ls_mlkit.diffuser.sde.base_sde module¶
Abstract SDE classes
Note
t is always continous time step in this module.
- class ls_mlkit.diffuser.sde.base_sde.SDE(n_discretization_steps: int, ndim_micro_shape: int = 2)[source]¶
Bases:
ABCSDE abstract class. Functions are designed for a mini-batch of inputs.
- abstract property T: float¶
End time of the SDE.
- get_diffusion_coefficient_with_proper_shape(x: Tensor, diffusion: Tensor) Tensor[source]¶
Get the diffusion coefficient with the proper shape. Complete the micro shape of the diffusion coefficient.
- Parameters:
x (
Tensor) – the sample.diffusion (
Tensor) – the diffusion coefficient.
- Returns:
the diffusion coefficient with the proper shape.
- Return type:
Tensor
- get_discretization_steps(t: Tensor) Tensor[source]¶
Get the discretization steps.
- Parameters:
t (
Tensor) – the time step.- Returns:
the discretization steps.
- Return type:
Tensor
- get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]¶
Euler-Maruyama discretization.
\[ \begin{align}\begin{aligned}dx &= f(x, t)dt + g(x,t)d z\\x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,|\Delta t|))\end{aligned}\end{align} \]- Parameters:
x – a torch tensor
t – a torch float representing the time step (from 0 to
self.T)mask – 1 indicates valid region, 0 indicates invalid region
Note
Here
dtalways greater than 0.- Returns:
f, g
\[ \begin{align}\begin{aligned}f &= f(x,t) |\Delta t|\\g &= g(x,t) \sqrt{|\Delta t|}\end{aligned}\end{align} \]
- abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]¶
Get the drift and diffusion of the SDE.
- Parameters:
x (
Tensor) – the sample.t (
Tensor) – the time step.mask (
Tensor, optional) – the mask of the sample. Defaults to None.
- Returns:
the drift and diffusion of the SDE.
- Return type:
Tuple[Tensor, Tensor]
- get_reverse_sde(score_fn: object, use_probability_flow=False)[source]¶
Create the reverse-time SDE/ODE.
- Parameters:
score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.
use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.