ls_mlkit.diffuser.sde.base_sde module

Abstract SDE classes

Note

t is always continous time step in this module.

class ls_mlkit.diffuser.sde.base_sde.SDE(n_discretization_steps: int, ndim_micro_shape: int = 2)[source]

Bases: ABC

SDE abstract class. Functions are designed for a mini-batch of inputs.

abstract property T: float

End time of the SDE.

get_diffusion_coefficient_with_proper_shape(x: Tensor, diffusion: Tensor) Tensor[source]

Get the diffusion coefficient with the proper shape. Complete the micro shape of the diffusion coefficient.

Parameters:
  • x (Tensor) – the sample.

  • diffusion (Tensor) – the diffusion coefficient.

Returns:

the diffusion coefficient with the proper shape.

Return type:

Tensor

get_discretization_steps(t: Tensor) Tensor[source]

Get the discretization steps.

Parameters:

t (Tensor) – the time step.

Returns:

the discretization steps.

Return type:

Tensor

get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Euler-Maruyama discretization.

\[ \begin{align}\begin{aligned}dx &= f(x, t)dt + g(x,t)d z\\x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,|\Delta t|))\end{aligned}\end{align} \]
Parameters:
  • x – a torch tensor

  • t – a torch float representing the time step (from 0 to self.T)

  • mask – 1 indicates valid region, 0 indicates invalid region

Note

Here dt always greater than 0.

Returns:

f, g

\[ \begin{align}\begin{aligned}f &= f(x,t) |\Delta t|\\g &= g(x,t) \sqrt{|\Delta t|}\end{aligned}\end{align} \]

abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

get_reverse_sde(score_fn: object, use_probability_flow=False)[source]

Create the reverse-time SDE/ODE.

Parameters:
  • score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.

  • use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.