Source code for ls_mlkit.util.offload.gradient_offload

import torch


[docs] def get_record_gradient_hook(self, model, record_dict): def record_gradient_hook(grad): for n, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if n not in record_dict: record_dict[n] = p.grad.to(self.gradient_device) else: record_dict[n] += p.grad.to(self.gradient_device) p.grad = None return grad return record_gradient_hook
[docs] class GradientOffloadHookContext: def __init__( self, model: torch.nn.Module, record_dict: dict, enable: bool = True, *args, **kwargs, ): """Offload gradient to cpu Args: model (torch.nn.Module): The model whose gradients will be offloaded. record_dict (dict): A dictionary to record offloaded gradient (named_grad) enable (bool, optional): If True, enables the gradient offloading. Defaults to True. *args: Additional arguments. **kwargs: Additional keyword arguments. """ if enable: self.gradient_device = "cpu" else: self.gradient_device = "cuda" self.handle_list = list() self.model = model self.record_dict = record_dict def __enter__(self): self.register_gradient_hook() def __exit__(self, exc_type, exc_val, exc_tb): for handle in self.handle_list: handle.remove()
[docs] def register_gradient_hook(self): for _, param in self.model.named_parameters(): hook = param.register_hook(self.get_record_gradient_hook(self.model, self.record_dict)) self.handle_list.append(hook)
[docs] def get_record_gradient_hook(self, model, record_dict): def record_gradient_hook(grad): for n, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if n not in record_dict: record_dict[n] = p.grad.to(self.gradient_device) else: record_dict[n] += p.grad.to(self.gradient_device) p.grad = None return grad return record_gradient_hook