Source code for ls_mlkit.diffuser.sde.predictor

import abc
import functools
from typing import Tuple

import torch
from overrides import override
from torch import Tensor

from ...util.decorators import register_class_to_dict
from .base_sde import SDE

_PREDICTORS = {}

register_predictor = functools.partial(register_class_to_dict, global_dict=_PREDICTORS)


[docs] class Predictor(abc.ABC): def __init__(self, sde: SDE, score_fn: object, use_probability_flow=False): super().__init__() self.sde = sde # Compute the reverse SDE/ODE self.rsde = sde.get_reverse_sde(score_fn=score_fn, use_probability_flow=use_probability_flow) self.score_fn = score_fn
[docs] @abc.abstractmethod def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r"""One update of the predictor. Args: x: A PyTorch tensor representing the current state t: A Pytorch tensor representing the current time step. Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. """
[docs] @register_predictor(key_name="none") class NonePredictor(Predictor): def __init__(self, sde, score_fn, use_probability_flow=False): ...
[docs] def update_fn(self, x, t, mask=None): return x, x
[docs] @register_predictor(key_name="reverse_diffusion_predictor") class ReverseDiffusionPredictor(Predictor): def __init__(self, sde: SDE, score_fn, use_probability_flow=False, n_dim: int = 3): super().__init__(sde=sde, score_fn=score_fn, use_probability_flow=use_probability_flow) self.n_dim = n_dim
[docs] @override def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r""" .. math:: x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t})) f &= f(x_t, t)|\Delta t| g &= g(x_t, t)\sqrt{|\Delta t|} """ f, g = self.rsde.get_discretized_drift_and_diffusion(x, t, mask=mask) z = torch.randn_like(x) x_mean = x - f x = x_mean + g * z return x, x_mean