import numpy as np
import torch
from .forward_hook import ForwardHookForDevice
from .graph_hook import OffloadSavedTensorHook
[docs]
class ForwardBackwardOffloadHookContext(ForwardHookForDevice):
mode = ("release",) # enum["debug","release"]
def __init__(
self,
model,
offload_proportion=0.5,
device="cuda",
no_split_module_classes=None,
with_backward_hook=False, # for debug
enable=True,
num_block: int = 2,
strategy="block", # enum["module","block"],
):
"""Offload model during forward and backward.
Args:
model (torch.nn.Module): The model to which the hook will be applied.
offload_proportion (float, optional): The proportion of activations to offload. Defaults to 0.5.
device (str, optional): The device to which activations will be offloaded. Defaults to "cuda".
no_split_module_classes (list of type, optional): List of module classes that should not be split during offloading. Defaults to None.
with_backward_hook (bool, optional): If True, enables the backward hook for debugging purposes. Defaults to False.
enable (bool, optional): If True, enables the hook. Defaults to True.
num_block (int, optional): The number of blocks to use when the strategy is set to "block". Defaults to 2.
strategy (str, optional): The offloading strategy to use. Options are "module" or "block". Defaults to "block".
"""
self.enable = enable
if not enable:
return
super().__init__()
self.strategy = strategy
self.num_block = num_block
self.device = device # computing device for offloaded modules
self.with_backward_hook = with_backward_hook
self.model = model
self.handle_list = list()
if no_split_module_classes is None:
no_split_module_classes = ["LlamaDecoderLayer", "GPT2TransformerBlock"]
self.module_list = ForwardHookForDevice.get_module_list(model, no_split_module_classes=no_split_module_classes)
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"module_list:{self.module_list}")
if self.strategy == "module":
self.offload_list = self.module_list[: int(offload_proportion * len(self.module_list))]
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"self.offload_list={self.offload_list}")
elif self.strategy == "block":
self.module_info = self.get_partition_block(self.module_list, self.num_block)
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"self.module_info={self.module_info}")
def __enter__(self):
"""
Register the hook in the appropriate module
"""
if not self.enable:
return
if ForwardBackwardOffloadHookContext.mode == "debug":
print("ForwardBackwardOffloadHookContext.__enter__(self):")
if self.strategy == "module":
self.register_forward_hook_by_module(self.model)
else:
self.register_hook_by_block(self.model)
def __exit__(self, exc_type, exc_val, exc_tb):
"""
remove hook registered by __enter__
"""
if not self.enable:
return
if ForwardBackwardOffloadHookContext.mode == "debug":
print("ForwardBackwardOffloadHookContext.__exit__(self, exc_type, exc_val, exc_tb)")
for handle in self.handle_list:
handle.remove()
[docs]
def register_hook_by_block(self, module: torch.nn.Module, parent_name=""):
if self.with_backward_hook and parent_name in self.module_list:
handle = module.register_full_backward_pre_hook(hook=self.get_backward_hook(pre=True))
self.handle_list.append(handle)
handle = module.register_full_backward_hook(hook=self.get_backward_hook(pre=False))
self.handle_list.append(handle)
if parent_name in self.module_list:
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"register_hook_by_block(self, module, parent_name={parent_name}")
# forward hook==============================================================
handle = module.register_forward_pre_hook(
hook=self.get_forward_hook_by_block(info=self.module_info[parent_name], pre=True, with_kwargs=True),
with_kwargs=True,
)
self.handle_list.append(handle)
handle = module.register_forward_hook(
hook=self.get_forward_hook_by_block(info=self.module_info[parent_name], pre=False, with_kwargs=True),
with_kwargs=True,
)
self.handle_list.append(handle)
# backward hook==============================================================
handle = module.register_full_backward_pre_hook(
hook=self.get_backward_hook_by_block(info=self.module_info[parent_name], pre=True)
)
self.handle_list.append(handle)
handle = module.register_full_backward_hook(
hook=self.get_backward_hook_by_block(info=self.module_info[parent_name], pre=False)
)
self.handle_list.append(handle)
return
for name, sub_module in module.named_children():
full_name = f"{parent_name}.{name}" if parent_name else name
self.register_hook_by_block(sub_module, full_name)
[docs]
@staticmethod
def get_forward_hook_by_block(info: dict, pre=True, device="cuda", with_kwargs=True):
if device is None:
device = "cuda"
offload_device = "cpu"
info["first_block_flag"]
last_block_flag = info["last_block_flag"]
first_module_flag = info["first_module_flag"]
info["last_module_flag"]
def pre_hook_with_kwargs(module, args, kwargs):
if ForwardBackwardOffloadHookContext.mode == "debug":
from .resource_monitor import show_gpu_and_cpu_memory
show_gpu_and_cpu_memory()
# model
module.to(device)
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args)
kwargs = {n: v.to(device) if isinstance(v, torch.Tensor) else v for n, v in kwargs.items()}
# saved_tensor,such as activations.
if not last_block_flag and first_module_flag:
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"set OffloadSavedTensorHook.offload_device = offload_device:{offload_device}")
OffloadSavedTensorHook.offload_device = offload_device
elif last_block_flag and first_module_flag:
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"set OffloadSavedTensorHook.offload_device = device:{device}")
OffloadSavedTensorHook.offload_device = device
return args, kwargs
def after_hook_with_kwargs(module, args, kwargs, output):
if not last_block_flag:
module.to(offload_device)
# output = output.to(offload_device) if isinstance(output, torch.Tensor) else output
# if isinstance(output, tuple):
# output = tuple(o.to(offload_device) if isinstance(o, torch.Tensor) else o for o in output)
elif last_block_flag:
module.to(device)
# output = output.to(device) if isinstance(output, torch.Tensor) else output
# if isinstance(output, tuple):
# output = tuple(o.to(device) if isinstance(o, torch.Tensor) else o for o in output)
return output
if pre:
return pre_hook_with_kwargs
else:
return after_hook_with_kwargs
[docs]
@staticmethod
def get_backward_hook_by_block(info: dict, pre=True, device="cuda"):
if device is None:
device = "cuda"
offload_device = "cpu"
first_block_flag = info["first_block_flag"]
info["last_block_flag"]
info["first_module_flag"]
info["last_module_flag"]
def pre_hook(module, grad_output):
module.to(device)
return grad_output
def after_hook(module, grad_input, grad_output):
if not first_block_flag:
module.to(offload_device)
else:
pass
return grad_input
if pre:
return pre_hook
else:
return after_hook
[docs]
@staticmethod
def get_backward_hook(pre=True):
def pre_hook(module, grad_output):
if ForwardBackwardOffloadHookContext.mode == "debug":
from .resource_monitor import show_gpu_and_cpu_memory
show_gpu_and_cpu_memory()
return grad_output
def after_hook(module, grad_input, grad_output):
if ForwardBackwardOffloadHookContext.mode == "debug":
from .resource_monitor import show_gpu_and_cpu_memory
show_gpu_and_cpu_memory()
return grad_input
if pre:
return pre_hook
else:
return after_hook
[docs]
def register_forward_hook_by_module(self, module: torch.nn.Module, parent_name=""):
if ForwardBackwardOffloadHookContext.mode == "debug":
print(f"register_forward_hook_by_module(self, module, parent_name={parent_name}")
if self.with_backward_hook and parent_name in self.module_list:
handle = module.register_full_backward_pre_hook(hook=self.get_backward_hook())
self.handle_list.append(handle)
handle = module.register_full_backward_hook(hook=self.get_backward_hook(pre=False))
self.handle_list.append(handle)
if parent_name in self.offload_list:
handle = module.register_forward_pre_hook(
self.get_forward_hook(pre=True, device=self.device, with_kwargs=True),
with_kwargs=True,
)
self.handle_list.append(handle)
handle = module.register_forward_hook(
self.get_forward_hook(pre=False, device=self.device, with_kwargs=True),
with_kwargs=True,
)
self.handle_list.append(handle)
return
elif parent_name in self.module_list:
handle = module.register_forward_pre_hook(
self.get_align_device_pre_forward_hook(device="cuda", with_kwargs=True),
with_kwargs=True,
)
self.handle_list.append(handle)
return
for name, sub_module in module.named_children():
full_name = f"{parent_name}.{name}" if parent_name else name
self.register_forward_hook_by_module(sub_module, full_name)
[docs]
@staticmethod
def get_partition_block(module_list: list, num_block: int) -> dict:
block_list = list()
module_groups = [list(e) for e in np.array_split(module_list, num_block)]
for i in range(num_block):
block = dict()
block["module_list"] = module_groups[i]
block["first_block_flag"] = True if i == 0 else False
block["last_block_flag"] = True if i == (num_block - 1) else False
block_list.append(block)
if ForwardBackwardOffloadHookContext.mode == "debug":
print(block_list)
module_info = dict()
for block in block_list:
n_module = len(block["module_list"])
for i in range(n_module):
module_name = block["module_list"][i]
module_info[module_name] = dict()
module_info[module_name].update(
{
"first_block_flag": block["first_block_flag"],
"last_block_flag": block["last_block_flag"],
"first_module_flag": True if i == 0 else False,
"last_module_flag": True if i == (n_module - 1) else False,
}
)
return module_info