Source code for ls_mlkit.util.shape
from typing import Tuple, Union
from torch import Tensor
[docs]
def get_macroscopic_shape(obj: Union[Tensor, tuple], ndim_microscopic: int) -> Tuple[int]:
"""
Get the macroscopic shape of an object.
"""
if isinstance(obj, tuple):
if len(obj) == ndim_microscopic:
result = (1,)
else:
result = obj[:-ndim_microscopic]
elif isinstance(obj, Tensor):
if obj.ndim == ndim_microscopic:
result = (1,)
else:
result = obj.shape[:-ndim_microscopic]
else:
raise ValueError(f"Invalid type: {type(obj)}")
return result