Source code for ls_mlkit.util.nma.anm

from typing import Callable

import biotite.structure as struc
import einops
import torch
from torch import Tensor

from .force_fields import ForceField

K_B = 1.380649e-23  # Boltzmann constant, J/K
N_A = 6.02214076e23  # Avogadro constant, mol^-1


r"""
Boltzmann distribution
$$
P(x) \propto \exp \left(-\frac{1}{2k_BT}(x-x_0)^T H(x-x_0)\right)
$$
Multi-dimensional Gaussian distribution
$$
P(x) \propto \exp \left(-\frac{1}{2}(x-\mu)^T \Sigma^{-1}(x-\mu)\right)
$$
"""


[docs] class ANM: """ Anisotropic Network Model. Args: hessian : tensor, shape=``(..., n*3, n*3)``, dtype=float The *Hessian* matrix for this model. Each dimension is partitioned in the form ``[x1, y1, z1, ... xn, yn, zn]``. This is not a copy: Create a copy before modifying this matrix. masses : tensor, shape=(..., n), dtype=float The mass for each atom, `None` if no mass weighting is applied. """ def __init__( self, atoms: Tensor, force_field: ForceField, masses=None, device="cuda", node_mask: Tensor = None, ): """ Args atoms : tensor, shape=(..., n, 3), dtype=float Atom coordinates that are part of the model. It usually contains only CA atoms. force_field : ForceField, natoms=(..., n) The :class:`ForceField` that defines the force constants between the given `atoms`. masses : ndarray, shape=(..., n), dtype=float, optional If an array is given, the Hessian is weighted with the inverse square root of the given masses. By default no mass-weighting is applied. 0 for invalid nodes. device : str, optional node_mask : tensor, shape=(..., n), dtype=long, optional, 1 for valid nodes, 0 for invalid nodes """ self._coord = atoms self._ff = force_field self.device = device self._node_mask = node_mask if masses is None: self._masses = None else: assert ( masses.shape == atoms.shape[:-1] ), f"shape(masses) = {masses.shape} != shape(atoms[:-1]) = {atoms.shape[:-1]}" if torch.any(masses < 0): raise ValueError("Masses must not be negative") self._masses = masses * node_mask if self._masses is not None: mass_weights = torch.where( node_mask == 1, 1 / torch.sqrt(self._masses), torch.zeros_like(self._masses) ) # (..., n) mass_weights = mass_weights.repeat_interleave(repeats=3, dim=-1) # (..., 3n) self._mass_weight_matrix = torch.einsum("...i,...j->...ij", mass_weights, mass_weights) """ Shape of mass_weights: (..., n) Shape of self._mass_weight_matrix: (..., 3n, 3n) """ else: self._mass_weight_matrix = None self._hessian = None @property def masses(self): return self._masses @property def hessian(self): if self._hessian is None: self._hessian, _ = self.compute_hessian( self._coord, self._ff, device=self.device, use_cell_list=False, node_mask=self._node_mask ) if self._mass_weight_matrix is not None: self._hessian *= self._mass_weight_matrix self._hessian *= torch.einsum( "...i,...j->...ij", self._node_mask.repeat_interleave(repeats=3, dim=-1), self._node_mask.repeat_interleave(repeats=3, dim=-1), ) return self._hessian @hessian.setter def hessian(self, value): self._hessian = value
[docs] def eigen(self, epsilon=1e-7): """ Compute the Eigenvalues and Eigenvectors of the *Hessian* matrix. The first six Eigenvalues/Eigenvectors correspond to trivial modes (translations/rotations) and are usually omitted in normal mode analysis. Returns: eig_values : tensor, shape=(..., n*3), dtype=float Eigenvalues of the *Hessian* matrix in ascending order. eig_vectors : tensor, shape=(..., n*3, n*3), dtype=float Eigenvectors of the *Hessian* matrix. Eigenvectors will have the same dtype as the *Hessian* matrix and will contain the eigenvectors as its columns. """ torch.linalg.eigh: Callable eig_values, eig_vectors = torch.linalg.eigh(self.hessian + epsilon * torch.randn_like(self.hessian)) return eig_values, eig_vectors
[docs] def get_displacements_from_normal_modes(self, indexes: list[int]): """ Get the displacement vectors for the given normal modes. Args: indexes: list of integers, the indexes of the normal modes. Returns: displacement_vectors: tensor of shape (..., 3n, len(indexes)), where n is the number of atoms. """ k = len(indexes) node_mask = self._node_mask # (..., n) n = node_mask.shape[-1] skips = (node_mask == 0).sum(dim=-1, keepdim=True) # (..., 1) skips = skips.expand(*skips.shape[:-1], k) # (..., k) skips *= 3 indexes: Tensor = torch.tensor(indexes, device=self.device, dtype=torch.long) # (k) indexes = indexes.expand_as(skips) # (..., k) skips = indexes + skips # (..., k) _, eig_vectors = self.eigen() # (..., 3n),(..., 3n, 3n) macro_shape = indexes.shape[:-1] # (...), k = indexes.shape[-1] mesh = torch.meshgrid([torch.arange(s, device=indexes.device) for s in macro_shape], indexing="ij") mesh = [m.unsqueeze(-1).expand(*macro_shape, k) for m in mesh] mode_vectors = eig_vectors[(*mesh, slice(None), indexes)] # shape: (..., k, 3n) mode_vectors = mode_vectors.reshape(*macro_shape, k, n, 3) if k == 1: return mode_vectors[..., 0, :, :] else: return mode_vectors
[docs] def compute_hessian( self, coord: Tensor, force_field: ForceField, device, use_cell_list=False, node_mask: Tensor = None ): """ Compute the *Hessian* matrix for atoms with given coordinates and the chosen force field. Args: coord : tensor, shape=(..., n, 3), dtype=float The coordinates. force_field : ForceField, natoms=(..., n) The :class:`ForceField` that defines the force constants. use_cell_list : bool, optional If true, a *cell list* is used to find atoms within cutoff distance instead of checking all pairwise atom distances. This significantly increases the performance for large number of atoms, but is slower for very small systems. If the `force_field` does not provide a cutoff, no cell list is used regardless. node_mask : tensor, shape=(..., n), dtype=long, optional, 1 for valid nodes, 0 for invalid nodes Returns: hessian : tensor, shape=(..., n*3, n*3), dtype=float The computed *Hessian* matrix. Each dimension is partitioned in the form ``[x1, y1, z1, ... xn, yn, zn]``. pairs : tensor, shape=(len(...) + 2, m), dtype=int Indices for interacting atoms, i.e. atoms within `cutoff_distance`. """ # Convert into higher precision to avert numerical issues in # pseudoinverse calculation coord = coord.to(torch.float64) pairs, disp, sq_dist = self._prepare_values_for_interaction_matrix( coord, force_field, device, use_cell_list, node_mask=node_mask ) """ pair: tensor, shape=(len(...) + 2, m), len(...)=macro_shape, dtype=int, Indices for interacting atoms, i.e. atoms within `cutoff_distance`. disp: tensor, shape=(m, 3), dtype=float The displacement vector for the atom `pair`. sq_dist: tensor, shape=(m), dtype=float The squared distance for the atom `pair`. """ macro_shape = coord.shape[:-2] n = coord.shape[-2] hessian_shape = list(macro_shape) + [n, n, 3, 3] hessian = torch.zeros(hessian_shape, dtype=torch.float64, device=device) atom_i = pairs[:-1] atom_j = torch.concat([pairs[:-2], pairs[-1].unsqueeze(0)], dim=0) force_constants = force_field.force_constant(atom_i, atom_j, sq_dist) # (m) hessian[*pairs] = -(force_constants / sq_dist).view(-1, 1, 1) * disp.view(-1, 3, 1) * disp.view(-1, 1, 3) # Set values for main diagonal """ hessian.shape = (macro_shape, n, n, 3, 3) torch.sum(hessian, dim=-4).shape = (macro_shape, n, 3, 3) """ indices = torch.arange(n, device=device) hessian[..., indices, indices, :, :] = -torch.sum(hessian, dim=-4) hessian = einops.rearrange(hessian, "... a b c d -> ... (a c) (b d)") return hessian, pairs
def _prepare_values_for_interaction_matrix(self, coord, force_field, device, use_cell_list, node_mask): """ Check input values and calculate common intermediate values for :func:`compute_kirchhoff()` and :func:`compute_hessian()`. Args: coord : ndarray, shape=(..., n,3), dtype=float The coordinates. force_field : ForceField The :class:`ForceField` that defines the force constants. node_mask : tensor, shape=(..., n), dtype=long, optional, 1 for valid nodes, 0 for invalid nodes Returns: pair_indices : ndarray, shape=(len(...) + 2, m), len(...)=macro_shape, dtype=int, Indices for interacting atoms, i.e. atoms within `cutoff_distance`. disp : ndarray, shape=(m, 3), dtype=float The displacement vector for the atom `pair_indices`. sq_dist : ndarray, shape=(m), dtype=float The squared distance for the atom `pair_indices`. """ if coord.shape[-1] != 3: raise ValueError(f"Expected coordinates with shape (..., n, 3), got {coord.shape}") # Find interacting atoms within cutoff distance cutoff_distance = force_field.cutoff_distance macro_shape = coord.shape[:-2] n = coord.shape[-2] adj_matrix_shape = list(macro_shape) + [n, n] if cutoff_distance is None: # Include all possible interactions adj_matrix = torch.ones(adj_matrix_shape, dtype=bool, device=device) else: """ TODO: check if this is correct """ dist_matrix = torch.cdist(coord, coord, p=2).reshape(adj_matrix_shape) sq_dist_matrix = dist_matrix**2 adj_matrix = sq_dist_matrix <= cutoff_distance**2 # Remove interactions of atoms with themselves adj_matrix = adj_matrix.squeeze(-1) adj_matrix = adj_matrix & (~torch.eye(n, dtype=bool, device=device).view([1 for _ in macro_shape] + [n, n])) # (..., n, n) node_mask_matrix = torch.einsum("...i,...j->...ij", node_mask, node_mask) adj_matrix = adj_matrix * node_mask_matrix # self._patch_adjacency_matrix( # adj_matrix, # force_field.contact_shutdown, # force_field.contact_pair_off, # force_field.contact_pair_on, # ) # Convert matrix to indices where interaction exists pair_indices = torch.where(adj_matrix) # ((len(marcro_shape) + 2), m) pair_indices = torch.stack(pair_indices, dim=0) # ((len(marcro_shape) + 2), m) atom_i = pair_indices[:-1] # ((len(marcro_shape) + 1), m) atom_j = torch.concat([pair_indices[:-2], pair_indices[-1].unsqueeze(0)], dim=0) # ((len(marcro_shape) + 1), m) disp = coord[*atom_i] - coord[*atom_j] # Get displacement vector for ANMs # and squared distances for distance-dependent force fields if cutoff_distance is None: disp = struc.index_displacement(coord, pair_indices) sq_dist = torch.sum(disp * disp, axis=-1) else: sq_dist = sq_dist_matrix[*pair_indices] return pair_indices, disp, sq_dist def _patch_adjacency_matrix(self, matrix, contact_shutdown, contact_pair_off, contact_pair_on): """ NOT USED Apply contacts that are artificially switched off/on to an adjacency matrix. The matrix is modified in-place. Args matrix: tensor of shape (..., n, n), dtype=bool contact_shutdown: tensor of shape (..., n), dtype=int contact_pair_off: tensor of shape (..., m, 2), dtype=int contact_pair_on: tensor of shape (..., m, 2), dtype=int """ if contact_shutdown is not None: matrix[:, contact_shutdown] = False matrix[contact_shutdown, :] = False if contact_pair_off is not None: atom_i, atom_j = contact_pair_off.T matrix[atom_i, atom_j] = False matrix[atom_j, atom_i] = False if contact_pair_on is not None: atom_i, atom_j = contact_pair_on.T if (atom_i == atom_j).any(): raise ValueError("Cannot turn on interaction of an atom with itself") matrix[atom_i, atom_j] = True matrix[atom_j, atom_i] = True