ls_mlkit.diffuser.sde package

Submodules

Module contents

class ls_mlkit.diffuser.sde.Corrector(sde: SDE, score_fn: object, snr: float, n_steps: int)[source]

Bases: ABC

The abstract class for a corrector algorithm.

abstractmethod update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.LangevinCorrector(sde: SDE, score_fn: object, snr: float, n_steps: int, n_dim: int = 2)[source]

Bases: Corrector

update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.NoneCorrector(sde, score_fn, snr, n_steps)[source]

Bases: Corrector

An empty corrector that does nothing.

update_fn(x, t, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.NonePredictor(sde, score_fn, use_probability_flow=False)[source]

Bases: Predictor

update_fn(x, t, mask=None)[source]

One update of the predictor.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A Pytorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.Predictor(sde: SDE, score_fn: object, use_probability_flow=False)[source]

Bases: ABC

abstractmethod update_fn(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

One update of the predictor.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A Pytorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.ReverseDiffusionPredictor(sde: SDE, score_fn, use_probability_flow=False, n_dim: int = 3)[source]

Bases: Predictor

update_fn(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]
\[ \begin{align}\begin{aligned}x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t}))\\f &= f(x_t, t)|\Delta t|\\g &= g(x_t, t)\sqrt{|\Delta t|}\end{aligned}\end{align} \]
class ls_mlkit.diffuser.sde.SDE(n_discretization_steps: int, ndim_micro_shape: int = 2)[source]

Bases: ABC

SDE abstract class. Functions are designed for a mini-batch of inputs.

abstract property T: float

End time of the SDE.

get_diffusion_coefficient_with_proper_shape(x: Tensor, diffusion: Tensor) Tensor[source]

Get the diffusion coefficient with the proper shape. Complete the micro shape of the diffusion coefficient.

Parameters:
  • x (Tensor) – the sample.

  • diffusion (Tensor) – the diffusion coefficient.

Returns:

the diffusion coefficient with the proper shape.

Return type:

Tensor

get_discretization_steps(t: Tensor) Tensor[source]

Get the discretization steps.

Parameters:

t (Tensor) – the time step.

Returns:

the discretization steps.

Return type:

Tensor

get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Euler-Maruyama discretization.

\[ \begin{align}\begin{aligned}dx &= f(x, t)dt + g(x,t)d z\\x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,|\Delta t|))\end{aligned}\end{align} \]
Parameters:
  • x – a torch tensor

  • t – a torch float representing the time step (from 0 to self.T)

  • mask – 1 indicates valid region, 0 indicates invalid region

Note

Here dt always greater than 0.

Returns:

f, g

\[ \begin{align}\begin{aligned}f &= f(x,t) |\Delta t|\\g &= g(x,t) \sqrt{|\Delta t|}\end{aligned}\end{align} \]

abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

get_reverse_sde(score_fn: object, use_probability_flow=False)[source]

Create the reverse-time SDE/ODE.

Parameters:
  • score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.

  • use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.

class ls_mlkit.diffuser.sde.SubVPSDE(beta_min: float = 0.1, beta_max: float = 20, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2)[source]

Bases: SDE

property T: float

End time of the SDE.

get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

marginal_prob(x, t, mask=None)[source]
prior_logp(z)[source]
prior_sampling(shape)[source]
class ls_mlkit.diffuser.sde.VESDE(sigma_min=0.01, sigma_max=50, n_discretization_steps=1000, ndim_micro_shape=2, drop_first_step=False)[source]

Bases: SDE

property T: float

End time of the SDE.

get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

SMLD(NCSN) discretization. .. math:

x_t &= x_0 + g \epsilon

x_t &\sim \mathcal{N}(x_0, \sigma_t^2)

\sigma_t^2 &= \sigma_{t-1}^2 + g^2

g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2}
get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]
\[ \begin{align}\begin{aligned}dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t\\diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})}\end{aligned}\end{align} \]
marginal_prob(x, t, mask=None)[source]
prior_logp(z)[source]
prior_sampling(shape)[source]
class ls_mlkit.diffuser.sde.VPSDE(beta_min: float = 0.1, beta_max: float = 20, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2)[source]

Bases: SDE

property T: float

End time of the SDE.

get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

DDPM discretization.

get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

continuous DDPM SDE

\[dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw\]
Parameters:
  • x

  • t – (macro_shape)

  • mask

Returns:

shape = x.shape diffusion: shape=x.macro_shape

Return type:

drift

get_target_score(x_0: Tensor, x_t: Tensor, t: Tensor, mask: Tensor, continuous: bool = False) Tensor[source]
\[p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0)\]
marginal_prob(x_0: Tensor, t: Tensor, mask: Tensor = None) Tuple[Tensor, Tensor][source]
\[p_{0t} (x_t|x_0)\]
\[ \begin{align}\begin{aligned}\gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t \beta_0\\mean = e^{\gamma} * x\\std = \sqrt{1 - e^{2 \gamma }}\end{aligned}\end{align} \]
prior_logp(z: Tensor) Tensor[source]
\[(2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right)\]

where \(\Sigma = I\) and \(\mathbf{\mu} = 0\)

prior_sampling(shape: Tuple) Tensor[source]
\[\epsilon \sim \mathbfcal{N}(0,1)\]
ls_mlkit.diffuser.sde.get_model_fn(model, train=False)[source]

Create a function to give the output of the score-based model.

Parameters:
  • model – The score model.

  • trainTrue for training and False for evaluation.

Returns:

A model function.

ls_mlkit.diffuser.sde.get_pc_sampler(sde: SDE, shape: Tuple[int, ...], predictor_class: Predictor, corrector_class: Corrector, inverse_scaler: Callable, snr: float, n_correct_steps: int = 1, use_probability_flow: bool = False, denoise_at_final: bool = True, eps: float = 0.001, device: str = 'cuda')[source]

Create a Predictor-Corrector (PC) sampler.

Parameters:
  • sde – An SDE object representing the forward SDE.

  • shape – A sequence of integers. The expected shape of a single sample. First dimension is batch size.

  • predictor_class – A subclass of Predictor representing the predictor algorithm.

  • corrector_class – A subclass of Corrector representing the corrector algorithm.

  • inverse_scaler – The inverse data normalizer.

  • snr – A float number. The signal-to-noise ratio for configuring correctors.

  • n_correct_steps – An integer. The number of corrector steps per predictor update.

  • use_probability_flow – If True, solve the reverse-time probability flow ODE when running the predictor.

  • denoise_at_final – If True, add one-step denoising to the final samples.

  • eps – A float number. The reverse-time SDE and ODE are integrated to epsilon to avoid numerical issues.

  • device – PyTorch device.

Returns:

A sampling function that returns samples and the number of function evaluations during sampling.

ls_mlkit.diffuser.sde.get_score_fn(sde, model, train=False, continuous=False)[source]

Wraps score_fn so that the model output corresponds to a real time-dependent score function.

Parameters:
  • sde – An sde_lib.SDE object that represents the forward SDE.

  • model – A score model.

  • trainTrue for training and False for evaluation.

  • continuous – If True, the score-based model is expected to directly take continuous time steps.

Returns:

A score function.