Source code for ls_mlkit.util.dequantize
import torch
from peft.tuners.lora.bnb import Linear4bit
from peft.tuners.lora.bnb import Linear4bit as LoraLinear4bit
from peft.tuners.lora.bnb import Linear8bitLt
from peft.tuners.lora.bnb import Linear8bitLt as LoraLinear8bitLt
from torch.nn import Linear
[docs]
def get_float_weight(model: torch.nn.Module):
"""Get the float weight of a model
Args:
model (torch.nn.Module): the model to get the float weight
Returns:
torch.Tensor: the float weight of the model
"""
model: torch.nn.Linear
device = model.weight.device
in_features = model.in_features
with torch.no_grad():
I = torch.eye(in_features).to(device)
w = model(I)
if hasattr(model, "bias") and isinstance(model.bias, torch.Tensor):
w -= model.bias
w = torch.transpose(w, 0, 1)
w.requires_grad = model.weight.requires_grad
return w
[docs]
def replace_module_with_linear(model: torch.nn.Module, target) -> None:
"""Replace a module with a linear layer
Args:
model (torch.nn.Module): the model to replace the module
target (torch.nn.Module): the target module to replace
Returns:
None
"""
for name, module in model.named_children():
if isinstance(module, target):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
new_module = torch.nn.Linear(in_features, out_features, bias)
with torch.no_grad():
new_module.weight.data = get_float_weight(module).data
if bias:
new_module.bias.data = module.bias if module.bias is not None else None
setattr(model, name, new_module)
else:
replace_module_with_linear(module, target)
[docs]
def dequantize(model, dtype) -> None:
"""Dequantize a model
Args:
model (torch.nn.Module): the model to dequantize
dtype (str): the dtype of the model
Returns:
None
"""
target = None
if dtype == "int8":
target = LoraLinear8bitLt
elif dtype == "nf4":
target = LoraLinear4bit
replace_module_with_linear(model=model, target=target)
[docs]
class Config:
in_features = 3
out_features = 4
device = "cuda"
[docs]
def main():
config = Config()
print("get float weight")
linear = Linear(config.in_features, config.out_features)
print(linear.weight)
w = get_float_weight(model=linear)
print(w)
print("int8 quant==============================================")
m = Linear8bitLt(config.in_features, config.out_features, has_fp16_weights=False)
m.load_state_dict(linear.state_dict())
m.to(config.device)
print(m.weight)
w = get_float_weight(model=m)
print(linear.weight)
print(w)
print("nf4 quant==============================================")
m = Linear4bit(config.in_features, config.out_features)
m.load_state_dict(linear.state_dict())
m.to(config.device)
print(m.weight)
print(linear.weight)
w = get_float_weight(model=m)
print(w)
if __name__ == "__main__":
main()