from typing import Any
import torch
from torch import Tensor
from ...util.se3 import T
from .masker_interface import MaskerInterface
[docs]
class BioCAOnlyMasker(MaskerInterface):
def __init__(self, ndim_mini_micro_shape: int = 1, **kwargs: dict[Any, Any]):
super().__init__(**kwargs)
self.ndim_mini_micro_shape: int = ndim_mini_micro_shape
[docs]
def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor:
self.check_mask_shape(x, mask)
return x * mask.view(*mask.shape, *[1 for _ in range(self.ndim_mini_micro_shape)])
[docs]
def check_mask_shape(self, x: Tensor, mask: Tensor):
if self.ndim_mini_micro_shape == 0:
assert x.shape == mask.shape
else:
assert x.shape[: -self.ndim_mini_micro_shape] == mask.shape
[docs]
def count_bright_area(self, mask: Tensor) -> Tensor:
"""
Bright area can be seen
Dark area cannot be seen
"""
return torch.sum(mask)
[docs]
def get_full_bright_mask(self, x: Tensor) -> Tensor:
"""
b, n, 3 -> b, n
"""
shape = x.shape[: -self.ndim_mini_micro_shape] if self.ndim_mini_micro_shape != 0 else x.shape
device = x.device
return torch.ones(shape, device=device)
[docs]
def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
"""
1 represents the region that can be seen
"""
self.check_mask_shape(x_0, inpainting_mask)
inpainting_mask = inpainting_mask.view(*inpainting_mask.shape, *[1 for _ in range(self.ndim_mini_micro_shape)])
return x_t * (1 - inpainting_mask) + x_0 * inpainting_mask
[docs]
class BioBackboneFrameMasker(MaskerInterface):
def __init__(self, **kwargs: dict[Any, Any]):
super().__init__(**kwargs)
self.trans_ndim_mini_micro_shape: int = 1
self.rots_ndim_mini_micro_shape: int = 2
[docs]
def apply_mask(self, x: T, mask: Tensor) -> Tensor:
r"""
Args:
x: T, (b,n,3,3) rotation matrix and (b,n,3) translation vector
mask: b, n
"""
frames: T = x
self.check_mask_shape(frames, mask)
frames.trans *= mask.view(*mask.shape, *[1 for _ in range(self.trans_ndim_mini_micro_shape)])
identity = torch.eye(3, device=frames.rots.device).view(1, 1, 3, 3)
identity = identity.expand(frames.rots.shape[0], frames.rots.shape[1], 3, 3)
rot_mask = mask.view(mask.shape[0], mask.shape[1], 1, 1)
frames.rots = frames.rots * rot_mask + identity * (1 - rot_mask)
return frames
[docs]
def check_mask_shape(self, x: T, mask: Tensor):
assert x.trans[: -self.trans_ndim_mini_micro_shape] == mask.shape
[docs]
def count_bright_area(self, mask: Tensor) -> Tensor:
"""
Bright area can be seen
Dark area cannot be seen
"""
return torch.sum(mask)
[docs]
def get_full_bright_mask(self, x: T) -> Tensor:
"""
b, n, 3 -> b, n
"""
x: Tensor = x.trans
shape = x.shape[: -self.trans_ndim_mini_micro_shape]
device = x.device
return torch.ones(shape, device=device)
[docs]
def apply_inpainting_mask(self, x_0: T, x_t: T, inpainting_mask: Tensor) -> Tensor:
"""
1 represents the region that can be seen
"""
self.check_mask_shape(x_0, inpainting_mask)
b, n = x_0.rots.shape[:2]
inpainting_mask = inpainting_mask.view(
*inpainting_mask.shape, *[1 for _ in range(self.trans_ndim_mini_micro_shape)]
)
x_t.trans = x_t.trans * (1 - inpainting_mask) + x_0.trans * inpainting_mask
rot_inpainting_mask = inpainting_mask.view(b, n, 1, 1)
x_t.rots = x_t.rots * (1 - rot_inpainting_mask) + x_0.rots * rot_inpainting_mask
[docs]
class BioSO3Masker(MaskerInterface):
def __init__(self, **kwargs: dict[Any, Any]):
super().__init__(**kwargs)
self.ndim_mini_micro_shape: int = 2
[docs]
def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor:
r"""
Args:
x: (b,n,3,3) rotation matrix
mask: b, n
"""
b, n = x.shape[:2]
device = x.device
self.check_mask_shape(x, mask)
identity = torch.eye(3, device=device).view(1, 1, 3, 3)
identity = identity.expand(b, n, -1, -1)
rot_mask = mask.view(b, n, 1, 1)
x = x * rot_mask + identity * (1 - rot_mask)
return x
[docs]
def check_mask_shape(self, x: Tensor, mask: Tensor):
# print(f"x.shape: {x.shape}, mask.shape: {mask.shape}")
assert x.shape[: -self.ndim_mini_micro_shape] == mask.shape
[docs]
def count_bright_area(self, mask: Tensor) -> Tensor:
"""
Bright area can be seen
Dark area cannot be seen
"""
return torch.sum(mask)
[docs]
def get_full_bright_mask(self, x: Tensor) -> Tensor:
"""
b, n, 3 -> b, n
"""
shape = x.shape[: -self.ndim_mini_micro_shape]
device = x.device
return torch.ones(shape, device=device)
[docs]
def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
"""
1 represents the region that can be seen
"""
self.check_mask_shape(x_0, inpainting_mask)
b, n = x_0.shape[:2]
rot_inpainting_mask = inpainting_mask.view(b, n, 1, 1)
x_t.rots = x_t * (1 - rot_inpainting_mask) + x_0 * rot_inpainting_mask