ls_mlkit.util.offload.forward_hook module¶
- class ls_mlkit.util.offload.forward_hook.ForwardHookForDevice[source]¶
Bases:
object- static get_align_device_pre_forward_hook(device='cuda', with_kwargs=False)[source]¶
ensure same device for input and module
- static get_forward_hook(pre: bool, device=None, with_kwargs=False)[source]¶
device is executing device origin_device is the device where tensor is saved after forward