Source code for ls_mlkit.util.vector_utils
from typing import Callable
import torch
from torch import Tensor
[docs]
def get_vector_cosines(vectors: Tensor) -> Tensor:
"""
Args:
vectors: torch.Tensor, shape (..., n_nodes, 3)
Returns:
cosines: torch.Tensor, shape (..., n_nodes * n_nodes)
"""
rows_norms = torch.norm(vectors, dim=-1)
rows_norms = torch.where(rows_norms == 0, torch.tensor([1], device=rows_norms.device), rows_norms)
normalised_vectors = torch.einsum("...nd,...n->...nd", vectors, 1 / rows_norms)
cosines = torch.einsum("...ij,...kj->...ik", normalised_vectors, normalised_vectors)
return cosines
[docs]
def get_cosines_and_amplitudes(vectors: Tensor, mask: Tensor) -> tuple[Tensor, Tensor]:
cosines = get_vector_cosines(vectors)
torch.linalg.norm: Callable
amplitudes = torch.linalg.norm(vectors, dim=-1)
amplitudes = amplitudes * mask # type: ignore
amplitudes = torch.nn.functional.normalize(amplitudes, p=2, dim=-1) # type: ignore #
return cosines, amplitudes