Source code for ls_mlkit.util.se3

# Adapted from OpenFold
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torch import Tensor


# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
[docs] def rot_matmul(a, b): e = ... row_1 = torch.stack( [ a[e, 0, 0] * b[e, 0, 0] + a[e, 0, 1] * b[e, 1, 0] + a[e, 0, 2] * b[e, 2, 0], a[e, 0, 0] * b[e, 0, 1] + a[e, 0, 1] * b[e, 1, 1] + a[e, 0, 2] * b[e, 2, 1], a[e, 0, 0] * b[e, 0, 2] + a[e, 0, 1] * b[e, 1, 2] + a[e, 0, 2] * b[e, 2, 2], ], dim=-1, ) row_2 = torch.stack( [ a[e, 1, 0] * b[e, 0, 0] + a[e, 1, 1] * b[e, 1, 0] + a[e, 1, 2] * b[e, 2, 0], a[e, 1, 0] * b[e, 0, 1] + a[e, 1, 1] * b[e, 1, 1] + a[e, 1, 2] * b[e, 2, 1], a[e, 1, 0] * b[e, 0, 2] + a[e, 1, 1] * b[e, 1, 2] + a[e, 1, 2] * b[e, 2, 2], ], dim=-1, ) row_3 = torch.stack( [ a[e, 2, 0] * b[e, 0, 0] + a[e, 2, 1] * b[e, 1, 0] + a[e, 2, 2] * b[e, 2, 0], a[e, 2, 0] * b[e, 0, 1] + a[e, 2, 1] * b[e, 1, 1] + a[e, 2, 2] * b[e, 2, 1], a[e, 2, 0] * b[e, 0, 2] + a[e, 2, 1] * b[e, 1, 2] + a[e, 2, 2] * b[e, 2, 2], ], dim=-1, ) return torch.stack([row_1, row_2, row_3], dim=-2)
[docs] def rot_vec_mul(r, t): x = t[..., 0] y = t[..., 1] z = t[..., 2] return torch.stack( [ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, ], dim=-1, )
[docs] class T: def __init__(self, rots, trans): self.rots = rots self.trans = trans if self.rots is None and self.trans is None: raise ValueError("Only one of rots and trans can be None") elif self.rots is None: self.rots = T.identity_rot(self.trans.shape[:-1], self.trans.dtype, self.trans.device) elif self.trans is None: self.trans = T.identity_trans(self.rots.shape[:-2], self.rots.dtype, self.rots.device) if self.rots.shape[-2:] != (3, 3) or self.trans.shape[-1] != 3 or self.rots.shape[:-2] != self.trans.shape[:-1]: raise ValueError("Incorrectly shaped input") def __getitem__(self, index): if type(index) != tuple: index = (index,) return T(self.rots[index + (slice(None), slice(None))], self.trans[index + (slice(None),)]) def __eq__(self, obj): return torch.all(self.rots == obj.rots) and torch.all(self.trans == obj.trans) def __mul__(self, right): rots = self.rots * right[..., None, None] trans = self.trans * right[..., None] return T(rots, trans) def __rmul__(self, left): return self.__mul__(left)
[docs] def to(self, device): if isinstance(device, T): self.trans = self.trans.to(device.get_trans()) self.rots = self.rots.to(device.get_rots()) else: self.trans = self.trans.to(device) self.rots = self.rots.to(device) return self
@property def shape(self): s = self.rots.shape[:-2] return s if len(s) > 0 else torch.Size([1])
[docs] def get_trans(self): return self.trans
[docs] def get_rots(self): return self.rots
[docs] def compose(self, t): rot_1, trn_1 = self.rots, self.trans rot_2, trn_2 = t.rots, t.trans rot = rot_matmul(rot_1, rot_2) trn = rot_vec_mul(rot_1, trn_2) + trn_1 return T(rot, trn)
[docs] def apply(self, pts): r, t = self.rots, self.trans rotated = rot_vec_mul(r, pts) return rotated + t
[docs] def invert_apply(self, pts): r, t = self.rots, self.trans pts = pts - t return rot_vec_mul(r.transpose(-1, -2), pts)
[docs] def invert(self): rot_inv = self.rots.transpose(-1, -2) trn_inv = rot_vec_mul(rot_inv, self.trans) return T(rot_inv, -1 * trn_inv)
[docs] def unsqueeze(self, dim): if dim >= len(self.shape): raise ValueError("Invalid dimension") rots = self.rots.unsqueeze(dim if dim >= 0 else dim - 2) trans = self.trans.unsqueeze(dim if dim >= 0 else dim - 1) return T(rots, trans)
[docs] @staticmethod def identity_rot(shape, dtype, device, requires_grad=False): rots = torch.eye(3, dtype=dtype, device=device, requires_grad=requires_grad) rots = rots.view(*((1,) * len(shape)), 3, 3) rots = rots.expand(*shape, -1, -1) return rots
[docs] @staticmethod def identity_trans(shape, dtype, device, requires_grad=False): trans = torch.zeros((*shape, 3), dtype=dtype, device=device, requires_grad=requires_grad) return trans
[docs] @staticmethod def identity(shape, dtype, device, requires_grad=False): return T( T.identity_rot(shape, dtype, device, requires_grad), T.identity_trans(shape, dtype, device, requires_grad), )
[docs] @staticmethod def from_4x4(t): rots = t[..., :3, :3] trans = t[..., :3, 3] return T(rots, trans)
[docs] def to_4x4(self): tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device) tensor[..., :3, :3] = self.rots tensor[..., :3, 3] = self.trans tensor[..., 3, 3] = 1 return tensor
[docs] @staticmethod def from_tensor(t): return T.from_4x4(t)
[docs] @staticmethod def rigid_from_3_points(x_1: Tensor, x_2: Tensor, x_3: Tensor, eps: float = 1e-8): v1 = x_3 - x_2 v2 = x_1 - x_2 e1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + eps) u2 = v2 - (torch.einsum("...li, ...li -> ...l", e1, v2)[..., None] * e1) e2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps) e3 = torch.cross(e1, e2, dim=-1) R = torch.cat([e1[..., None], e2[..., None], e3[..., None]], axis=-1) # [B,L,3,3] - rotation matrix return T(R, x_2)
[docs] @staticmethod def concat(ts, dim): rots = torch.cat([t.rots for t in ts], dim=dim if dim >= 0 else dim - 2) trans = torch.cat([t.trans for t in ts], dim=dim if dim >= 0 else dim - 1) return T(rots, trans)
[docs] def map_tensor_fn(self, fn): """Apply a function that takes a tensor as its only argument to the rotations and translations, treating the final two/one dimension(s), respectively, as batch dimensions. E.g.: Given t, an instance of T of shape [N, M], this function can be used to sum out the second dimension thereof as follows: t = t.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) The resulting object has rotations of shape [N, 3, 3] and translations of shape [N, 3] """ rots = self.rots.view(*self.rots.shape[:-2], 9) rots = torch.stack(list(map(fn, torch.unbind(rots, -1))), dim=-1) rots = rots.view(*rots.shape[:-1], 3, 3) trans = torch.stack(list(map(fn, torch.unbind(self.trans, -1))), dim=-1) return T(rots, trans)
[docs] def stop_rot_gradient(self): return T(self.rots.detach(), self.trans)
[docs] def scale_translation(self, factor): return T(self.rots, self.trans * factor)
_quat_elements = ["a", "b", "c", "d"] _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} def _to_mat(pairs): mat = torch.zeros((4, 4)) for pair in pairs: key, value = pair ind = _qtr_ind_dict[key] mat[ind // 4][ind % 4] = value return mat _qtr_mat = torch.zeros((4, 4, 3, 3)) _qtr_mat[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) _qtr_mat[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) _qtr_mat[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) _qtr_mat[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) _qtr_mat[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) _qtr_mat[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) _qtr_mat[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) _qtr_mat[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) _qtr_mat[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
[docs] def quat_to_rot(quat): # [*, 4] # [*, 4, 4] quat = quat[..., None] * quat[..., None, :] # [*, 4, 4, 3, 3] shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3)) quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device) # [*, 3, 3] return torch.sum(quat, dim=(-3, -4))