Source code for ls_mlkit.util.offload.saved_tensor_offload
import torch
from .graph_hook import OffloadSavedTensorHook
[docs]
class SavedTensorOffloadContext:
def __init__(self):
self.savedTensorOffloadContext = torch.autograd.graph.saved_tensors_hooks(
pack_hook=OffloadSavedTensorHook.pack, unpack_hook=OffloadSavedTensorHook.unpack
)
def __enter__(self):
self.savedTensorOffloadContext.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.savedTensorOffloadContext.__exit__(exc_type, exc_val, exc_tb)