ls_mlkit.util.observer module¶
- class ls_mlkit.util.observer.Observer(model: Module = None, optimizer: Optimizer = None, scheduler: LambdaLR = None, dataset: Dataset | Dataset = None, target_modules: List[str] = None, no_split_classes: List[str] = None)[source]¶
Bases:
object- function_mapping = {'gradient_norm': <function gradient_norm_fn>, 'gradients': <function gradients_fn>, 'weight_norm': <function weight_norm_fn>, 'weights': <function weights_fn>}¶
- get_gradient_norm(strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None)[source]¶
- get_gradients(strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None)[source]¶
- get_something(name, strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None)[source]¶
- get_statistics(name, strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None)[source]¶
- get_weight_norm(strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None)[source]¶
- log_distribution(name, bins: int = 16, section='observer/', strategy: Literal['all', 'block'] = 'all', no_split_classes: List[str] = None, desc='')[source]¶
- log_gradient_norm(section='observer/', strategy: Literal['all', 'block', 'both'] = 'both', no_split_classes: List[str] = None, desc='gradient_norm/')[source]¶
- log_gradients_distribution(bins: int = 16, section='observer/', strategy: Literal['all', 'block', 'both'] = 'both', no_split_classes: List[str] = None, desc='')[source]¶
- log_statistics(strategy: Literal['all', 'block', 'both'] = 'both', no_split_classes: List[str] = None, section='statistics/')[source]¶
- ls_mlkit.util.observer.gradient_norm_fn(module: Module)[source]¶
Compute the gradient norm of a module
- Parameters:
module (Module) – the module to compute the gradient norm
- Returns:
the gradient norm of the module
- Return type:
float
- ls_mlkit.util.observer.gradients_fn(module: Module) list[Tensor][source]¶
Get the gradients of a module
- Parameters:
module (Module) – the module to get the gradients
- Returns:
the gradients of the module
- Return type:
list