Source code for ls_mlkit.util.offload.split

import torch.nn


[docs] def get_model_memory(model: torch.nn.Module, forward_factor: float = 1.3): """ Calculates the estimated memory usage of a model in gigabytes. Args: model (torch.nn.Module): The model whose memory usage is to be calculated. forward_factor (float, optional): A factor to account for additional memory usage during the forward pass. Defaults to 1.3. Returns: float: The estimated memory usage of the model in gigabytes. """ total = 0 for p in model.parameters(): total += p.numel() * p.element_size() return forward_factor * total / 1024**3
[docs] def get_split_num(origin_type: str = "bf16", quant_type: str = "int8"): """ Calculates the ratio of original type size to quantized type size. Args: origin_type (str, optional): The data type of the original tensor. Defaults to "bf16". Options are "fp32" and "bf16". quant_type (str, optional): The data type of the quantized tensor. Defaults to "int8". Options are "int8" and "nf4". Raises: ValueError: If the origin_type is not "fp32" or "bf16". ValueError: If the quant_type is not "int8" or "nf4". Returns: int: The ratio of the original type size to the quantized type size. """ n_origin_bytes = 16 n_quant_bytes = 8 match origin_type: case "fp32": n_origin_bytes = 32 case "bf16": n_origin_bytes = 16 case _: raise ValueError("Wrong dtype") match quant_type: case "int8": n_quant_bytes = 8 case "nf4": n_quant_bytes = 4 case _: raise ValueError("Wrong dtype") return n_origin_bytes // n_quant_bytes