Source code for ls_mlkit.util.offload.gradient_offload_v2
import torch.nn
[docs]
def get_offload_grad_hook(
offload_attr_name="offloaded_grad", offload_device="cpu", init_iters="init_iters", *arg, **kwargs
):
def offload_grad_hook(x):
# print("offload_grad_hook is called")
if x.grad is not None:
grad = x.grad
x.grad = None
if not hasattr(x, offload_attr_name):
setattr(x, offload_attr_name, grad.to(offload_device))
setattr(x, init_iters, 0)
else:
accumulated_grad = getattr(x, offload_attr_name)
setattr(x, offload_attr_name, accumulated_grad + grad.to(offload_device))
init_niter = getattr(x, init_iters)
setattr(x, init_iters, init_niter + 1)
return offload_grad_hook
[docs]
class GradientOffloadHookContext:
def __init__(self, model: torch.nn.Module, enable: bool, *args, **kwargs):
self.enable = enable
if not enable:
return
self.model = model
self.offload_attr_name = "offloaded_grad"
self.offload_device = "cpu"
self.handle_list = list()
def __enter__(self):
if not self.enable:
return
for n, p in self.model.named_parameters():
handle = p.register_post_accumulate_grad_hook(
hook=GradientOffloadHookContext.get_offload_grad_hook(
offload_attr_name=self.offload_attr_name, offload_device=self.offload_device
)
)
self.handle_list.append(handle)
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enable:
return
for handle in self.handle_list:
handle.remove()
[docs]
@staticmethod
def get_offload_grad_hook(
offload_attr_name="offloaded_grad", offload_device="cpu", init_iters="init_iters", *arg, **kwargs
):
def offload_grad_hook(x):
# print("offload_grad_hook is called")
if x.grad is not None:
grad = x.grad
x.grad = None
if not hasattr(x, offload_attr_name):
setattr(x, offload_attr_name, grad.to(offload_device))
setattr(x, init_iters, 0)
else:
accumulated_grad = getattr(x, offload_attr_name)
setattr(x, offload_attr_name, accumulated_grad + grad.to(offload_device))
init_niter = getattr(x, init_iters)
setattr(x, init_iters, init_niter + 1)
return offload_grad_hook