Source code for ls_mlkit.diffuser.lie_group_diffuser
r"""
Lie Group Diffuser
"""
from typing import Any
from ..util.decorators import inherit_docstrings
from ..util.manifold.lie_group import LieGroup
from .manifold_diffuser import RiemannianManifoldDiffuser, RiemannianManifoldDiffuserConfig
from .time_scheduler import TimeScheduler
[docs]
@inherit_docstrings
class LieGroupDiffuserConfig(RiemannianManifoldDiffuserConfig):
"""
Lie Group Diffuser Config
"""
def __init__(self, n_discretization_steps: int, ndim_micro_shape: int, *args: list[Any], **kwargs: dict[Any, Any]):
super().__init__(
n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, *args, **kwargs
)
[docs]
@inherit_docstrings
class LieGroupDiffuser(RiemannianManifoldDiffuser):
"""
Riemannian Manifold Diffuser
"""
def __init__(
self,
config: LieGroupDiffuserConfig,
time_scheduler: TimeScheduler,
lie_group: LieGroup,
):
"""Initialize the LieGroupDiffuser
Args:
config (LieGroupDiffuserConfig): the config of the LieGroupDiffuser
time_scheduler (TimeScheduler): the time scheduler of the LieGroupDiffuser
lie_group (LieGroup): the Lie group of the LieGroupDiffuser
"""
super().__init__(
config=config,
time_scheduler=time_scheduler,
riemannian_manifold=lie_group,
)
self.lie_group = lie_group