from typing import Callable
import torch
torch.linalg.inv: Callable
[docs]
class KFA:
known_modules = ["Linear"]
eps = 1e-3
def __init__(self, model):
self.cache = {}
self.handlers = []
self.model = model
def __enter__(self):
self.register_save_hook(self.model)
def __exit__(self, exc_type, exc_value, traceback):
self.remove_hook()
[docs]
@staticmethod
def calculate_fisher_inverse_mult_V(cache: dict, a: torch.Tensor, g: torch.Tensor, V: torch.Tensor):
"""
V:(..., m, n)
aaT:(..., m, m)
ggT:(..., n, n)
"""
a = a.unsqueeze(-2).transpose(-2, -1)
aT = a.transpose(-2, -1)
aaT = torch.matmul(a, aT)
g = g.unsqueeze(-2).transpose(-2, -1)
gT = g.transpose(-2, -1)
ggT = torch.matmul(g, gT)
EaaT = torch.mean(
aaT,
dim=tuple(range(len(aaT.shape) - 2)),
)
EggT = torch.mean(
ggT,
dim=tuple(range(len(ggT.shape) - 2)),
)
inv_EaaT = torch.linalg.inv(EaaT + KFA.eps * torch.eye(EaaT.shape[-1], device=EaaT.device))
# print(g)
inv_EggT = torch.linalg.inv(EggT + KFA.eps * torch.eye(EggT.shape[-1], device=EggT.device))
result = torch.matmul(inv_EggT, V)
result = torch.matmul(result, inv_EaaT)
return result
[docs]
@staticmethod
def get_save_hook_for_a(cache: dict, module_dot_path: str, name: str = "a"):
def _save_hook_for_a(module: torch.nn.Module, args, kwargs):
match module.__class__.__name__:
case "Linear":
module: torch.nn.Linear
if cache.get(module.weight) is None:
cache[module.weight] = {}
cache[module.weight][name] = args[0]
if module.bias is not None:
if cache.get(module.bias) is None:
cache[module.bias] = {}
cache[module.bias][name] = torch.eye(1, device=module.bias.device)
case _:
pass
return _save_hook_for_a
[docs]
@staticmethod
def get_save_hook_for_g(cache: dict, module_dot_path: str, name: str = "g"):
def _save_hook_for_g(module, grad_input, grad_output):
match module.__class__.__name__:
case "Linear":
module: torch.nn.Linear
if cache.get(module.weight) is None:
cache[module.weight] = {}
cache[module.weight][name] = grad_output[0] * grad_output[0].shape[0]
if module.bias is not None:
if cache.get(module.bias) is None:
cache[module.bias] = {}
cache[module.bias][name] = grad_output[0] * grad_output[0].shape[0]
case _:
pass
return _save_hook_for_g
[docs]
def register_save_hook(self, model):
def _register_save_hook(module: torch.nn.Module, prefix=""):
if module.__class__.__name__ in self.known_modules:
handler = module.register_forward_pre_hook(
KFA.get_save_hook_for_a(
cache=self.cache,
module_dot_path=prefix,
),
with_kwargs=True,
)
self.handlers.append(handler)
handler = module.register_full_backward_hook(
KFA.get_save_hook_for_g(
cache=self.cache,
module_dot_path=prefix,
),
)
self.handlers.append(handler)
if len(list(module.children())) <= 0:
return
for name, submodule in module.named_children():
new_prefix = prefix + "." + name if prefix != "" else name
_register_save_hook(module=submodule, prefix=new_prefix)
_register_save_hook(module=model, prefix="")
[docs]
def remove_hook(self):
for handler in self.handlers:
handler.remove()