Source code for ls_mlkit.util.nma.nma

import torch
from torch import Tensor

from .anm import ANM
from .force_fields import HinsenForceField


[docs] def get_nma_displacement_from_node_coordinates( node_coordinates: Tensor, cutoff_distance: float = 10.0, indexes: list[int] = [6], node_mask: Tensor = None, ) -> Tensor: """ node_coordinates: shape = (..., n, 3) node_mask: shape = (..., n) """ force_field = HinsenForceField(cutoff_distance=cutoff_distance) anm = ANM( atoms=node_coordinates, force_field=force_field, masses=None, device=node_coordinates.device, node_mask=node_mask, ) return anm.get_displacements_from_normal_modes(indexes=indexes)
[docs] def get_nma_displacement_from_protein_ligand_complex( protein_ca_coordinates: Tensor, ligand_center_of_mass: Tensor, cutoff_distance: float = 10.0, indexes: list[int] = [6], protein_mask: Tensor = None, ) -> Tensor: """ protein_ca_coordinates: shape = (..., n, 3), ligand_center_of_mass: shape = (..., 3) protein_mask: shape = (..., n) Returns: displacements: shape = (..., k, n, 3) or (..., n, 3) if k == 1, where k is the number of normal modes, n is the number of atoms, and 3 is the number of coordinates (x, y, z) """ ligand_center_of_mass = ligand_center_of_mass.unsqueeze(-2) node_coordinates = torch.cat([protein_ca_coordinates, ligand_center_of_mass], dim=-2) # (..., (n+1), 3) ligand_mask_shape = list(ligand_center_of_mass.shape[:-1]) + [1] ligand_mask = torch.ones_like(ligand_mask_shape) # (..., 1) node_mask = torch.cat([protein_mask, ligand_mask], dim=-1) # (..., n+1) return get_nma_displacement_from_node_coordinates( node_coordinates=node_coordinates, cutoff_distance=cutoff_distance, indexes=indexes, node_mask=node_mask, )