ls_mlkit.diffuser.sde.corrector module

class ls_mlkit.diffuser.sde.corrector.Corrector(sde: SDE, score_fn: object, snr: float, n_steps: int)[source]

Bases: ABC

The abstract class for a corrector algorithm.

abstractmethod update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.corrector.LangevinCorrector(sde: SDE, score_fn: object, snr: float, n_steps: int, n_dim: int = 2)[source]

Bases: Corrector

update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.diffuser.sde.corrector.NoneCorrector(sde, score_fn, snr, n_steps)[source]

Bases: Corrector

An empty corrector that does nothing.

update_fn(x, t, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x