Source code for ls_mlkit.diffuser.euclidean_diffuser

from typing import Any, Literal, Tuple

import torch
from torch import Tensor
from tqdm.auto import tqdm

from ls_mlkit.diffuser.time_scheduler import TimeScheduler

from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from .base_diffuser import BaseDiffuser, BaseDiffuserConfig
from .conditioner import Conditioner


[docs] @inherit_docstrings class EuclideanDiffuserConfig(BaseDiffuserConfig): def __init__( self, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2, *args, **kwargs, ): super().__init__( n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, *args, **kwargs )
[docs] @inherit_docstrings class EuclideanDiffuser(BaseDiffuser): def __init__( self, config: EuclideanDiffuserConfig, time_scheduler: TimeScheduler, masker: MaskerInterface, conditioner_list: list[Conditioner] = [], ): super().__init__(config=config, time_scheduler=time_scheduler) self.masker = masker self.conditioner_list = conditioner_list
[docs] def forward_process_one_step(self, x: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> Tensor: r"""Forward process one step Args: x (``Tensor``): the sample t (``Tensor``): the timestep padding_mask (``Tensor``): the padding mask Returns: ``Tensor``: the sample at the next timestep """
[docs] def forward_process_n_step( self, x: Tensor, t: Tensor, next_t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any ) -> Tensor: r"""Forward process n step, from t to next_t Args: x (``Tensor``): the sample t (``Tensor``): the timestep next_t (``Tensor``): the next timestep padding_mask (``Tensor``): the padding mask Returns: ``Tensor``: the sample at the next timestep """
[docs] def get_accumulated_conditional_score( self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any ) -> Tensor: r"""Get the accumulated conditional score Args: x_t (``Tensor``): :math:`x_t` t (``Tensor``): :math:`t` padding_mask (``Tensor``): the padding mask Returns: ``Tensor``: the accumulated conditional score """ accumulated_conditional_score = torch.zeros_like(x_t) for conditioner in self.conditioner_list: if not conditioner.is_enabled(): continue accumulated_conditional_score += conditioner.get_conditional_score(x_t, t, padding_mask, *args, **kwargs) return accumulated_conditional_score
[docs] @torch.no_grad() def sample_x0_unconditionally( self, shape: Tuple[int, ...], device, x_init_posterior: Tensor = None, *args: Any, **kwargs: Any ) -> Tensor: r"""Sample :math:`x_0` unconditionally Args: shape (``Tuple[int, ...]``): the shape of the sample device (``device``): the device to use for sampling x_init_posterior (``Tensor``): Use x_init_posterior as the initial posterior if provided, otherwise sample from prior. *args: additional arguments **kwargs: additional keyword arguments Returns: Tensor: :math:`x_0` """ config = self.config if x_init_posterior is not None: shape = x_init_posterior.shape macro_shape = shape[: -self.config.ndim_micro_shape] masker = self.masker if x_init_posterior is None: x_t = self.prior_sampling(shape).to(device) else: x_t = x_init_posterior padding_mask = kwargs.get("padding_mask", None) if padding_mask is None: padding_mask = masker.get_full_bright_mask(x_t) x_t = self.forward_process( x_t, torch.ones(macro_shape, device=device, dtype=torch.long) * (config.n_discretization_steps - 1), padding_mask, )["x_t"] # Get timesteps from the specific diffuser implementation timesteps = self.time_scheduler.get_discrete_timesteps_schedule() for t in tqdm(timesteps): t = torch.ones(macro_shape, device=device) * t no_padding_mask = masker.get_full_bright_mask(x_t) x_t = self.sample_xtm1_conditional_on_xt(x_t, t, no_padding_mask, *args, **kwargs) return x_t
[docs] @torch.no_grad() def inpainting_x0_unconditionally( self, x_0: Tensor, padding_mask: Tensor, inpainting_mask: Tensor, device, recovery_mode: Literal["x_0", "x_t"] = "x_t", n_repaint_steps: int = 1, x_init_posterior: Tensor = None, inpainting_mask_key="inpainting_mask", *args: Any, **kwargs: Any, ) -> Tensor: r"""Inpaint :math:`x_0` unconditionally Args: x_0 (``Tensor``): :math:`x_0` padding_mask (``Tensor``): the padding mask inpainting_mask (``Tensor``): the inpainting mask device (``device``): the device to use for sampling recovery_mode (``Literal["x_0", "x_t"]``): the recovery mode n_repaint_steps (``int``): the number of repaint steps, n_repaint_steps > 1 means use RePaint. x_init_posterior (``Tensor``): Use x_init_posterior as the initial posterior if provided, otherwise sample from prior. *args: additional arguments **kwargs: additional keyword arguments Returns: ``Tensor``: :math:`x_0` inpainted """ shape = x_0.shape config = self.config macro_shape = shape[: -config.ndim_micro_shape] masker = self.masker x_t = None if x_init_posterior is None: x_t = self.prior_sampling(shape).to(device) else: x_t = x_init_posterior x_t = self.forward_process( x_t, torch.ones(macro_shape, device=device, dtype=torch.long) * (config.n_discretization_steps - 1), padding_mask, )["x_t"] x_0 = masker.apply_mask(x_0, padding_mask) timesteps = self.time_scheduler.get_discrete_timesteps_schedule() # Add inpainting_mask to kwargs so it gets passed to the model kwargs[inpainting_mask_key] = inpainting_mask for i, t in enumerate(tqdm(timesteps)): for u in range(1, n_repaint_steps + 1): t = torch.ones(macro_shape, device=device, dtype=torch.long) * t if recovery_mode == "x_t": x_0t = self.forward_process(x_0, t, padding_mask)["x_t"] else: x_0t = x_0 x_t = masker.apply_inpainting_mask(x_0t, x_t, inpainting_mask) x_t = self.sample_xtm1_conditional_on_xt(x_t, t, padding_mask, *args, **kwargs) # get x_tm1 x_t = masker.apply_mask(x_t, padding_mask) if u < n_repaint_steps and (t > 0).all(): assert i < len(timesteps) - 1 prev_t = timesteps[i + 1].to(device) x_t = self.forward_process_n_step(x_t, prev_t, t, padding_mask, *args, **kwargs) x_t = masker.apply_inpainting_mask(x_0, x_t, inpainting_mask) return x_t