ls_mlkit.diffuser.sde.sampler module¶
- ls_mlkit.diffuser.sde.sampler.get_pc_sampler(sde: SDE, shape: Tuple[int, ...], predictor_class: Predictor, corrector_class: Corrector, inverse_scaler: Callable, snr: float, n_correct_steps: int = 1, use_probability_flow: bool = False, denoise_at_final: bool = True, eps: float = 0.001, device: str = 'cuda')[source]¶
Create a Predictor-Corrector (PC) sampler.
- Parameters:
sde – An SDE object representing the forward SDE.
shape – A sequence of integers. The expected shape of a single sample. First dimension is batch size.
predictor_class – A subclass of Predictor representing the predictor algorithm.
corrector_class – A subclass of Corrector representing the corrector algorithm.
inverse_scaler – The inverse data normalizer.
snr – A float number. The signal-to-noise ratio for configuring correctors.
n_correct_steps – An integer. The number of corrector steps per predictor update.
use_probability_flow – If True, solve the reverse-time probability flow ODE when running the predictor.
denoise_at_final – If True, add one-step denoising to the final samples.
eps – A float number. The reverse-time SDE and ODE are integrated to epsilon to avoid numerical issues.
device – PyTorch device.
- Returns:
A sampling function that returns samples and the number of function evaluations during sampling.