r"""
SO3 Utils
"""
from typing import Tuple
import torch
from torch import Tensor
from ...util.decorators import cache_to_disk
from ...util.interp import interp
EPS = 1e-8
[docs]
def get_macro_shape(x: Tensor, ndim_micro_shape: int) -> Tuple[int, ...]:
return x.shape[:-ndim_micro_shape]
[docs]
def flatten_batch_dimension(x: Tensor, ndim_micro_shape: int) -> Tensor:
return x.view(-1, *x.shape[-ndim_micro_shape:])
[docs]
def trace(A: Tensor) -> Tensor:
"""
Args:
A.shape: (..., 3, 3)
Returns:
shape: (..., )
"""
return torch.diagonal(A, dim1=-1, dim2=-2).sum(dim=-1) # (..., )
[docs]
def vector_to_skew_symmetric(v: Tensor) -> Tensor:
r"""
Hat map from vector space $$\mathbb{R}^3$$ to Lie algebra $$\mathfrak{so}(3)$$
$$
(x,y,z) \to \begin{pmatrix}
0 & -z & y \\
z & 0 & -x \\
-y & x & 0
\end{pmatrix}
$$
Args:
v.shape: (..., 3)
Returns:
shape: (..., 3, 3)
"""
macro_shape = get_macro_shape(x=v, ndim_micro_shape=1)
hat_v = torch.zeros([*macro_shape, 3, 3], dtype=v.dtype, device=v.device)
hat_v[..., 0, 1], hat_v[..., 0, 2], hat_v[..., 1, 2] = -v[..., 2], v[..., 1], -v[..., 0]
return hat_v + -hat_v.transpose(-1, -2)
[docs]
def skew_symmetric_to_vector(hat_v: Tensor) -> Tensor:
r"""
Map from skew-symmetric matrix to vector
$$\mathfrak{so}(3) \mapsto \mathbb{R}^3$$
$$
\begin{pmatrix}
0 & -z & y \\
z & 0 & -x \\
-y & x & 0
\end{pmatrix} \to (x,y,z)
$$
Args:
hat_v.shape: (..., 3, 3)
Returns:
shape: (..., 3)
"""
macro_shape = get_macro_shape(x=hat_v, ndim_micro_shape=2)
v = torch.zeros([*macro_shape, 3], dtype=hat_v.dtype, device=hat_v.device)
v[..., 0], v[..., 1], v[..., 2] = -hat_v[..., 1, 2], hat_v[..., 0, 2], -hat_v[..., 0, 1]
return v
[docs]
def skew_symmetric_to_angle(A: Tensor) -> Tensor:
r"""
$$
\begin{pmatrix}
0 & -z & y \\
z & 0 & -x \\
-y & x & 0
\end{pmatrix} \to \sqrt{x^2 + y^2 + z^2}
$$
Args:
A.shape: (..., 3, 3)
Returns:
shape: (..., )
"""
return torch.sqrt(A[..., 0, 1] ** 2 + A[..., 0, 2] ** 2 + A[..., 1, 2] ** 2)
[docs]
def unit_skew_symmetric(A: Tensor) -> Tensor:
r"""
get the unit skew-symmetric matrix
Args:
A.shape: (..., 3, 3)
Returns:
shape: (..., 3, 3)
"""
theta = skew_symmetric_to_angle(A)
return A / (theta[..., None, None] + EPS)
[docs]
def rotation_matrix_to_angle(R: Tensor) -> Tensor:
r"""
$$
\theta = \arccos(\frac{Tr(R)-1}{2})
$$
Args:
R.shape: (..., 3, 3)
Returns:
shape: (..., )
"""
theta = torch.arccos(((trace(R) - 1) / 2).clamp(min=-1 + EPS, max=1 - EPS)) # (..., )
return theta
[docs]
def logarithmic_map(R: Tensor) -> Tensor:
r"""
Logarithmic map from SO(3) to so(3), this is the matrix logarithm
$$SO(3) \mapsto \mathfrak{so}(3)$$
$$
\begin{align*}
\theta &= \arccos(\frac{Tr(R)-1}{2})\\
\log(R) &= \frac{\theta}{2\sin(\theta)} (R - R^T)
\end{align*}
$$
Args:
R.shape: (..., 3, 3)
Returns:
shape: (..., 3, 3)
"""
theta = rotation_matrix_to_angle(R) # (..., )
log_R = (theta / (2 * torch.sin(theta) + EPS))[..., None, None] * (R - R.transpose(-1, -2)) # (..., 3, 3)
return log_R
[docs]
def exponential_map(A: Tensor) -> Tensor:
r"""
Exponential map from vector space of $$\mathfrak{so}(3)$$ to SO(3), this is the matrix
$$\mathfrak{so}(3) \mapsto SO(3)$$
$$
\begin{align*}
\theta &= \sqrt{A_{0,1}^2 + A_{0,2}^2 + A_{1,2}^2}\\
B &= A / \theta\\
\exp(A) &= I + \sin(\theta) B + (1 - \cos(\theta)) B^2
\end{align*}
$$
Args:
A.shape: (..., 3, 3)
Returns:
shape: (..., 3, 3)
"""
theta = skew_symmetric_to_angle(A)
unit_A = unit_skew_symmetric(A)
macro_shape = get_macro_shape(x=A, ndim_micro_shape=2)
ndim_macro_shape = len(macro_shape)
exp_A = (
torch.eye(3).view(*[1 for _ in range(ndim_macro_shape)], 3, 3).to(dtype=A.dtype, device=A.device) # (..., 3, 3)
+ torch.sin(theta)[..., None, None] * unit_A # (..., 3, 3)
+ (1 - torch.cos(theta))[..., None, None] * unit_A @ unit_A # (..., 3, 3)
)
return exp_A
L_default = 2000
[docs]
def f_igso3(omega: Tensor, c: Tensor, L: int = L_default) -> Tensor:
r"""Truncated sum of IGSO(3) distribution.
This function approximates the power series in equation 5 of
"DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL
ALIGNMENT"
Leach et al. 2022
This expression diverges from the expression in Leach in that here,
$$\sigma = \sqrt{2} * \epsilon$$, if eps_leach were the scale parameter of the IGSO(3).
With this reparameterization, IGSO(3) agrees with the Brownian motion on
SO(3) with t=sigma^2 when defined for the canonical inner product on SO3,
$$<u, v>_{SO3} = Tr(u v^T)/2$$
Args:
omega: (...,), i.e. the angle of rotation associated with rotation matrix
c: (1,), variance parameter of IGSO(3), maps onto time in Brownian motion
L: (1,)Truncation level
Returns:
$$
\sum_{i=0}^{L-1} (2l+1) e^{-l(l+1)c/2} \sin (\omega (l +1/2)) / sin(\omega/2)
$$
"""
ls = torch.arange(L) # of shape [L,]
s = (2 * ls + 1) * torch.exp(-ls * (ls + 1) * c / 2).view(*([1] * omega.ndim), L) # (..., L)
numerator = torch.sin(
omega[..., None] * (ls + 1 / 2).view(*([1] * omega.ndim), L)
) # (..., 1) * (..., L) = (..., L)
denominator = torch.sin(omega / 2).unsqueeze(-1) # (..., 1)
# Add small epsilon to prevent division by zero when omega is close to 0 or 2π
denominator_safe = torch.clamp(torch.abs(denominator), min=1e-8) * torch.sign(denominator)
denominator_safe = torch.where(torch.abs(denominator) < 1e-8, torch.ones_like(denominator) * 1e-8, denominator)
result = s * numerator / denominator_safe # (..., L)
result = result.sum(dim=-1)
# Ensure result is always positive and finite
result = torch.clamp(result, min=1e-12)
return result
[docs]
def d_logf_d_omega(omega: Tensor, c: Tensor, L: int = L_default) -> Tensor:
r"""
Score function of IGSO(3) distribution.
$$
\frac{d}{d\omega} \log f(\omega, c, L)
$$
Args:
omega: (...,), i.e. the angle of rotation associated with rotation matrix
c: (1,), variance parameter of IGSO(3), maps onto time in Brownian motion
L: (1,)Truncation level
Returns:
(..., )
"""
omega = omega.clone().detach().requires_grad_(True)
f_val = f_igso3(omega, c, L)
# Clamp f_val to avoid log(0) or log(negative)
# Use a small positive value to prevent numerical instability
f_val_clamped = torch.clamp(f_val, min=1e-8)
log_f = torch.log(f_val_clamped)
return torch.autograd.grad(log_f.sum(), omega)[0]
[docs]
def igso3_density(Rt: Tensor, c: Tensor, L: int = L_default) -> Tensor:
r"""
IGSO3 density with respect to the volume form on SO(3)
Args:
Rt: (..., 3, 3), rotation matrix
c: (1,), variance parameter of IGSO(3), maps onto time in Brownian motion
L: (1,)Truncation level
Returns:
(..., )
"""
omega = rotation_matrix_to_angle(Rt)
return f_igso3(omega, c, L)
[docs]
def igso3_density_angle(omega: Tensor, c: Tensor, L: int = L_default) -> Tensor:
r"""
$$((1-\cos(\omega)) / \pi ) f$$
Args:
omega: (...,), i.e. the angle of rotation associated with rotation matrix
c: (1,), variance parameter of IGSO(3), maps onto time in Brownian motion
L: (1,)Truncation level
Returns:
(..., )
"""
return f_igso3(omega, c, L) * (1 - torch.cos(omega)) / torch.pi
[docs]
def igso3_score(R: Tensor, c: Tensor, L: int = L_default) -> Tensor:
r"""
grad_R log IGSO3(R; I_3, c)
$$
\nabla_R \log IG_{SO3}(R; I_3, c) = R \frac{log(R)}{\omega(R)} \frac{d}{d\omega} \log f(\omega, c, L)
$$
Args:
R: (..., 3, 3), rotation matrix
c: (1,), variance parameter of IGSO(3), maps onto time in Brownian motion
L: (1,)Truncation level
Returns:
(..., 3, 3)
"""
omega = rotation_matrix_to_angle(R) # (..., )
unit = torch.einsum("...ij,...jk->...ik", R, logarithmic_map(R)) / omega[:, None, None] # (..., 3, 3)
return unit * d_logf_d_omega(omega, c, L)[..., None, None]
[docs]
@cache_to_disk(root_datadir="cache")
def calculate_igso3(
*, num_sigma: int, num_omega: int, min_sigma: float, max_sigma: float, discrete_omega=None, discrete_sigma=None
) -> dict[str, Tensor]:
r"""calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs
and score norms and expected squared score norms.
Args:
num_sigma: number of different sigmas for which to compute igso3
quantities.
num_omega: number of point in the discretization in the angle of
rotation.
min_sigma, max_sigma: the upper and lower ranges for the angle of
rotation on which to consider the IGSO3 distribution. This cannot
be too low or it will create numerical instability.
"""
# Discretize omegas for calculating CDFs. Skip omega=0.
if discrete_omega is None:
discrete_omega = torch.linspace(0, torch.pi, num_omega + 1)[1:] # [num_omega, ]
else:
discrete_omega = discrete_omega
# Exponential noise schedule. This choice is closely tied to the
# scalings used when simulating the reverse time SDE. For each step n,
# discrete_sigma[n] = min_eps^(1-n/num_eps) * max_eps^(n/num_eps)
if discrete_sigma is None:
discrete_sigma = (
10 ** torch.linspace(torch.log10(min_sigma), torch.log10(max_sigma), num_sigma + 1)[1:]
) # [num_sigma, ]
else:
discrete_sigma = discrete_sigma
# Compute the pdf and cdf values for the marginal distribution of the angle
# of rotation (which is needed for sampling)
# $$\pi/\omega$$ is the length of the interval in the angle of rotation
pdf_vals = torch.stack(
[igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma]
) # [num_sigma, num_omega]
cdf_vals = torch.stack([pdf.cumsum(dim=-1) / num_omega * torch.pi for pdf in pdf_vals]) # [num_sigma, num_omega]
# Compute the norms of the scores. This are used to scale the rotation axis when
# computing the score as a vector.
score_norm = torch.stack(
[d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma]
) # [num_sigma, num_omega]
# Compute the standard deviation of the score norm for each sigma
r"""
$$
\sqrt{\mathbb{E}_{\omega} || \frac{d}{d\omega} f(\omega, c, L)||_2^2}
$$
"""
exp_score_norms = torch.sqrt(
torch.sum(score_norm**2 * pdf_vals, axis=1) / torch.sum(pdf_vals, axis=1)
) # [num_sigma, ]
return {
"cdf": cdf_vals, # [num_sigma, num_omega]
"score_norm": score_norm, # [num_sigma, num_omega]
"exp_score_norms": exp_score_norms, # [num_sigma, ]
"discrete_omega": discrete_omega, # [num_omega, ]
"discrete_sigma": discrete_sigma, # [num_sigma, ]
}