Source code for ls_mlkit.util.offload.model_offload

from .forward_backward_offload import ForwardBackwardOffloadHookContext
from .saved_tensor_offload import SavedTensorOffloadContext


[docs] class ModelOffloadHookContext: def __init__( self, model, no_split_module_classes=None, num_block: int = 2, enable=True, # ========================= device="cuda", strategy="block", with_backward_hook=False, ): """ Initializes the ModelOffloadHookContext to manage offloading of model computations and saved tensors. Args: model (torch.nn.Module): The model to which the hooks will be applied. no_split_module_classes (list of type, optional): List of module classes that should not be split during offloading. Defaults to None. num_block (int, optional): The number of blocks to use when the strategy is set to "block". Defaults to 2. enable (bool, optional): If True, enables the hook. Defaults to True. device (str, optional): The device to which activations and gradients will be offloaded. Defaults to "cuda". strategy (str, optional): The offloading strategy to use. Options are "module" or "block". Defaults to "block". with_backward_hook (bool, optional): If True, enables the backward hook for debugging purposes. Defaults to False. """ self.enable = enable if not enable: return self.forwardBackwardOffloadHookContext = ForwardBackwardOffloadHookContext( model=model, device=device, no_split_module_classes=no_split_module_classes, with_backward_hook=with_backward_hook, # for debug enable=True, num_block=num_block, strategy=strategy, # enum["module","block"], ) self.savedTensorOffloadContext = SavedTensorOffloadContext() def __enter__(self): if not self.enable: return self.forwardBackwardOffloadHookContext.__enter__() self.savedTensorOffloadContext.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if not self.enable: return self.forwardBackwardOffloadHookContext.__exit__(exc_type, exc_val, exc_tb) self.savedTensorOffloadContext.__exit__(exc_type, exc_val, exc_tb)