import functools
from typing import Callable, Tuple
import torch
from .base_sde import SDE
from .corrector import Corrector, NoneCorrector
from .predictor import NonePredictor, Predictor
[docs]
def shared_predictor_update_fn(x, t, sde, score_fn, predictor_class, use_probability_flow):
if predictor_class is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, use_probability_flow)
else:
predictor_obj = predictor_class(sde, score_fn, use_probability_flow)
return predictor_obj.update_fn(x, t)
[docs]
def shared_corrector_update_fn(x, t, sde, score_fn, corrector_class, snr, n_steps):
if corrector_class is None:
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector_class(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t)
[docs]
def get_pc_sampler(
sde: SDE,
shape: Tuple[int, ...],
predictor_class: Predictor,
corrector_class: Corrector,
inverse_scaler: Callable,
snr: float,
n_correct_steps: int = 1,
use_probability_flow: bool = False,
denoise_at_final: bool = True,
eps: float = 1e-3,
device: str = "cuda",
):
"""Create a Predictor-Corrector (PC) sampler.
Args:
sde: An `SDE` object representing the forward SDE.
shape: A sequence of integers. The expected shape of a single sample. First dimension is batch size.
predictor_class: A subclass of `Predictor` representing the predictor algorithm.
corrector_class: A subclass of `Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_correct_steps: An integer. The number of corrector steps per predictor update.
use_probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
denoise_at_final: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(
shared_predictor_update_fn, sde=sde, predictor_class=predictor_class, use_probability_flow=use_probability_flow
)
corrector_update_fn = functools.partial(
shared_corrector_update_fn, sde=sde, corrector_class=corrector_class, snr=snr, n_steps=n_correct_steps
)
def pc_sampler(score_fn):
with torch.no_grad():
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, eps, sde.n_discretization_steps, device=device)
for i in range(sde.n_discretization_steps):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, score_fn=score_fn)
x, x_mean = predictor_update_fn(x, vec_t, score_fn=score_fn)
return inverse_scaler(x_mean if denoise_at_final else x), sde.n_discretization_steps * (n_correct_steps + 1)
return pc_sampler