Source code for ls_mlkit.diffuser.sde.sde_lib

from typing import Tuple

import numpy as np
import torch
from overrides import override
from torch import Tensor

from .base_sde import SDE


[docs] class VPSDE(SDE): def __init__( self, beta_min: float = 0.1, beta_max: float = 20, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2 ): r"""Construct a Variance Preserving SDE. Args: beta_min: value of beta(0) beta_max: value of beta(1) n_discretization_steps: number of discretization steps ndim_micro_shape: number of dimensions of a sample """ super().__init__(n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape) self.beta_0 = beta_min self.beta_1 = beta_max self.discrete_betas = torch.linspace( beta_min / n_discretization_steps, beta_max / n_discretization_steps, n_discretization_steps ) self.alphas = 1.0 - self.discrete_betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) # mean self.sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) # std @property @override def T(self) -> float: return 1
[docs] @override def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r"""continuous DDPM SDE .. math:: dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw Args: x: t: (macro_shape) mask: Returns: drift: shape = x.shape diffusion: shape=x.macro_shape """ macro_shape = x.shape[: self.ndim_micro_shape] beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) * x diffusion = torch.sqrt(beta_t) return drift, diffusion
[docs] @override def get_discretized_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: """DDPM discretization.""" timestep = (t * (self.n_discretization_steps - 1) / self.T).long() beta = self.discrete_betas.to(x.device)[timestep] alpha = self.alphas.to(x.device)[timestep] sqrt_beta = torch.sqrt(beta) f = torch.sqrt(alpha).view(alpha.shape[0], *[1 for _ in range(self.ndim_micro_shape)]) * x - x g = sqrt_beta return f, g
[docs] def get_target_score(self, x_0: Tensor, x_t: Tensor, t: Tensor, mask: Tensor, continuous: bool = False) -> Tensor: r""" .. math:: p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0) """ mu = None # $$E_{x_t\sim p_{0t}(x_t|x_0)}[x_t]$$ sigma = None # $$\sqrt{Var_{x_t\sim p_{0t}(x_t|x_0)}[x_t]}$$ macro_shape = x_t.shape[: self.ndim_micro_shape] if continuous: mu, sigma = self.marginal_prob(x_0, t, mask=None) else: mu = self.sqrt_alphas_cumprod[t].view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) * x_0 sigma = self.sqrt_1m_alphas_cumprod[t].view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) score = -(x_t - mu) / sigma**2 return score
[docs] def marginal_prob(self, x_0: Tensor, t: Tensor, mask: Tensor = None) -> Tuple[Tensor, Tensor]: r""" .. math:: p_{0t} (x_t|x_0) .. math:: \gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t \beta_0 mean = e^{\gamma} * x std = \sqrt{1 - e^{2 \gamma }} """ log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 macro_shape = x_0.shape[: self.ndim_micro_shape] log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) mean = torch.exp(log_mean_coeff) * x_0 std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) return mean, std
[docs] def prior_sampling(self, shape: Tuple) -> Tensor: r""" .. math:: \epsilon \sim \mathbfcal{N}(0,1) """ return torch.randn(*shape)
[docs] def prior_logp(self, z: torch.Tensor) -> Tensor: r""" .. math:: (2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right) where :math:`\Sigma = I` and :math:`\mathbf{\mu} = 0` """ shape = z.shape N = np.prod(shape[1:]) logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0 return logps
[docs] class SubVPSDE(SDE): def __init__( self, beta_min: float = 0.1, beta_max: float = 20, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2 ): """Construct the sub-VP SDE that excels at likelihoods. Args: beta_min: value of beta(0) beta_max: value of beta(1) n_discretization_steps: number of discretization steps ndim_micro_shape: number of dimensions of a sample """ super().__init__(n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape) self.beta_0 = beta_min self.beta_1 = beta_max @property @override def T(self) -> float: return 1
[docs] @override def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) macro_shape = x.shape[: self.ndim_micro_shape] beta_t = beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) drift = -0.5 * beta_t * x discount = 1.0 - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t**2) diffusion = torch.sqrt(beta_t * discount) return drift, diffusion
[docs] def marginal_prob(self, x, t, mask=None): log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 macro_shape = x.shape[: self.ndim_micro_shape] log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) mean = torch.exp(log_mean_coeff) * x std = 1 - torch.exp(2.0 * log_mean_coeff) return mean, std
[docs] def prior_sampling(self, shape): return torch.randn(*shape)
[docs] def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0
[docs] class VESDE(SDE): def __init__( self, sigma_min=0.01, sigma_max=50, n_discretization_steps=1000, ndim_micro_shape=2, drop_first_step=False ): """Construct a Variance Exploding SDE. Args: sigma_min: smallest sigma. sigma_max: largest sigma. n_discretization_steps: number of discretization steps ndim_micro_shape: number of dimensions of a sample """ super().__init__(n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape) self.sigma_min = sigma_min self.sigma_max = sigma_max self.drop_first_step = drop_first_step sigma_min = torch.tensor(sigma_min) sigma_max = torch.tensor(sigma_max) if drop_first_step: self.discrete_sigmas = ( 10 ** torch.linspace(torch.log10(sigma_min), torch.log10(sigma_max), n_discretization_steps + 1)[1:] ) else: self.discrete_sigmas = torch.exp( torch.linspace(torch.log(sigma_min), torch.log(sigma_max), n_discretization_steps) ) @property @override def T(self) -> float: return 1
[docs] @override def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r""" .. math:: dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} """ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t drift = torch.zeros_like(x) diffusion = sigma * torch.sqrt( torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device) ) return drift, diffusion
[docs] @override def get_discretized_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r"""SMLD(NCSN) discretization. .. math:: x_t &= x_0 + g \epsilon x_t &\sim \mathcal{N}(x_0, \sigma_t^2) \sigma_t^2 &= \sigma_{t-1}^2 + g^2 g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2} """ timestep = (t * (self.n_discretization_steps - 1) / self.T).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) ) f = torch.zeros_like(x) g = torch.sqrt(sigma**2 - adjacent_sigma**2) return f, g
[docs] def marginal_prob(self, x, t, mask=None): std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t mean = x return mean, std
[docs] def prior_sampling(self, shape): return torch.randn(*shape) * self.sigma_max
[docs] def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - torch.sum(z**2, dim=(1, 2, 3)) / ( 2 * self.sigma_max**2 )