Source code for ls_mlkit.util.observer

from typing import Callable, Dict, List, Literal

import numpy as np
import torch
import wandb
from datasets import Dataset as HFDataset
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset


[docs] def weight_norm_fn(module: Module): """Compute the weight norm of a module Args: module (Module): the module to compute the weight norm Returns: float: the weight norm of the module """ return torch.sqrt(sum(torch.sum(p.data * p.data) for p in module.parameters() if p.requires_grad))
[docs] def gradient_norm_fn(module: Module): """Compute the gradient norm of a module Args: module (Module): the module to compute the gradient norm Returns: float: the gradient norm of the module """ return torch.sqrt(sum(torch.sum(p.grad.data * p.grad.data) for p in module.parameters() if p.grad is not None))
[docs] def weights_fn(module: Module) -> list[Tensor]: """Get the weights of a module Args: module (Module): the module to get the weights Returns: list: the weights of the module """ return [p.detach().cpu() for p in module.parameters() if p.requires_grad]
[docs] def gradients_fn(module: Module) -> list[Tensor]: """Get the gradients of a module Args: module (Module): the module to get the gradients Returns: list: the gradients of the module """ return [p.grad.detach().cpu() for p in module.parameters() if p.grad is not None]
[docs] class Observer(object): function_mapping = { "weight_norm": weight_norm_fn, "gradient_norm": gradient_norm_fn, "weights": weights_fn, "gradients": gradients_fn, } def __init__( self, model: Module = None, optimizer: Optimizer = None, scheduler: LambdaLR = None, dataset: Dataset | HFDataset = None, target_modules: List[str] = None, no_split_classes: List[str] = None, ): """Initialize the Observer Args: model (Module, optional): the model to observe. Defaults to None. optimizer (Optimizer, optional): the optimizer to observe. Defaults to None. scheduler (LambdaLR, optional): the scheduler to observe. Defaults to None. dataset (Dataset | HFDataset, optional): the dataset to observe. Defaults to None. target_modules (List[str], optional): the modules to observe. Defaults to None. if target_modules is not None, then no_split_classes and strategy is ignored. no_split_classes (List[str], optional): the classes to not split. Defaults to None. """ self.model = model self.optimizer = optimizer self.scheduler = scheduler self.dataset = dataset self.no_split_classes = no_split_classes self.target_modules = target_modules # get something================================================================= @torch.no_grad() @staticmethod def _get_something( model: Module, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, function: Callable = None, ): info = dict() def __get_something(module: Module, prefix=""): if ( len(list(module.named_children())) == 0 or (no_split_classes is not None and module.__class__.__name__ in no_split_classes) ) and any(param.requires_grad for param in module.parameters()): info[prefix] = function(module) return for name, sub_module in module.named_children(): sub_module_name = f"{prefix}.{name}" if prefix != "" else name __get_something(sub_module, sub_module_name) match strategy: case "all": something = function(model) return {"total_model": something} case "block": __get_something(model, "") return info case _: raise ValueError(f"Unsupported strategy: {strategy}") @torch.no_grad() @staticmethod def _get_target_modules(model: Module, target_modules: List[str]): info = dict() def __get_target_modules(module: Module, prefix=""): if any(target_module in prefix for target_module in target_modules): info[prefix] = module return for name, sub_module in module.named_children(): sub_module_name = f"{prefix}.{name}" if prefix != "" else name __get_target_modules(sub_module, sub_module_name) __get_target_modules(model, "") return info @torch.no_grad() @staticmethod def _get_something_from_targets( model: Module = None, target_modules_dict: Dict[str, Module] = None, target_modules: List[str] = None, function: Callable = None, ): info = dict() if target_modules_dict is None: target_modules_dict = Observer._get_target_modules(model, target_modules) for module_path, module in target_modules_dict.items(): info[module_path] = function(module) return info
[docs] @torch.no_grad() def get_something_from_targets(self, function: Callable): return Observer._get_something_from_targets( model=self.model, target_modules_dict=None, target_modules=self.target_modules, function=function, )
[docs] @torch.no_grad() def get_something( self, name, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): if self.target_modules is None: if no_split_classes is None: no_split_classes = self.no_split_classes return Observer._get_something( model=self.model, strategy=strategy, no_split_classes=no_split_classes, function=Observer.function_mapping[name], ) return self.get_something_from_targets(function=Observer.function_mapping[name])
[docs] @torch.no_grad() def get_weight_norm( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("weight_norm", strategy, no_split_classes)
[docs] @torch.no_grad() def get_gradient_norm( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("gradient_norm", strategy, no_split_classes)
[docs] @torch.no_grad() def get_weights( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("weights", strategy, no_split_classes)
[docs] @torch.no_grad() def get_gradients( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("gradients", strategy, no_split_classes)
@torch.no_grad() @staticmethod def _get_statistics(data: List[Tensor]): flattened_tensor = torch.cat([item.reshape(-1) for item in data], dim=0) mean = flattened_tensor.mean() std = flattened_tensor.std() median = flattened_tensor.median() var = flattened_tensor.var() return {"mean": mean, "std": std, "median": median, "variance": var}
[docs] @torch.no_grad() def get_statistics( self, name, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): something = self.get_something(name, strategy=strategy, no_split_classes=no_split_classes) return {key: Observer._get_statistics(value) for key, value in something.items()}
# log something===============================================================
[docs] @torch.no_grad() def log_statistics( self, strategy: Literal["all", "block", "both"] = "both", no_split_classes: List[str] = None, section="statistics/", ): if self.target_modules is not None: strategy = "block" result = {} if strategy in ["all", "both"]: weights_statistics_total_model = self.get_statistics( name="weights", strategy="all", no_split_classes=no_split_classes ) gradients_statistics_total_model = self.get_statistics( name="gradients", strategy="all", no_split_classes=no_split_classes ) weights_statistics = { section + "weight/" + module_path + "/" + k: v for module_path, info in weights_statistics_total_model.items() for k, v in info.items() if k in ["mean", "std", "median"] } gradients_statistics = { section + "gradient/" + module_path + "/" + k: v for module_path, info in gradients_statistics_total_model.items() for k, v in info.items() if k in ["mean", "std", "median"] } result.update(weights_statistics) result.update(gradients_statistics) if strategy in ["block", "both"]: weights_statistics_block = self.get_statistics( name="weights", strategy="block", no_split_classes=no_split_classes ) gradients_statistics_block = self.get_statistics( name="gradients", strategy="block", no_split_classes=no_split_classes ) weights_statistics_block = { section + "weight/" + module_path + "/" + k: v for module_path, info in weights_statistics_block.items() for k, v in info.items() if k in ["mean", "std", "median"] } gradients_statistics_block = { section + "gradient/" + module_path + "/" + k: v for module_path, info in gradients_statistics_block.items() for k, v in info.items() if k in ["mean", "std", "median"] } result.update(weights_statistics_block) result.update(gradients_statistics_block) wandb.log(result)
[docs] @torch.no_grad() @staticmethod def log_histograms( data: Dict[str, List[Tensor]], bins: int = 16, section="histogram/", prefix="", ): results = dict() for key, data in data.items(): flattened_tensor = torch.cat([item.reshape(-1) for item in data], dim=0) flattened_numpy = flattened_tensor.numpy() np_histogram = np.histogram(flattened_numpy, bins=bins) wandb_histogram = wandb.Histogram(np_histogram=np_histogram) results.update( { section + prefix + key: wandb_histogram, } ) wandb.log(results)
[docs] @torch.no_grad() def log_distribution( self, name, bins: int = 16, section="observer/", strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, desc="", ): something = self.get_something(name, strategy=strategy, no_split_classes=no_split_classes) Observer.log_histograms( data=something, bins=bins, section=section, prefix=name + "/" + desc, )
[docs] @torch.no_grad() def log_weights_distribution( self, bins: int = 16, section="observer/", strategy: Literal["all", "block", "both"] = "both", no_split_classes: List[str] = None, desc="", ): if self.target_modules is not None: strategy = "block" if strategy in ["all", "both"]: self.log_distribution( name="weights", bins=bins, section=section, strategy="all", no_split_classes=no_split_classes, desc=desc, ) if strategy in ["block", "both"]: self.log_distribution( name="weights", bins=bins, section=section, strategy="block", no_split_classes=no_split_classes, desc=desc, )
[docs] @torch.no_grad() def log_gradients_distribution( self, bins: int = 16, section="observer/", strategy: Literal["all", "block", "both"] = "both", no_split_classes: List[str] = None, desc="", ): if self.target_modules is not None: strategy = "block" if strategy in ["all", "both"]: self.log_distribution( name="gradients", bins=bins, section=section, strategy="all", no_split_classes=no_split_classes, desc=desc, ) if strategy in ["block", "both"]: self.log_distribution( name="gradients", bins=bins, section=section, strategy="block", no_split_classes=no_split_classes, desc=desc, )
[docs] def log_gradient_norm( self, section="observer/", strategy: Literal["all", "block", "both"] = "both", no_split_classes: List[str] = None, desc="gradient_norm/", ): if self.target_modules is not None: strategy = "block" results = dict() if strategy in ["all", "both"]: gradient_norm_total_model = self.get_something( "gradient_norm", strategy="all", no_split_classes=no_split_classes ) results.update(gradient_norm_total_model) if strategy in ["block", "both"]: gradient_norm_block = self.get_something( "gradient_norm", strategy="block", no_split_classes=no_split_classes ) results.update(gradient_norm_block) sectioned_results = {section + desc + key: value for key, value in results.items()} wandb.log(sectioned_results)
[docs] def log_weight_norm( self, section="observer/", strategy: Literal["all", "block", "both"] = "both", desc="", no_split_classes: List[str] = None, ): if self.target_modules is not None: strategy = "block" results = dict() if strategy in ["all", "both"]: weight_norm_total_model = self.get_something( "weight_norm", strategy="all", no_split_classes=no_split_classes ) results.update(weight_norm_total_model) if strategy in ["block", "both"]: weight_norm_block = self.get_something("weight_norm", strategy="block", no_split_classes=no_split_classes) results.update(weight_norm_block) sectioned_results = {section + "weight_norm/" + desc + key: value for key, value in results.items()} wandb.log(sectioned_results)