Source code for ls_mlkit.diffuser.model_interface
import abc
from typing import Any
import torch
from torch import Tensor
[docs]
class Model4DiffuserInterface(abc.ABC):
def __init__(
self,
):
pass
[docs]
@abc.abstractmethod
def get_model_device(self) -> torch.device:
"""Get the device of the model
Returns:
torch.device: the device of the model
"""
@abc.abstractmethod
def __call__(self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> dict:
r"""Call the model
Args:
x_t (Tensor): the input tensor
t (Tensor): the time tensor
padding_mask (Tensor): the padding mask
*args: additional arguments
**kwargs: additional keyword arguments
Returns:
dict: the output of the model
"""