Source code for ls_mlkit.diffuser.sde.sampler

import functools
from typing import Callable, Tuple

import torch

from .base_sde import SDE
from .corrector import Corrector, NoneCorrector
from .predictor import NonePredictor, Predictor


[docs] def shared_predictor_update_fn(x, t, sde, score_fn, predictor_class, use_probability_flow): if predictor_class is None: # Corrector-only sampler predictor_obj = NonePredictor(sde, score_fn, use_probability_flow) else: predictor_obj = predictor_class(sde, score_fn, use_probability_flow) return predictor_obj.update_fn(x, t)
[docs] def shared_corrector_update_fn(x, t, sde, score_fn, corrector_class, snr, n_steps): if corrector_class is None: corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) else: corrector_obj = corrector_class(sde, score_fn, snr, n_steps) return corrector_obj.update_fn(x, t)
[docs] def get_pc_sampler( sde: SDE, shape: Tuple[int, ...], predictor_class: Predictor, corrector_class: Corrector, inverse_scaler: Callable, snr: float, n_correct_steps: int = 1, use_probability_flow: bool = False, denoise_at_final: bool = True, eps: float = 1e-3, device: str = "cuda", ): """Create a Predictor-Corrector (PC) sampler. Args: sde: An `SDE` object representing the forward SDE. shape: A sequence of integers. The expected shape of a single sample. First dimension is batch size. predictor_class: A subclass of `Predictor` representing the predictor algorithm. corrector_class: A subclass of `Corrector` representing the corrector algorithm. inverse_scaler: The inverse data normalizer. snr: A `float` number. The signal-to-noise ratio for configuring correctors. n_correct_steps: An integer. The number of corrector steps per predictor update. use_probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor. denoise_at_final: If `True`, add one-step denoising to the final samples. eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. device: PyTorch device. Returns: A sampling function that returns samples and the number of function evaluations during sampling. """ # Create predictor & corrector update functions predictor_update_fn = functools.partial( shared_predictor_update_fn, sde=sde, predictor_class=predictor_class, use_probability_flow=use_probability_flow ) corrector_update_fn = functools.partial( shared_corrector_update_fn, sde=sde, corrector_class=corrector_class, snr=snr, n_steps=n_correct_steps ) def pc_sampler(score_fn): with torch.no_grad(): x = sde.prior_sampling(shape).to(device) timesteps = torch.linspace(sde.T, eps, sde.n_discretization_steps, device=device) for i in range(sde.n_discretization_steps): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t x, x_mean = corrector_update_fn(x, vec_t, score_fn=score_fn) x, x_mean = predictor_update_fn(x, vec_t, score_fn=score_fn) return inverse_scaler(x_mean if denoise_at_final else x), sde.n_discretization_steps * (n_correct_steps + 1) return pc_sampler