Source code for ls_mlkit.optimizer.sam
from dataclasses import dataclass, field
from typing import Literal
import torch
from torch.optim.optimizer import Optimizer
[docs]
@dataclass
class SAMConfig:
epsilon: float = field(default=1e-9)
rho: float = field(default=0.05)
adaptive: bool = field(default=False)
[docs]
class SAM(Optimizer):
def __init__(
self,
params,
base_optimizer,
model: torch.nn.Module,
sam_config: SAMConfig = None,
**kwargs,
):
adaptive = sam_config.adaptive
rho = sam_config.rho
defaults = dict(adaptive=adaptive, rho=rho, **kwargs)
super(SAM, self).__init__(params, defaults)
if isinstance(base_optimizer, type):
base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
self.model = model
self.sam_config = sam_config
self.closure = None
[docs]
def set_closure(self, loss_fn, x, y, closure=None):
if closure is not None:
self.closure = closure
return
@torch.enable_grad()
def _closure():
self.zero_grad()
o = self.model(x)
loss = loss_fn(o, y)
loss.backward()
return loss.detach().clone().item()
self.closure = _closure
[docs]
def step(self, closure=None):
if closure is not None:
self.closure = closure
assert self.closure is not None, "closure is not set"
loss = self.closure()
grad_norm = self.get_gradient_norm(src="grad")
theta = "theta"
with torch.no_grad():
for group in self.param_groups:
scale = group["rho"] / (grad_norm + self.sam_config.epsilon)
for p in group["params"]:
if p.grad is None:
continue
self.state[p][theta] = p.data.clone()
grad = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad
p.data.add_(grad, alpha=scale)
self.closure()
x = torch.randn_like(p.data)
x.norm()
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.data = self.state[p][theta]
self.base_optimizer.step()
self.closure = None
return loss
[docs]
@torch.no_grad()
def get_gradient_norm(self, src: Literal["grad", "state", "weight"] = "grad", **kwargs):
assert (
src == "grad" or src == "state" or src == "weight"
), f"src must be in ['grad','state','weight'], {src} not"
gradient_norm = 0.0
if src == "grad":
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = (torch.abs(p) if group["adaptive"] else 1.0) * p.grad
gradient_norm += torch.sum(grad * grad)
elif src == "state":
key = kwargs.get("key", None)
assert key is not None, "src='state',please provide key of state"
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if self.state[p].get(key, None) is None:
continue
grad = self.state[p][key]
gradient_norm += torch.sum(grad * grad)
elif src == "weight":
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
gradient_norm += torch.sum(p.data * p.data)
return torch.sqrt(gradient_norm)
[docs]
def zero_grad(self, set_to_none: bool = True):
self.base_optimizer.zero_grad(set_to_none)
[docs]
def state_dict(self):
return self.base_optimizer.state_dict()
[docs]
def load_state_dict(self, state_dict):
self.base_optimizer.load_state_dict(state_dict)
def __repr__(self):
return f"SAM({self.base_optimizer.__class__.__name__})"
if __name__ == "__main__":
pass