Source code for ls_mlkit.optimizer.kfa.kfa_optimizer

import math
from dataclasses import dataclass, field
from typing import Literal

import torch
from torch.optim import Optimizer

from .kfa import KFA


[docs] @dataclass class UserConfig: xi: float = field(default=1e-4) alpha: float = field(default=0.9) rho: float = field(default=0.1) rho_cov: float = field(default=0.1)
[docs] class KFAOptimizer(Optimizer): def __init__( self, params, base_optimizer: Optimizer, model: torch.nn.Module, user_config: UserConfig, kfa: KFA, **kwargs, ): rho = user_config.rho rho_cov = user_config.rho_cov defaults = dict(rho=rho, rho_cov=rho_cov, **kwargs) super(KFAOptimizer, self).__init__(params, defaults) if isinstance(base_optimizer, type): base_optimizer = base_optimizer(self.param_groups, **kwargs) self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults.update(self.base_optimizer.defaults) self.model = model self.user_config = user_config self.kfa = kfa
[docs] def set_closure(self, loss_fn, x, y, closure=None): if closure is not None: self.closure = closure return @torch.enable_grad() def _closure(): self.zero_grad() o = self.model(x) loss = loss_fn(o, y) loss.backward() return loss.detach().clone().item() self.closure = _closure
[docs] def step(self, closure=None): if closure is not None: self.closure = closure assert self.closure is not None, "closure is not set" with self.kfa: loss = self.closure() self.save_params("theta") self.save_grad("g1") self.perturb(name="grad") self.closure() self.save_grad("g2") self.save_moving_average_and_save_d(name="d", alpha=self.user_config.alpha) self.calculate_fisher_inverse_mult(tgt_key="inv_Fd", src_key="d") self.calculate_fisher_inverse_mult(tgt_key="inv_Fg1", src_key="g1") self.calculate_C_inverse_mult_d(tgt_key="inv_Cd") self.epsilon_perturb() self.closure() self.back_to("theta") self.base_optimizer.step() self.closure = None return loss
[docs] @torch.no_grad() def epsilon_perturb(self): scale = 0.0 for group in self.param_groups: for p in group["params"]: if p.grad is None: continue scale += torch.sum(self.state[p]["d"] * self.state[p]["inv_Cd"]) scale = 1 / torch.sqrt(scale) for group in self.param_groups: for p in group["params"]: if p.grad is None: continue self.state[p]["epsilon"] = scale * self.state[p]["inv_Cd"] self.back_to("theta") for group in self.param_groups: for p in group["params"]: if p.grad is None: continue p.data = p.data + group["rho_cov"] * self.state[p]["epsilon"]
[docs] @torch.no_grad() def calculate_C_inverse_mult_d(self, tgt_key: str = "inv_Cd"): numerator, denominator = 0.0, 0.0 for group in self.param_groups: for p in group["params"]: if p.grad is None: continue numerator += torch.sum(self.state[p]["g1"] * self.state[p]["inv_Fd"]) denominator += torch.sum(self.state[p]["g1"] * self.state[p]["inv_Fg1"]) print(denominator) # denominator = 1 - torch.exp(-denominator) denominator = 1 - denominator # denominator = 1 scale = numerator / denominator for group in self.param_groups: for p in group["params"]: if p.grad is None: continue self.state[p][tgt_key] = self.state[p]["inv_Fd"] + scale * self.state[p]["inv_Fg1"]
[docs] @torch.no_grad() def calculate_fisher_inverse_mult(self, tgt_key: str, src_key: str): def _calculate_fisher_inverse_mult(module: torch.nn.Module, prefix: str): if module.__class__.__name__ in self.kfa.known_modules: match module.__class__.__name__: case "Linear": r""" $\F^{-1}d$ """ if not module.weight.requires_grad: return self.state[module.weight][tgt_key] = KFA.calculate_fisher_inverse_mult_V( cache=self.kfa.cache, a=self.kfa.cache[module.weight]["a"], g=self.kfa.cache[module.weight]["g"], V=self.state[module.weight][src_key], ) if module.bias is not None: self.state[module.bias][tgt_key] = ( KFA.calculate_fisher_inverse_mult_V( cache=self.kfa.cache, a=self.kfa.cache[module.bias]["a"], g=self.kfa.cache[module.bias]["g"], V=self.state[module.bias][src_key].unsqueeze(-2).transpose(-2, -1), ) .transpose(-2, -1) .squeeze(-2) ) case _: raise ValueError(f"unknown module: {module.__class__.__name__}") return for name, sub_module in module.named_children(): prefix = f"{prefix}.{name}" if prefix else name _calculate_fisher_inverse_mult(sub_module, prefix) _calculate_fisher_inverse_mult(self.model, "")
[docs] @torch.no_grad() def back_to(self, name): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue p.data = self.state[p][name]
[docs] @torch.no_grad() def perturb(self, name: Literal["grad", "state"] = "grad", **kwargs): if name == "grad": grad_norm = self.get_something_norm(name, **kwargs) + self.user_config.xi for group in self.param_groups: for p in group["params"]: if p.grad is None: continue if name == "grad": grad = p.grad elif name == "state": key = kwargs.get("key", None) assert key is not None, "src='state',please provide key of state" grad = self.state[p][key] p.data = p.data + group["rho"] * grad / grad_norm
[docs] @torch.no_grad() def save_moving_average_and_save_d(self, name: str, alpha: float = 0.9): for group in self.param_groups: for p in group["params"]: if p.grad is None: continue m = self.state[p]["g2"] - self.state[p]["g1"] if self.state[p].get(name, None) is None: self.state[p][name] = torch.zeros_like(m) else: self.state[p][name] = alpha * self.state[p][name] + (1 - alpha) * m self.state[p]["d"] = m - self.state[p][name]
[docs] @torch.no_grad() def save_grad(self, name: str): for group in self.param_groups: for p in group["params"]: if p.grad is not None: self.state[p][name] = p.grad.detach().clone()
[docs] @torch.no_grad() def save_params(self, name: str): for group in self.param_groups: for p in group["params"]: self.state[p][name] = p.data.detach().clone()
[docs] @torch.no_grad() def get_something_norm(self, something_name: Literal["grad", "state", "weight"] = "grad", **kwargs): assert ( something_name == "grad" or something_name == "state" or something_name == "weight" ), f"something_name must be in ['grad','state','weight'], {something_name} not" something_norm = 0.0 if something_name == "grad": for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = (torch.abs(p) if group.get("adaptive", None) else 1.0) * p.grad something_norm += torch.sum(grad * grad).item() elif something_name == "state": key = kwargs.get("key", None) assert key is not None, "src='state',please provide key of state" for group in self.param_groups: for p in group["params"]: if p.grad is None: continue if self.state[p].get(key, None) is None: continue grad = self.state[p][key] something_norm += torch.sum(grad * grad).item() elif something_name == "weight": for group in self.param_groups: for p in group["params"]: if p.grad is None: continue something_norm += torch.sum(p.data * p.data).item() return math.sqrt(something_norm)
[docs] def zero_grad(self): self.base_optimizer.zero_grad()
[docs] def state_dict(self): return self.base_optimizer.state_dict()
[docs] def load_state_dict(self, state_dict): self.base_optimizer.load_state_dict(state_dict)
def __repr__(self): return f"KFAOptimizer({self.base_optimizer.__class__.__name__})"