Skip to content

Commit

Permalink
Merge pull request #1122 from openforcefield/cache-parameter-lookups
Browse files Browse the repository at this point in the history
Improve performance of `Interchange.from_smirnoff` on polymers
  • Loading branch information
mattwthompson authored Jan 16, 2025
2 parents 9460db5 + adf40f3 commit b623003
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/releasehistory.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Please note that all releases prior to a version 1.0.0 are considered pre-releas
### Performance improvements

* #1097 Migrates version handling to `versioningit`, which should result in shorter import times.
* #1122 Improves performance of Interchange.from_smirnoff on polymers.

### Documentation improvements

Expand Down
57 changes: 56 additions & 1 deletion openff/interchange/components/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,35 @@
"""Utilities for processing and interfacing with the OpenFF Toolkit."""

from functools import lru_cache
from typing import TYPE_CHECKING, Union

import networkx
import numpy
from openff.toolkit import ForceField, Molecule, Quantity, Topology
from openff.toolkit.topology._mm_molecule import _SimpleMolecule
from openff.toolkit.typing.engines.smirnoff.parameters import VirtualSiteHandler
from openff.toolkit.typing.engines.smirnoff.parameters import ParameterHandler, VirtualSiteHandler
from openff.toolkit.utils.collections import ValidatedList
from openff.utilities.utilities import has_package

from openff.interchange.models import (
PotentialKey,
)

if has_package("openmm") or TYPE_CHECKING:
import openmm.app


_IDIVF_1 = Quantity(1.0, "dimensionless")
_PERIODICITIES = {
1: Quantity(1, "dimensionless"),
2: Quantity(2, "dimensionless"),
3: Quantity(3, "dimensionless"),
4: Quantity(4, "dimensionless"),
5: Quantity(5, "dimensionless"),
6: Quantity(6, "dimensionless"),
}


def _get_num_h_bonds(topology: "Topology") -> int:
"""Get the number of (covalent) bonds containing a hydrogen atom."""
n_bonds_containing_hydrogen = 0
Expand Down Expand Up @@ -202,3 +218,42 @@ def _lookup_virtual_site_parameter(
raise ValueError(
f"No VirtualSiteType found with {smirks=}, name={name=}, and match={match=}.",
)


@lru_cache
def _cache_angle_parameter_lookup(
potential_key: PotentialKey,
parameter_handler: ParameterHandler,
) -> dict[str, Quantity]:
parameter = parameter_handler.parameters[potential_key.id]

return {parameter_name: getattr(parameter, parameter_name) for parameter_name in ["k", "angle"]}


@lru_cache
def _cache_torsion_parameter_lookup(
potential_key: PotentialKey,
parameter_handler: ParameterHandler,
idivf: float | None = None,
) -> dict[str, Quantity]:
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

if idivf is not None:
# case of non-standard default_idivf in impropers
_idivf = idivf
elif parameter.idivf is None:
# This appears to only come from imports
_idivf = _IDIVF_1
elif parameter.idivf[n] == 1.0:
_idivf = _IDIVF_1
else:
_idivf = Quantity(parameter.idivf[n], "dimensionless")

return {
"k": parameter.k[n],
"periodicity": _PERIODICITIES[parameter.periodicity[n]],
"phase": parameter.phase[n],
"idivf": _idivf,
}
4 changes: 1 addition & 3 deletions openff/interchange/interop/openmm/_import/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def from_openmm(
def _convert_constraints(
system: "openmm.System",
) -> ConstraintCollection | None:
from openff.toolkit import unit

from openff.interchange.components.potentials import Potential
from openff.interchange.models import BondKey, PotentialKey

Expand Down Expand Up @@ -164,7 +162,7 @@ def _convert_constraints(
potential_key = PotentialKey(id=f"Constraint{index}")
_keys[distance] = potential_key
constraints.potentials[potential_key] = Potential(
parameters={"distance": distance * unit.nanometer},
parameters={"distance": Quantity(distance, "nanometer")},
)

for index in range(system.getNumConstraints()):
Expand Down
14 changes: 11 additions & 3 deletions openff/interchange/smirnoff/_nonbonded.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@
LibraryChargeHandler,
]

_ZERO_CHARGE = Quantity(0.0, unit.elementary_charge)


@functools.lru_cache(None)
def _add_charges(
charge1: float,
charge2: float,
) -> "Quantity":
"""Add two charges together."""
return Quantity(charge1 + charge2, "elementary_charge")


def _upconvert_vdw_handler(vdw_handler: vdWHandler):
"""Given a vdW with version 0.3 or 0.4, up-convert to 0.4 or short-circuit if already 0.4."""
Expand Down Expand Up @@ -362,7 +373,6 @@ def _get_charges(
# the "charge" and "charge_increment" keys may not appear in that order, so
# we "add" the charge whether or not the increment was already applied.
# There should be a better way to do this.

charges[atom_index] = charges.get(atom_index, 0.0) + parameter_value.m

else:
Expand Down Expand Up @@ -781,7 +791,6 @@ def _find_reference_matches(
@classmethod
def _assign_charges_from_molecules(
cls,
topology: Topology,
unique_molecule: Molecule,
molecules_with_preset_charges=list[Molecule] | None,
) -> tuple[bool, dict, dict]:
Expand Down Expand Up @@ -851,7 +860,6 @@ def store_matches(
unique_molecule = topology.molecule(unique_molecule_index)

flag, matches, potentials = self._assign_charges_from_molecules(
topology,
unique_molecule,
molecules_with_preset_charges,
)
Expand Down
47 changes: 22 additions & 25 deletions openff/interchange/smirnoff/_valence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ProperTorsionCollection,
)
from openff.interchange.components.potentials import Potential, WrappedPotential
from openff.interchange.components.toolkit import _cache_angle_parameter_lookup, _cache_torsion_parameter_lookup
from openff.interchange.exceptions import (
DuplicateMoleculeError,
InvalidParameterHandlerError,
Expand Down Expand Up @@ -448,15 +449,16 @@ def store_potentials(self, parameter_handler: AngleHandler) -> None:
"""
for potential_key in self.key_map.values():
smirks = potential_key.id
parameter = parameter_handler.parameters[smirks]
potential = Potential(
parameters={
parameter_name: getattr(parameter, parameter_name)
for parameter_name in self.potential_parameters()
self.potentials.update(
{
potential_key: Potential(
parameters=_cache_angle_parameter_lookup(
potential_key,
parameter_handler,
),
),
},
)
self.potentials[potential_key] = potential


class SMIRNOFFProperTorsionCollection(SMIRNOFFCollection, ProperTorsionCollection):
Expand Down Expand Up @@ -550,11 +552,11 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None:
"""
for topology_key, potential_key in self.key_map.items():
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

if topology_key.bond_order: # type: ignore[union-attr]
smirks = potential_key.id
n = potential_key.mult
parameter = parameter_handler.parameters[smirks]

bond_order = topology_key.bond_order # type: ignore[union-attr]
data = parameter.k_bondorder[n]
coeffs = _get_interpolation_coeffs(
Expand All @@ -580,12 +582,7 @@ def store_potentials(self, parameter_handler: ProperTorsionHandler) -> None:
{pot: coeff for pot, coeff in zip(pots, coeffs)},
)
else:
parameters = {
"k": parameter.k[n],
"periodicity": parameter.periodicity[n] * unit.dimensionless,
"phase": parameter.phase[n],
"idivf": parameter.idivf[n] * unit.dimensionless,
}
parameters = _cache_torsion_parameter_lookup(potential_key, parameter_handler)
potential = Potential(parameters=parameters) # type: ignore[assignment]
self.potentials[potential_key] = potential

Expand Down Expand Up @@ -734,11 +731,11 @@ def store_potentials(self, parameter_handler: ImproperTorsionHandler) -> None:
# Assumed to be a numerical value
idivf = _default_idivf * unit.dimensionless

parameters = {
"k": parameter.k[n],
"periodicity": parameter.periodicity[n] * unit.dimensionless,
"phase": parameter.phase[n],
"idivf": idivf,
}
potential = Potential(parameters=parameters)
self.potentials[potential_key] = potential
# parameter keys happen to be the same as keys in proper torsions
self.potentials[potential_key] = Potential(
parameters=_cache_torsion_parameter_lookup(
potential_key,
parameter_handler,
idivf=idivf,
),
)

0 comments on commit b623003

Please sign in to comment.