Source code for ls_mlkit.diffuser.manifold_diffuser

r"""
Riemannian Manifold Diffuser
"""

from typing import Any

from ..util.decorators import inherit_docstrings
from ..util.manifold.riemannian_manifold import RiemannianManifold
from .base_diffuser import BaseDiffuser, BaseDiffuserConfig
from .time_scheduler import TimeScheduler


[docs] @inherit_docstrings class RiemannianManifoldDiffuserConfig(BaseDiffuserConfig): """ Riemannian Manifold Diffuser Config """ def __init__(self, n_discretization_steps: int, ndim_micro_shape: int, *args: list[Any], **kwargs: dict[Any, Any]): super().__init__( n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, *args, **kwargs )
[docs] @inherit_docstrings class RiemannianManifoldDiffuser(BaseDiffuser): """ Riemannian Manifold Diffuser """ def __init__( self, config: RiemannianManifoldDiffuserConfig, time_scheduler: TimeScheduler, riemannian_manifold: RiemannianManifold, ): """Initialize the RiemannianManifoldDiffuser Args: config (RiemannianManifoldDiffuserConfig): the config of the RiemannianManifoldDiffuser time_scheduler (TimeScheduler): the time scheduler of the RiemannianManifoldDiffuser riemannian_manifold (RiemannianManifold): the Riemannian manifold of the RiemannianManifoldDiffuser """ super().__init__( config=config, time_scheduler=time_scheduler, ) self.riemannian_manifold = riemannian_manifold