Source code for ls_mlkit.util.mask.image_masker

from typing import Any

import torch
from torch import Tensor

from .masker_interface import MaskerInterface


[docs] class ImageMasker(MaskerInterface): def __init__(self, ndim_mini_micro_shape: int = 0, **kwargs: dict[Any, Any]): super().__init__(**kwargs) self.ndim_mini_micro_shape: int = ndim_mini_micro_shape
[docs] def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor: self.check_mask_shape(x, mask) if self.ndim_mini_micro_shape == 0: return x * mask else: return x * mask.view(*mask.shape, *[1 for _ in range(self.ndim_mini_micro_shape)])
[docs] def check_mask_shape(self, x: Tensor, mask: Tensor): if self.ndim_mini_micro_shape == 0: if mask.shape[-3] == 1: mask = mask.expand(-1, x.shape[-3], -1, -1) assert x.shape == mask.shape else: assert x.shape[: -self.ndim_mini_micro_shape] == mask.shape
[docs] def count_bright_area(self, mask: Tensor) -> Tensor: r""" Bright area can be seen Dark area cannot be seen """ return torch.sum(mask)
[docs] def get_full_bright_mask(self, x: Tensor) -> Tensor: if self.ndim_mini_micro_shape == 0: shape = x.shape else: shape = x.shape[: -self.ndim_mini_micro_shape] device = x.device return torch.ones(shape, device=device)
[docs] def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor: r""" 1 represents the region that can be seen """ self.check_mask_shape(x_0, inpainting_mask) inpainting_mask = inpainting_mask.view(*inpainting_mask.shape, *[1 for _ in range(self.ndim_mini_micro_shape)]) return x_t * (1 - inpainting_mask) + x_0 * inpainting_mask