Source code for ls_mlkit.diffuser.so3_diffuser

from typing import Any, Callable, Tuple

import torch
from torch import Tensor

from ls_mlkit.diffuser.sde.sde_lib import VESDE

from ..util.decorators import inherit_docstrings
from ..util.interp import interp
from ..util.manifold.so3 import SO3
from ..util.manifold.so3_utils import (
    calculate_igso3,
    inverse_transform_sampling,
    rotation_matrix_to_angle,
    vector_to_skew_symmetric,
)
from ..util.mask.bio_masker import BioSO3Masker
from .lie_group_diffuser import LieGroupDiffuser, LieGroupDiffuserConfig
from .sde import SDE
from .time_scheduler import TimeScheduler

EPS = 1e-6


[docs] @inherit_docstrings class SO3DiffuserConfig(LieGroupDiffuserConfig): def __init__( self, n_discretization_steps: int, ndim_micro_shape: int, igso3_num_sigma: int, igso3_num_omega: int, igso3_min_sigma: float, igso3_max_sigma: float, *args: list[Any], **kwargs: dict[Any, Any], ): super().__init__( n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, *args, **kwargs ) self.igso3_num_sigma = igso3_num_sigma self.igso3_num_omega = igso3_num_omega self.igso3_min_sigma = igso3_min_sigma self.igso3_max_sigma = igso3_max_sigma
[docs] @inherit_docstrings class SO3Diffuser(LieGroupDiffuser): def __init__( self, config: SO3DiffuserConfig, time_scheduler: TimeScheduler, masker: BioSO3Masker, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor], Tensor], # (x, t, mask) -> score loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor], # (predicted_score, ground_truth_score, mask) -> loss ): """Initialize the SO3Diffuser Args: config (SO3DiffuserConfig): the config of the SO3Diffuser time_scheduler (TimeScheduler): the time scheduler of the SO3Diffuser masker (BioSO3Masker): the masker of the SO3Diffuser sde (SDE): the SDE of the SO3Diffuser score_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the score function of the SO3Diffuser loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the SO3Diffuser """ so3 = SO3() super().__init__( config=config, time_scheduler=time_scheduler, lie_group=so3, ) self.config = config self.time_scheduler = time_scheduler self.masker = masker self.sde = sde self.loss_fn = loss_fn self.so3 = so3 self.score_fn = score_fn assert isinstance(self.sde, VESDE), "only VESDE is supported" igso3_cache = calculate_igso3( num_sigma=config.igso3_num_sigma, num_omega=config.igso3_num_omega, min_sigma=config.igso3_min_sigma, max_sigma=config.igso3_max_sigma, discrete_omega=torch.linspace(0, torch.pi, config.igso3_num_omega + 1)[1:], discrete_sigma=self.sde.discrete_sigmas, ) # Register buffers - these will automatically move with the model self.register_buffer("_igso3_cdf", igso3_cache["cdf"]) # [num_sigma, num_omega] self.register_buffer( "_igso3_score_norm", igso3_cache["score_norm"] ) # [num_sigma, num_omega] # $$\frac{d}{d\omega} f(\omega, c, L)$$ self.register_buffer( "_igso3_exp_score_norms", igso3_cache["exp_score_norms"] ) # [num_sigma, ] $$\sqrt{\mathbb{E}_{\omega} || \frac{d}{d\omega} f(\omega, c, L)||_2^2}$$ self.register_buffer("_igso3_discrete_omega", igso3_cache["discrete_omega"]) # [num_omega, ] self.register_buffer("_igso3_discrete_sigma", igso3_cache["discrete_sigma"]) # [num_sigma, ] @property def igso3_cdf(self) -> Tensor: return self._igso3_cdf @property def igso3_score_norm(self) -> Tensor: return self._igso3_score_norm @property def igso3_exp_score_norms(self) -> Tensor: return self._igso3_exp_score_norms @property def igso3_discrete_omega(self) -> Tensor: return self._igso3_discrete_omega @property def igso3_discrete_sigma(self) -> Tensor: return self._igso3_discrete_sigma
[docs] def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor: r"""Sample initial noise used for reverse process .. math:: \mathcal{U}_{SO(3)} Args: shape (Tuple[int, ...]): the shape of the sample Returns: Tensor: the initial noise """ macro_shape = shape discrete_t = self.time_scheduler.num_train_timesteps - 1 axis = torch.randn(macro_shape + (3,)) axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True) angle = inverse_transform_sampling( shape=macro_shape, cdf=self.igso3_cdf[discrete_t], discrete_omega=self.igso3_discrete_omega ) rotation_vector = angle * axis_in_s2 rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector) rotation_matrix = self.so3.exp(v=rotation_skew_symmetric) return rotation_matrix
[docs] def forward_process( self, x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any] ) -> dict: r"""Forward process .. math:: \text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3) Args: x_0 (Tensor): the initial sample discrete_t (Tensor): the discrete timestep mask (Tensor): the mask *args: additional arguments **kwargs: additional keyword arguments Returns: dict: a dictionary that must contain the key "x_t" """ # x.shape = (b, n, 3, 3) macro_shape = self.get_macro_shape(x_0) # (*macro_shape, ) = (b,) n = x_0.shape[-3] shape = macro_shape + (n,) device = x_0.device axis = torch.randn(shape + (3,), device=device) # (*macro_shape, n, 3) axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True) # (*macro_shape, n, 3) igso3_cdf = self.igso3_cdf[discrete_t] # (*macro_shape, num_omega) igso3_cdf = igso3_cdf.unsqueeze(-2).expand(*macro_shape, n, -1) # (*macro_shape, n, num_omega) angle = inverse_transform_sampling( shape=shape, cdf=igso3_cdf, discrete_omega=self.igso3_discrete_omega ) # (*macro_shape, n) rotation_vector = angle.unsqueeze(-1) * axis_in_s2 # (*macro_shape, n, 3) rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector) # (*macro_shape,n 3, 3) rotation_matrix = self.so3.exp(v=rotation_skew_symmetric) # (*macro_shape,n, 3, 3) x_t = self.so3.multiply(rotation_matrix, x_0) # (*macro_shape, n, 3, 3) self.masker.apply_mask(x_t, mask) return {"x_t": x_t}
[docs] def get_ground_truth_score(self, x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) -> Tensor: """Denoise Score Matching .. math:: \nabla_x \log p_{0t} (x_t | x_0) Args: x_0 (Tensor): _description_ x_t (Tensor): _description_ discrete_t (Tensor): _description_ padding_mask (Tensor): _description_ Returns: Tensor: _description_ """ macro_shape = self.get_macro_shape(x_0) n = x_0.shape[-3] x_0t = x_0.transpose(-1, -2) @ x_t # (*macro_shape, n, 3, 3) omega = rotation_matrix_to_angle(x_0t) # (*macro_shape, n) igso3_score_norm = self.igso3_score_norm[discrete_t] # (*macro_shape, num_omega) igso3_score_norm = igso3_score_norm.unsqueeze(-2).expand(*macro_shape, n, -1) # (*macro_shape, n, num_omega) ground_truth_score = ( x_t # (*macro_shape, n, 3, 3)
[docs] @ (self.so3.log(q=x_0t) / (omega.unsqueeze(-1).unsqueeze(-1) + EPS)) # (*macro_shape, n, 3, 3) * interp(x=omega, xp=self.igso3_discrete_omega, fp=igso3_score_norm) .unsqueeze(-1) .unsqueeze(-1) # (*macro_shape, n, 3, 3) ) return ground_truth_score
def compute_loss(self, batch: dict[str, Any], *args: list[Any], **kwargs: dict[Any, Any]) -> Tensor: x_0 = batch["x_0"] padding_mask = batch["padding_mask"] macro_shape = self.get_macro_shape(x_0) discrete_t = batch.get("t", None) if discrete_t is None: discrete_t = self.time_scheduler.sample_a_discrete_time_step_uniformly(macro_shape=macro_shape) x_t = self.forward_process(x_0, discrete_t=discrete_t, mask=padding_mask)["x_t"] ground_truth_score = self.get_ground_truth_score(x_0, x_t, discrete_t, padding_mask) predicted_score = self.score_fn(x_t, discrete_t, padding_mask) loss = self.loss_fn(predicted_score, ground_truth_score, padding_mask) return loss
[docs] def sample_xtm1_conditional_on_xt( self, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any] ) -> Tensor: r""" .. math:: dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\ x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\ f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\ g_{rev} &= g\\ """ continuous_t = self.time_scheduler.discrete_time_to_continuous_time(discrete_t) f, g = self.sde.get_drift_and_diffusion(x=x_t, t=continuous_t, mask=padding_mask) # p_x_0: Tensor = kwargs.get("p_x_0", None) # assert p_x_0 is not None, "p_x_0 is required" # riemannian_grad = self.get_ground_truth_score( # x_0=p_x_0, x_t=x_t, discrete_t=discrete_t, padding_mask=padding_mask # ) riemannian_grad = self.score_fn(x_t, discrete_t, padding_mask) assert f.sum() == 0, "f should be 0" rev_f = f - g**2 * riemannian_grad rev_g = g delta_t = self.time_scheduler.T / self.time_scheduler.num_inference_timesteps term1 = -rev_f * delta_t noise_lie_algebra = self.sample_noise_in_lie_algebra(macro_shape=self.get_macro_shape(x_t)) delta_w = torch.sqrt(delta_t) * x_t @ noise_lie_algebra term2 = rev_g * delta_w move_in_tangent_space = term1 + term2 x_tm1 = self.so3.exp(p=x_t, v=move_in_tangent_space) return x_tm1
[docs] def sample_noise_in_lie_algebra( self, macro_shape: Tuple[int, ...], ) -> Tensor: r"""Sample noise in Lie algebra, Skew-symmetric matrix Args: macro_shape (Tuple[int, ...]): the macro shape of the noise Returns: Tensor: the noise in Lie algebra of shape :math:`(*macro_shape, 3, 3)` """ return self.so3.random_tangent(p=self.so3.identity(macro_shape=macro_shape))