from typing import Any, Callable, Literal, Tuple, cast
import numpy as np
import torch
from torch import Tensor
from ls_mlkit.diffuser.time_scheduler import TimeScheduler
from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from .conditioner import Conditioner
from .euclidean_diffuser import EuclideanDiffuser, EuclideanDiffuserConfig
from .model_interface import Model4DiffuserInterface
[docs]
@inherit_docstrings
class EuclideanDDPMConfig(EuclideanDiffuserConfig):
"""
Config Class for Euclidean DDPM Diffuser
"""
def __init__(
self,
n_discretization_steps: int = 1000,
ndim_micro_shape: int = 2,
use_probability_flow=False,
use_clip: bool = True,
clip_sample_range: float = 1.0,
use_dyn_thresholding: bool = False,
dynamic_thresholding_ratio=0.995,
sample_max_value: float = 1.0,
betas=None,
*args,
**kwargs,
):
r"""
Args:
n_discretization_steps: the number of discretization steps
ndim_micro_shape: the number of dimensions of the micro shape
use_probability_flow: whether to use probability flow
use_clip: whether to use clip
clip_sample_range: the range of the clip
use_dyn_thresholding: whether to use dynamic thresholding
dynamic_thresholding_ratio: the ratio of the dynamic thresholding
sample_max_value: the maximum value of the sample used in thresholding
betas: the betas
Returns:
None
"""
super().__init__(
n_discretization_steps=n_discretization_steps,
ndim_micro_shape=ndim_micro_shape,
)
self.betas: Tensor
if betas is None:
# Use the same beta schedule as standard DDPMScheduler
# Linear schedule from beta_start=0.0001 to beta_end=0.02
self.betas = torch.linspace(0.0001, 0.02, steps=self.n_discretization_steps, dtype=torch.float32)
else:
self.betas = betas
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) # expectation
self.sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) # std
self.use_clip = use_clip
self.clip_sample_range = clip_sample_range
self.use_dyn_thresholding = use_dyn_thresholding
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.sample_max_value = sample_max_value
[docs]
@inherit_docstrings
class EuclideanDDPMDiffuser(EuclideanDiffuser):
def __init__(
self,
config: EuclideanDDPMConfig,
time_scheduler: TimeScheduler,
masker: MaskerInterface,
conditioner_list: list[Conditioner],
model: Model4DiffuserInterface,
loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor], # (predicted, ground_true, padding_mask)
):
"""Initialize the EuclideanDDPMDiffuser
Args:
config (EuclideanDDPMConfig): the config of the diffuser
time_scheduler (TimeScheduler): the time scheduler of the diffuser
masker (MaskerInterface): the masker of the diffuser
conditioner_list (list[Conditioner]): the list of conditioners of the diffuser
model (Model4DiffuserInterface): the model of the diffuser
loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser
Returns:
None
"""
super().__init__(config=config, time_scheduler=time_scheduler, masker=masker, conditioner_list=conditioner_list)
self.config: EuclideanDDPMConfig = config
self.model = model
self.loss_fn = loss_fn
[docs]
def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
return torch.randn(shape)
[docs]
def compute_loss(self, batch: dict[str, Any], *args: Any, **kwargs: Any) -> dict:
mode: Literal["epsilon", "x_0", "score"] = batch.get("mode", "epsilon")
batch = self.model.prepare_batch_data_for_input(batch)
assert isinstance(batch, dict), "batch must be a dictionary"
x_0 = batch["x_0"]
padding_mask = batch["padding_mask"]
device = x_0.device
macro_shape = self.get_macro_shape(x_0)
t = batch.get("t", None)
if t is None:
t = self.time_scheduler.sample_a_discrete_time_step_uniformly(macro_shape).to(device)
self.config = self.config.to(t)
sqrt_1m_alphas_cumprod = self.complete_micro_shape(self.config.sqrt_1m_alphas_cumprod[t])
sqrt_alphas_cumprod = self.complete_micro_shape(self.config.sqrt_alphas_cumprod[t])
b = sqrt_1m_alphas_cumprod
a = sqrt_alphas_cumprod
forward_result = self.forward_process(x_0, t, padding_mask)
x_t, noise = (forward_result["x_t"], forward_result["noise"])
model_input_dict = batch
model_input_dict.pop("x_0")
model_input_dict.pop("padding_mask")
model_input_dict.pop("t", None)
model_output = self.model(x_t, t, padding_mask, **model_input_dict)
# Simplified loss calculation following standard DDPM
if mode == "epsilon":
predicted_noise = model_output["x"]
# Standard DDPM loss: MSE between predicted and actual noise
loss = self.loss_fn(predicted_noise, noise, padding_mask)
elif mode == "x_0":
predicted_x0 = model_output["x"]
# Convert to noise prediction for consistent loss calculation
predicted_noise = (x_t - a * predicted_x0) / b
loss = self.loss_fn(predicted_noise, noise, padding_mask)
elif mode == "score":
raise ValueError(f"Currently not supported mode: {mode}")
else:
raise ValueError(f"Invalid mode: {mode}")
# Handle conditioners if any (for advanced use cases)
if len(self.conditioner_list) > 0:
# Original complex logic for conditioners
p_uc_score = -predicted_noise / b
gt_uc_score = -noise / b
tgt_mask = padding_mask
for conditioner in self.conditioner_list:
if not conditioner.is_enabled():
continue
conditioner.set_condition(
**{
**conditioner.prepare_condition_dict(
train=True,
**{
"tgt_mask": tgt_mask,
"x_0": x_0,
"padding_mask": padding_mask,
"posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None),
},
),
}
)
acc_c_score = self.get_accumulated_conditional_score(x_t, t, padding_mask)
gt_score = gt_uc_score + acc_c_score
# Scale and compute conditioned loss
p_uc_score = b * p_uc_score
gt_score = b * gt_score
loss = self.loss_fn(p_uc_score, gt_score, padding_mask)
return {"loss": loss, "model_output": model_output}
[docs]
def q_xt_x_0(self, x_0: Tensor, t: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
r"""Forward process
.. math::
q(x_t|x_0) = \mathcal{N}(\sqrt{\alpha_t} x_0, \sqrt{1-\alpha_t} I)
Args:
x_0 (Tensor): :math:`x_0`
t (Tensor): :math:`t`
mask (Tensor): the mask of the sample
Returns:
Tuple[Tensor, Tensor]: the expectation and standard deviation of the sample
"""
config = cast(EuclideanDDPMConfig, self.config.to(t))
expectation = self.complete_micro_shape(config.sqrt_alphas_cumprod[t]) * x_0
standard_deviation = self.complete_micro_shape(config.sqrt_1m_alphas_cumprod[t])
return expectation, standard_deviation
[docs]
def forward_process_one_step(self, x: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> Tensor:
config = cast(EuclideanDDPMConfig, self.config.to(t))
beta_t = config.betas[t] # (macro_shape)
a = (1 - beta_t) ** 0.5
b = beta_t**0.5
a = self.complete_micro_shape(a)
b = self.complete_micro_shape(b)
noise = torch.randn_like(x)
x_next = a * x + b * noise
x_next = self.masker.apply_mask(x_next, padding_mask)
return x_next
[docs]
def forward_process_n_step(
self, x: Tensor, t: Tensor, next_t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any
) -> Tensor:
assert (next_t > t).all()
assert (t >= 0).all()
assert (next_t < self.config.n_discretization_steps).all()
config = cast(EuclideanDDPMConfig, self.config.to(t))
a_square = config.alphas_cumprod[next_t] / config.alphas_cumprod[t]
a = a_square**0.5
b = (1 - a_square) ** 0.5
a = self.complete_micro_shape(a)
b = self.complete_micro_shape(b)
noise = torch.randn_like(x)
x_next = a * x + b * noise
x_next = self.masker.apply_mask(x_next, padding_mask)
return x_next
[docs]
def forward_process(
self, x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]
) -> dict:
device = x_0.device
expectation, standard_deviation = self.q_xt_x_0(x_0, discrete_t, mask)
noise = torch.randn_like(expectation, device=device)
x_t = expectation + standard_deviation * noise
x_t = self.masker.apply_mask(x_t, mask)
return {"x_t": x_t, "noise": noise, "expectation": expectation, "standard_deviation": standard_deviation}
[docs]
def sample_xtm1_conditional_on_xt(
self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any
) -> Tensor:
r"""
Predict the sample from the previous timestep by reversing the SDE.
This function propagates the diffusion process from the learned model outputs.
Based on the standard DDPM sampling formula:
.. math::
\hat{\mathbf{x}}_0:=\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\mathbf{\epsilon}_{\theta}(\mathbf{x}_t,t))
\mathcal{N}\left( \boldsymbol{x}_{t-1}; \underbrace{\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{\boldsymbol{x}}_0}{1-\bar{\alpha}_t}}_{\mu_q(\boldsymbol{x}_t, \hat{\boldsymbol{x}}_0)}, \underbrace{\frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{I}}_{\Sigma_q(t)} \right)
Args:
x_t (Tensor): the sample at timestep t
t (Tensor): the timestep
padding_mask (Tensor): the padding mask
Returns:
Tensor: the sample at timestep t-1
"""
mode: Literal["epsilon", "x_0", "score"] = kwargs.get("mode", "epsilon")
assert torch.all(t == t.view(-1)[0]).item()
config = cast(EuclideanDDPMConfig, self.config.to(t))
# Convert to scalar timestep for indexing
t_scalar = t.view(-1)[0].long()
# Get model prediction
model_output = self.model(x_t, t.long(), padding_mask, *args, **kwargs)
if mode == "epsilon":
model_pred = model_output["x"]
elif mode == "x_0":
model_pred = model_output["x"]
elif mode == "score":
raise ValueError(f"Currently not supported mode: {mode}")
else:
raise ValueError(f"Invalid mode: {mode}")
# Calculate previous timestep (handle both standard and custom timestep schedules)
prev_t = self._get_previous_timestep(t_scalar)
# Get alpha values
alpha_prod_t = config.alphas_cumprod[t_scalar]
alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0).to(t.device)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# Compute predicted original sample from predicted noise
pred_original_sample: Tensor = None
if mode == "epsilon":
pred_original_sample = (x_t - beta_prod_t**0.5 * model_pred) / alpha_prod_t**0.5
elif mode == "x_0":
pred_original_sample = model_pred
# Clip predicted x_0 (following standard DDPM implementation)
# 3. Clip or threshold "predicted x_0"
if self.config.use_dyn_thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.use_clip:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev**0.5 * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t**0.5 * beta_prod_t_prev / beta_prod_t
# Compute predicted previous sample µ_t
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * x_t
# Add noise (variance) - following standard DDPM variance calculation
variance = 0
if t_scalar > 0:
# Standard DDPM variance: β_t * (1 - α̅_{t-1}) / (1 - α̅_t)
variance_value = self._get_variance(t_scalar, alpha_prod_t, alpha_prod_t_prev, current_beta_t)
variance_noise = torch.randn_like(x_t)
variance = (variance_value**0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def _get_previous_timestep(self, timestep: int) -> int:
r"""Get the previous timestep for sampling.
Args:
timestep (int): timestep
Returns:
int: the previous timestep for sampling
"""
return timestep - 1
def _get_variance(self, t: int, alpha_prod_t: Tensor, alpha_prod_t_prev: Tensor, current_beta_t: Tensor) -> Tensor:
r"""Calculate variance for timestep t following standard DDPM formula. For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
.. math::
\sigma^2 = (\frac{1 - \bar{\alpha}_{pre}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{pre}})
Args:
t (int): timestep
alpha_prod_t (Tensor): :math:`\bar{\alpha}_t`
alpha_prod_t_prev (Tensor): :math:`\bar{\alpha}_{t-1}`
current_beta_t (Tensor): :math:`\beta_t`
Returns:
Tensor: the variance for timestep t
"""
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# Clamp variance to ensure numerical stability
variance = torch.clamp(variance, min=1e-20)
return variance
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) # (batch_size, 1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
[docs]
def get_posterior_mean_fn(self, score: Tensor = None, score_fn: Callable = None):
r"""Get the posterior mean function
Args:
score (Tensor, optional): the score of the sample
score_fn (Callable, optional): the function to compute score
Returns:
Callable: the posterior mean function
"""
def _posterior_mean_fn(
x_t: Tensor,
t: Tensor,
padding_mask: Tensor,
):
r"""
Args:
x_t: shape=(..., n_nodes, 3)
t: shape=(...), dtype=torch.long
For the case of DDPM sampling, the posterior mean is given by
.. math::
E[x_0|x_t] = \frac{1}{\sqrt{\bar{\alpha}(t)}}(x_t + (1 - \bar{\alpha}(t))\nabla_{x_t}\log p_t(x_t))
"""
nonlocal score, score_fn
assert score is not None or score_fn is not None, "either score or score_fn must be provided"
if score is None:
score = score_fn(x_t, t, padding_mask)
config = cast(EuclideanDDPMConfig, self.config.to(t))
alpha_bar_t = config.alphas_cumprod[t] # macro_shape
alpha_bar_t = self.complete_micro_shape(alpha_bar_t)
x_0 = (x_t + (1 - alpha_bar_t) * score) / torch.sqrt(alpha_bar_t)
return x_0
return _posterior_mean_fn