ls_mlkit.optimizer.kfa.kfa module

class ls_mlkit.optimizer.kfa.kfa.KFA(model)[source]

Bases: object

static calculate_fisher_inverse_mult_V(cache: dict, a: Tensor, g: Tensor, V: Tensor)[source]

V:(…, m, n) aaT:(…, m, m) ggT:(…, n, n)

eps = 0.001
static get_save_hook_for_a(cache: dict, module_dot_path: str, name: str = 'a')[source]
static get_save_hook_for_g(cache: dict, module_dot_path: str, name: str = 'g')[source]
known_modules = ['Linear']
register_save_hook(model)[source]
remove_hook()[source]