Source code for ls_mlkit.util.nma.force_fields

import abc

import torch
from torch import Tensor


[docs] class ForceField(metaclass=abc.ABCMeta): r""" Subclasses of this abstract base class define the force constants of the modeled springs between atoms in a *Elastic network model*. ``...`` is arbitrary number of dimensions, for example, batch size. Args: n: int, the number of atoms m: int, the number of edges cutoff_distance : float or None The interaction of two atoms is only considered, if the distance between them is smaller or equal to this value. If ``None``, the interaction between all atoms is considered. natoms : [...] or None The number of atoms in the model. If a :class:`ForceField` does not depend on the respective atoms, i.e. `atom_i` and `atom_j` is unused in :meth:`force_constant()`, this attribute is ``None`` instead. contact_shutdown : Tensor, shape=(..., n), dtype=float, optional Indices that point to atoms, whose contacts to all other atoms are artificially switched off. If ``None``, no contacts are switched off. contact_pair_off : Tensor, shape=(..., m, 2), dtype=int, optional Indices that point to pairs of atoms, whose contacts are artificially switched off. If ``None``, no contacts are switched off. contact_pair_on : Tensor, shape=(..., m, 2), dtype=int, optional Indices that point to pairs of atoms, whose contacts are are established in any case. If ``None``, no contacts are artificially switched on. """
[docs] @abc.abstractmethod def force_constant(self, atom_i: Tensor, atom_j: Tensor, sq_distance: Tensor): """ Get the force constant for the interaction of the given atoms. ABSTRACT: Override when inheriting. Parameters: atom_i, atom_j : Tensor, shape=(len(...) + 2, m), len(...)=macro_shape, dtype=int The indices to the first and second atoms in each interacting atom pair. sq_distance : Tensor, shape=(m), dtype=float The distance between the atoms indicated by `atom_i` and `atom_j`. Notes: Implementations of this method do not need to check whether two atoms are within the cutoff distance of the :class:`ForceField`: The given pairs of atoms are limited to pairs within cutoff distance of each other. However, if `cutoff_distance` is ``None``, the atom indices contain the Cartesian product of all atom indices, i.e. each possible combination. """
@property def cutoff_distance(self): return None @property def contact_shutdown(self): return None @property def contact_pair_off(self): return None @property def contact_pair_on(self): return None @property def natoms(self): return None
[docs] class InvariantForceField(ForceField): """ This force field treats every interaction with the same force constant. Parameters: cutoff_distance : float The interaction of two atoms is only considered, if the distance between them is smaller or equal to this value. """ def __init__(self, cutoff_distance: float): if cutoff_distance is None: # A value of 'None' would give a fully connected network # with equal force constants for each connection, # which is unreasonable raise ValueError("Cutoff distance must be a float") self._cutoff_distance = cutoff_distance
[docs] def force_constant(self, atom_i: Tensor, atom_j: Tensor = None, sq_distance: Tensor = None): """ Calculate force constants for atom interactions. Args: atom_i: Tensor, shape=(len(...) + 2, m), len(...)=macro_shape, dtype=int atom_j: Tensor, shape=(len(...) + 2, m), len(...)=macro_shape, dtype=int sq_distance: Tensor, shape=(m), dtype=float """ n_edges = atom_i.shape[-1] # (m) force_constants = torch.ones(n_edges) return force_constants
@property def cutoff_distance(self): return self._cutoff_distance
[docs] class HinsenForceField(ForceField): """ The Hinsen force field was parametrized using the *Amber94* force field for a local energy minimum, with crambin as template. In a strict distance-dependent manner, contacts are subdivided into nearest-neighbour pairs along the backbone (r < 4 Å) and mid-/far-range pair interactions (r >= 4 Å). Force constants for these interactions are computed with two distinct formulas. 2.9 Å is the lowest accepted distance between ``CA`` atoms. Values below that threshold are set to 2.9 Å. Parameters: cutoff_distance : float, optional The interaction of two atoms is only considered, if the distance between them is smaller or equal to this value. By default all interactions are included. """ def __init__(self, cutoff_distance: float = None): self._cutoff_distance = cutoff_distance
[docs] def force_constant(self, atom_i: Tensor, atom_j: Tensor, sq_distance: Tensor): """ Calculate force constants using the Hinsen force field parameters. Args: atom_i: Tensor, indices of first atoms atom_j: Tensor, indices of second atoms sq_distance: Tensor, squared distances between atom pairs Returns: Tensor: Force constants for each atom pair """ distance = torch.sqrt(sq_distance) distance = torch.clip(distance, min=2.9, max=None) return torch.where(distance < 4.0, distance * 8.6e2 - 2.39e3, distance ** (-6) * 128e4)
@property def cutoff_distance(self): return self._cutoff_distance