Source code for ls_mlkit.util.mask.masker_interface

import abc
from typing import Any

from torch import Tensor


[docs] class MaskerInterface(abc.ABC): def __init__(self, *args, **kwargs: dict[Any, Any]): self.args = args self.kwargs: dict[Any, Any] = kwargs
[docs] @abc.abstractmethod def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor: pass
[docs] @abc.abstractmethod def check_mask_shape(self, x: Tensor, mask: Tensor): """ check whether the shape of mask is as expected """
[docs] @abc.abstractmethod def count_bright_area(self, mask: Tensor) -> Tensor: """ Bright area can be seen Dark area cannot be seen """
[docs] @abc.abstractmethod def get_full_bright_mask(self, x: Tensor) -> Tensor: """ Return a mask that is all bright """
[docs] @abc.abstractmethod def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor: """ 1 represents the region that can be seen """