ls_mlkit.diffuser.so3_diffuser module

class ls_mlkit.diffuser.so3_diffuser.SO3Diffuser(config: SO3DiffuserConfig, time_scheduler: TimeScheduler, masker: BioSO3Masker, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor], Tensor], loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]

Bases: LieGroupDiffuser

compute_loss(batch: dict[str, Any], *args: list[Any], **kwargs: dict[Any, Any]) Tensor[source]

Compute loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss”

Return type:

dict

forward_process(x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]

Forward process

\[\text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)\]
Parameters:
  • x_0 (Tensor) – the initial sample

  • discrete_t (Tensor) – the discrete timestep

  • mask (Tensor) – the mask

  • *args – additional arguments

  • **kwargs – additional keyword arguments

Returns:

a dictionary that must contain the key “x_t”

Return type:

dict

get_ground_truth_score(x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) Tensor[source]

Denoise Score Matching

\[\]

abla_x log p_{0t} (x_t | x_0)

Args:

x_0 (Tensor): _description_ x_t (Tensor): _description_ discrete_t (Tensor): _description_ padding_mask (Tensor): _description_

Returns:

Tensor: _description_

property igso3_cdf: Tensor
property igso3_discrete_omega: Tensor
property igso3_discrete_sigma: Tensor
property igso3_exp_score_norms: Tensor
property igso3_score_norm: Tensor
prior_sampling(shape: Tuple[int, ...]) Tensor[source]

Sample initial noise used for reverse process

\[\mathcal{U}_{SO(3)}\]
Parameters:

shape (Tuple[int, ...]) – the shape of the sample

Returns:

the initial noise

Return type:

Tensor

sample_noise_in_lie_algebra(macro_shape: Tuple[int, ...]) Tensor[source]

Sample noise in Lie algebra, Skew-symmetric matrix

Parameters:

macro_shape (Tuple[int, ...]) – the macro shape of the noise

Returns:

the noise in Lie algebra of shape \((*macro_shape, 3, 3)\)

Return type:

Tensor

sample_xtm1_conditional_on_xt(x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) Tensor[source]
\[\begin{split}dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\ x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\ f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\ g_{rev} &= g\\\end{split}\]
class ls_mlkit.diffuser.so3_diffuser.SO3DiffuserConfig(n_discretization_steps: int, ndim_micro_shape: int, igso3_num_sigma: int, igso3_num_omega: int, igso3_min_sigma: float, igso3_max_sigma: float, *args: list[Any], **kwargs: dict[Any, Any])[source]

Bases: LieGroupDiffuserConfig