ls_mlkit.util.offload.forward_hook module

class ls_mlkit.util.offload.forward_hook.ForwardHookForDevice[source]

Bases: object

static get_align_device_pre_forward_hook(device='cuda', with_kwargs=False)[source]

ensure same device for input and module

static get_forward_hook(pre: bool, device=None, with_kwargs=False)[source]

device is executing device origin_device is the device where tensor is saved after forward

static get_full_name_list(model)[source]

Get the module name list of the leaf nodes of the module tree

static get_module_list(model, no_split_module_classes=None)[source]
Get the module name list of the leaf nodes of the module tree,

and stop recursing when the specified node(no_split_module_class) is reached.