diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5231f992..781c3ddb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,9 +45,10 @@ repos: hooks: - id: isort name: isort (python) - args: ["--profile", "black", "--filter-files"] + args: ["--profile", "black", "--line-length", "80", "--filter-files"] - repo: https://github.com/psf/black rev: 24.10.0 hooks: - id: black + args: ["--line-length", "80"] diff --git a/setup.cfg b/setup.cfg index 872b81d9..fdc5b219 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = scipy tad-dftd3>=0.3.0 tad-dftd4>=0.2.0 - tad-mctc>=0.2.0 + tad-mctc>=0.3.0 tad-multicharge tomli tomli-w diff --git a/src/dxtb/_src/basis/slater.py b/src/dxtb/_src/basis/slater.py index b5799e2a..9cf03bb9 100644 --- a/src/dxtb/_src/basis/slater.py +++ b/src/dxtb/_src/basis/slater.py @@ -125,7 +125,7 @@ def slater_to_gauss( # <φ|φ> = (2i-1)!!(2j-1)!!(2k-1)!!/(4α)^(i+j+k) · sqrt(π/2α)³ # N² = (4α)^(i+j+k)/((2i-1)!!(2j-1)!!(2k-1)!!) · sqrt(2α/π)³ # N = (4α)^((i+j+k)/2) / sqrt((2i-1)!!(2j-1)!!(2k-1)!!) · (2α/π)^(3/4) - if norm: + if norm is True: coeff = coeff * ( (top * alpha) ** 0.75 * torch.sqrt(4 * alpha) ** l diff --git a/src/dxtb/_src/calculators/gfn1.py b/src/dxtb/_src/calculators/gfn1.py index 023ccaa6..7376db82 100644 --- a/src/dxtb/_src/calculators/gfn1.py +++ b/src/dxtb/_src/calculators/gfn1.py @@ -57,6 +57,7 @@ def __init__( # pylint: disable=import-outside-toplevel from dxtb import GFN1_XTB + # constructor can be found in src/dxtb/_src/calculators/types/base.py super().__init__( numbers, GFN1_XTB, diff --git a/src/dxtb/_src/calculators/gfn2.py b/src/dxtb/_src/calculators/gfn2.py index b8f0faf3..936caeb9 100644 --- a/src/dxtb/_src/calculators/gfn2.py +++ b/src/dxtb/_src/calculators/gfn2.py @@ -57,6 +57,7 @@ def __init__( # pylint: disable=import-outside-toplevel from dxtb import GFN2_XTB + # constructor can be found in src/dxtb/_src/calculators/types/base.py super().__init__( numbers, GFN2_XTB, @@ -66,5 +67,3 @@ def __init__( device=device, dtype=dtype, ) - - raise NotImplementedError("GFN2-xTB is not yet implemented.") diff --git a/src/dxtb/_src/components/classicals/list.py b/src/dxtb/_src/components/classicals/list.py index 7030c80b..888f9bda 100644 --- a/src/dxtb/_src/components/classicals/list.py +++ b/src/dxtb/_src/components/classicals/list.py @@ -157,16 +157,10 @@ def get_gradient( ########################################################################### @overload - def get_interaction( - self, - name: Literal["Halogen"], - ) -> Halogen: ... + def get_interaction(self, name: Literal["Halogen"]) -> Halogen: ... @overload - def get_interaction( - self, - name: Literal["Repulsion"], - ) -> Repulsion: ... + def get_interaction(self, name: Literal["Repulsion"]) -> Repulsion: ... @override # generic implementation for typing def get_interaction(self, name: str) -> Classical: diff --git a/src/dxtb/_src/components/classicals/repulsion/factory.py b/src/dxtb/_src/components/classicals/repulsion/factory.py index 2bd53fe8..a18921fc 100644 --- a/src/dxtb/_src/components/classicals/repulsion/factory.py +++ b/src/dxtb/_src/components/classicals/repulsion/factory.py @@ -74,7 +74,9 @@ def new_repulsion( """ if hasattr(par, "repulsion") is False or par.repulsion is None: - # TODO: Repulsion is used in all models, so error or just warning? + # Although repulsion is used in all models, we do not want to exit + # for custom models that are loaded from a parameter file. Hence, we + # only issue a warning here, not an error. warnings.warn("No repulsion scheme found.", ParameterWarning) return None diff --git a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py index 82443324..787bf7a0 100644 --- a/src/dxtb/_src/components/interactions/coulomb/thirdorder.py +++ b/src/dxtb/_src/components/interactions/coulomb/thirdorder.py @@ -188,9 +188,9 @@ def get_cache( Note ---- If the :class:`.ES3` interaction is evaluated within the - :class:`dxtb.components.InteractionList`, ``positions`` will be passed - as an argument, too. Hence, it is necessary in signature - of the function to absorb it (also see + :class:`dxtb.components.InteractionList`, ``positions`` will be + passed as an argument, too. Hence, it is necessary to absorb + the ``positions`` in the signature of the function (also see :meth:`dxtb.components.Interaction.get_cache`). """ if numbers is None: @@ -297,7 +297,7 @@ def new_es3( if device is not None: if device != numbers.device: raise DeviceError( - f"Passed device ({device}) and device of electric field " + f"Passed device ({device}) and device of `numbers` tensor " f"({numbers.device}) do not match." ) diff --git a/src/dxtb/_src/components/interactions/solvation/alpb.py b/src/dxtb/_src/components/interactions/solvation/alpb.py index e3e0a3ff..33cdf68e 100644 --- a/src/dxtb/_src/components/interactions/solvation/alpb.py +++ b/src/dxtb/_src/components/interactions/solvation/alpb.py @@ -270,10 +270,11 @@ def get_cache( Note ---- - If the :class:`.GeneralizedBorn` interaction is evaluated within the - :class:`dxtb.components.InteractionList`, the :class:`dxtb.IndexHelper` - will be passed as an argument, too. Hence, it is necessary in signature - of the function to absorb it. + If the :class:`.GeneralizedBorn` interaction is evaluated + within the :class:`dxtb.components.InteractionList`, the + :class:`dxtb.IndexHelper` will be passed as an argument, too. Hence, + it is necessary to absorb the ``positions`` in the signature of the + function. """ if numbers is None: raise ValueError("Atomic numbers are required for cache.") diff --git a/src/dxtb/_src/components/list.py b/src/dxtb/_src/components/list.py index 59efa68b..1e3b16dc 100644 --- a/src/dxtb/_src/components/list.py +++ b/src/dxtb/_src/components/list.py @@ -75,10 +75,10 @@ def restore(self) -> None: """ pass - def __str__(self) -> str: + def __str__(self) -> str: # pragma: no cover return f"{self.__class__.__name__}({list(self.keys())})" - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return str(self) @@ -238,7 +238,8 @@ def reset_all(self) -> None: @property def labels(self) -> list[str]: - return [component.label for component in self.components] + """Alphabetically sorted list of all components labels.""" + return sorted([component.label for component in self.components]) def get_interaction(self, name: str) -> C: """ @@ -265,10 +266,13 @@ def get_interaction(self, name: str) -> C: raise ValueError(f"The component named '{name}' is not in the list.") - def __str__(self) -> str: + def __len__(self) -> int: + return len(self.components) + + def __str__(self) -> str: # pragma: no cover return f"{self.__class__.__name__}({self.labels})" - def __repr__(self) -> str: + def __repr__(self) -> str: # pragma: no cover return str(self) @override diff --git a/src/dxtb/_src/integral/container.py b/src/dxtb/_src/integral/container.py index 24b85e53..693c059c 100644 --- a/src/dxtb/_src/integral/container.py +++ b/src/dxtb/_src/integral/container.py @@ -482,6 +482,23 @@ def reset_all(self) -> None: if isinstance(i, BaseIntegral) or isinstance(i, BaseHamiltonian): i.clear() + @property + def labels(self) -> list[str]: + """Return all initialized integrals by label.""" + return [ + slot[1:] + for slot in self.__slots__ + if slot.startswith("_") and getattr(self, slot) is not None + ] + + def __len__(self) -> int: + """Print all initialized integrals.""" + return sum( + 1 + for slot in self.__slots__ + if slot.startswith("_") and getattr(self, slot) is not None + ) + # pretty print def __str__(self) -> str: diff --git a/src/dxtb/_src/xtb/abc.py b/src/dxtb/_src/xtb/abc.py index d261d75e..83f88876 100644 --- a/src/dxtb/_src/xtb/abc.py +++ b/src/dxtb/_src/xtb/abc.py @@ -36,6 +36,28 @@ class HamiltonianABC(ABC): Abstract base class for Hamiltonians. """ + @abstractmethod + def _get_hscale(self) -> Tensor: + """ + Obtain the off-site scaling factor for the Hamiltonian. + + Returns + ------- + Tensor + Off-site scaling factor for the Hamiltonian. + """ + + @abstractmethod + def _get_elem_valence(self) -> Tensor: + """ + Obtain a mask for valence and non-valence shells. This is only required for GFN1-xTB's second hydrogen s-function. + + Returns + ------- + Tensor + Mask indicating valence shells for each unique species. + """ + @abstractmethod def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: """ diff --git a/src/dxtb/_src/xtb/base.py b/src/dxtb/_src/xtb/base.py index c982c416..9473c196 100644 --- a/src/dxtb/_src/xtb/base.py +++ b/src/dxtb/_src/xtb/base.py @@ -24,6 +24,11 @@ from __future__ import annotations import torch +from tad_mctc import storch +from tad_mctc.batch import real_pairs +from tad_mctc.convert import symmetrize +from tad_mctc.data.radii import ATOMIC as ATOMIC_RADII +from tad_mctc.units import EV2AU from dxtb import IndexHelper from dxtb._src.param import Param @@ -33,6 +38,8 @@ __all__ = ["BaseHamiltonian"] +PAD = -1 + class BaseHamiltonian(HamiltonianABC, TensorLike): """ @@ -68,14 +75,17 @@ class BaseHamiltonian(HamiltonianABC, TensorLike): shpoly: Tensor """Polynomial parameters for the distant dependent scaling.""" valence: Tensor - """Whether the shell belongs to the valence shell.""" + """ + Whether the shell belongs to the valence shell. + Only requried for GFN1-xTB (second s-function for H). + """ en: Tensor """Pauling electronegativity of each species.""" rad: Tensor """Van-der-Waals radius of each species.""" - cn: None | CNFunction + cn: CNFunction | None """Coordination number function.""" __slots__ = [ @@ -117,6 +127,62 @@ def __init__( self.label = self.__class__.__name__ self._matrix = None + # Initialize Hamiltonian parameters + + if self.par.hamiltonian is None: + raise RuntimeError("Parametrization does not specify Hamiltonian.") + + # atom-resolved parameters + self.rad = ATOMIC_RADII.to(**self.dd)[self.unique] + self.en = self._get_elem_param("en") + + # shell-resolved element parameters + self.kcn = self._get_elem_param("kcn") + self.selfenergy = self._get_elem_param("levels") + self.shpoly = self._get_elem_param("shpoly") + self.refocc = self._get_elem_param("refocc") + self.valence = self._get_elem_valence() + + # shell-pair-resolved pair parameters + self.hscale = self._get_hscale() + self.kpair = self._get_pair_param(self.par.hamiltonian.xtb.kpair) + + # unit conversion + self.selfenergy = self.selfenergy * EV2AU + self.kcn = self.kcn * EV2AU + # dtype should always be correct as it always uses self.dtype + if any( + tensor.dtype != self.dtype + for tensor in ( + self.hscale, + self.kcn, + self.kpair, + self.refocc, + self.selfenergy, + self.shpoly, + self.en, + self.rad, + ) + ): # pragma: no cover + raise ValueError("All tensors must have same dtype") + + # device should always be correct as it always uses self.device + if any( + tensor.device != self.device + for tensor in ( + self.hscale, + self.kcn, + self.kpair, + self.refocc, + self.selfenergy, + self.shpoly, + self.valence, + self.en, + self.rad, + ) + ): # pragma: no cover + raise ValueError("All tensors must be on the same device") + @property def matrix(self) -> Tensor | None: return self._matrix @@ -138,6 +204,59 @@ def requires_grad(self) -> bool: return self._matrix.requires_grad + def _get_elem_param(self, key: str) -> Tensor: + """ + Obtain element parameters for species. + + Parameters + ---------- + key : str + Name of the parameter to be retrieved. + + Returns + ------- + Tensor + Parameters for each species. + """ + # pylint: disable=import-outside-toplevel + from dxtb._src.param import get_elem_param + + return get_elem_param( + self.unique, self.par.element, key, pad_val=PAD, **self.dd + ) + + def _get_elem_valence(self) -> Tensor: + """ + Obtain a mask for valence and non-valence shells. This is only required for GFN1-xTB's second hydrogen s-function. For GFN2-xTB, this is a dummy method, i.e., the mask is always ``True``. + + Returns + ------- + Tensor + Mask indicating valence shells for each unique species. + """ + return torch.ones( + len(self.ihelp.unique_angular), device=self.device, dtype=torch.bool + ) + + def _get_pair_param(self, pair: dict[str, float]) -> Tensor: + """ + Obtain element-pair-specific parameters for all species. + + Parameters + ---------- + pair : dict[str, float] + Pair parametrization. + + Returns + ------- + Tensor + Pair parameters for each species. + """ + # pylint: disable=import-outside-toplevel + from dxtb._src.param import get_pair_param + + return get_pair_param(self.unique.tolist(), pair, **self.dd) + def get_occupation(self) -> Tensor: """ Obtain the reference occupation numbers for each orbital. @@ -168,3 +287,117 @@ def to_pt(self, path: PathLike | None = None) -> None: path = f"{self.label.casefold()}.pt" torch.save(self.matrix, path) + + def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: + """ + Build the xTB Hamiltonian. + + Parameters + ---------- + positions : Tensor + Atomic positions of molecular structure. + overlap : Tensor | None, optional + Overlap matrix. If ``None``, the true xTB Hamiltonian is *not* + built. Defaults to ``None``. + + Returns + ------- + Tensor + Hamiltonian (always symmetric). + """ + if self.par.hamiltonian is None: + raise RuntimeError("No Hamiltonian specified.") + + # masks + mask_atom_diagonal = real_pairs(self.numbers, mask_diagonal=True) + mask_shell = real_pairs( + self.ihelp.spread_atom_to_shell(self.numbers), mask_diagonal=False + ) + mask_shell_diagonal = self.ihelp.spread_atom_to_shell( + mask_atom_diagonal, dim=(-2, -1) + ) + + zero = torch.tensor(0.0, **self.dd) + + # ---------------- + # Eq.29: H_(mu,mu) + # ---------------- + if self.cn is None: + cn = torch.zeros_like(self.numbers, **self.dd) + else: + cn = self.cn(self.numbers, positions) + + kcn = self.ihelp.spread_ushell_to_shell(self.kcn) + + # formula differs from paper to be consistent with GFN2 -> "kcn" adapted + selfenergy = self.ihelp.spread_ushell_to_shell( + self.selfenergy + ) - kcn * self.ihelp.spread_atom_to_shell(cn) + + # ---------------------- + # Eq.24: PI(R_AB, l, l') + # ---------------------- + distances = storch.cdist(positions, positions, p=2) + rad = self.ihelp.spread_uspecies_to_atom(self.rad) + + rr = storch.divide(distances, rad.unsqueeze(-1) + rad.unsqueeze(-2)) + rr_shell = self.ihelp.spread_atom_to_shell( + torch.where(mask_atom_diagonal, storch.sqrt(rr), zero), + (-2, -1), + ) + + shpoly = self.ihelp.spread_ushell_to_shell(self.shpoly) + var_pi = (1.0 + shpoly.unsqueeze(-1) * rr_shell) * ( + 1.0 + shpoly.unsqueeze(-2) * rr_shell + ) + + # -------------------- + # Eq.28: X(EN_A, EN_B) + # -------------------- + en = self.ihelp.spread_uspecies_to_shell(self.en) + var_x = torch.where( + mask_shell_diagonal, + 1.0 + + self.par.hamiltonian.xtb.enscale + * torch.pow(en.unsqueeze(-1) - en.unsqueeze(-2), 2.0), + zero, + ) + + # -------------------- + # Eq.23: K_{AB}^{l,l'} + # -------------------- + kpair = self.ihelp.spread_uspecies_to_shell(self.kpair, dim=(-2, -1)) + hscale = self.ihelp.spread_ushell_to_shell(self.hscale, dim=(-2, -1)) + valence = self.ihelp.spread_ushell_to_shell(self.valence) + + var_k = torch.where( + valence.unsqueeze(-1) * valence.unsqueeze(-2), + hscale * kpair * var_x, + hscale, + ) + + # ------------ + # Eq.23: H_EHT + # ------------ + var_h = torch.where( + mask_shell, + 0.5 * (selfenergy.unsqueeze(-1) + selfenergy.unsqueeze(-2)), + zero, + ) + + hcore = self.ihelp.spread_shell_to_orbital( + torch.where( + mask_shell_diagonal, + var_pi * var_k * var_h, # scale only off-diagonals + var_h, + ), + dim=(-2, -1), + ) + + if overlap is not None: + hcore = hcore * overlap + + # force symmetry to avoid problems through numerical errors + h0 = symmetrize(hcore) + self.matrix = h0 + return h0 diff --git a/src/dxtb/_src/xtb/gfn1.py b/src/dxtb/_src/xtb/gfn1.py index 5fdac428..3025befe 100644 --- a/src/dxtb/_src/xtb/gfn1.py +++ b/src/dxtb/_src/xtb/gfn1.py @@ -26,20 +26,13 @@ import torch from tad_mctc import storch from tad_mctc.batch import real_pairs -from tad_mctc.convert import symmetrize -from tad_mctc.data.radii import ATOMIC as ATOMIC_RADII -from tad_mctc.ncoord import cn_d3 -from tad_mctc.units import EV2AU from dxtb import IndexHelper from dxtb._src.components.interactions import Potential from dxtb._src.param import Param -from dxtb._src.typing import Any, Tensor +from dxtb._src.typing import Any, Tensor, override -from .base import BaseHamiltonian - -PAD = -1 -"""Value used for padding of tensors.""" +from .base import PAD, BaseHamiltonian __all__ = ["GFN1Hamiltonian"] @@ -58,122 +51,13 @@ def __init__( ) -> None: super().__init__(numbers, par, ihelp, device, dtype) - if self.par.hamiltonian is None: - raise RuntimeError("Parametrization does not specify Hamiltonian.") - - # atom-resolved parameters - self.rad = ATOMIC_RADII.to(**self.dd)[self.unique] - self.en = self._get_elem_param("en") - - # shell-resolved element parameters - self.kcn = self._get_elem_param("kcn") - self.selfenergy = self._get_elem_param("levels") - self.shpoly = self._get_elem_param("shpoly") - self.refocc = self._get_elem_param("refocc") - self.valence = self._get_elem_valence() - - # shell-pair-resolved pair parameters - self.hscale = self._get_hscale() - self.kpair = self._get_pair_param(self.par.hamiltonian.xtb.kpair) - - # unit conversion - self.selfenergy = self.selfenergy * EV2AU - self.kcn = self.kcn * EV2AU - # coordination number function - self.cn = kwargs.pop("cn", cn_d3) - - # dtype should always be correct as it always uses self.dtype - if any( - tensor.dtype != self.dtype - for tensor in ( - self.hscale, - self.kcn, - self.kpair, - self.refocc, - self.selfenergy, - self.shpoly, - self.en, - self.rad, - ) - ): # pragma: no cover - raise ValueError("All tensors must have same dtype") - - # device should always be correct as it always uses self.device - if any( - tensor.device != self.device - for tensor in ( - self.hscale, - self.kcn, - self.kpair, - self.refocc, - self.selfenergy, - self.shpoly, - self.valence, - self.en, - self.rad, - ) - ): # pragma: no cover - raise ValueError("All tensors must be on the same device") - - def _get_elem_param(self, key: str) -> Tensor: - """ - Obtain element parameters for species. - - Parameters - ---------- - key : str - Name of the parameter to be retrieved. - - Returns - ------- - Tensor - Parameters for each species. - """ - # pylint: disable=import-outside-toplevel - from dxtb._src.param import get_elem_param - - return get_elem_param( - self.unique, self.par.element, key, pad_val=PAD, **self.dd - ) - - def _get_elem_valence(self) -> Tensor: - """ - Obtain "valence" parameters for shells of species. - - Returns - ------- - Tensor - Valence parameters for each species. - """ - # pylint: disable=import-outside-toplevel - from dxtb._src.param import get_elem_valence - - return get_elem_valence( - self.unique, - self.par.element, - pad_val=PAD, - device=self.device, - ) - - def _get_pair_param(self, pair: dict[str, float]) -> Tensor: - """ - Obtain element-pair-specific parameters for all species. - - Parameters - ---------- - pair : dict[str, float] - Pair parametrization. - - Returns - ------- - Tensor - Pair parameters for each species. - """ - # pylint: disable=import-outside-toplevel - from dxtb._src.param import get_pair_param + if "cn" in kwargs: + self.cn = kwargs.pop("cn") + else: + from tad_mctc.ncoord import cn_d3 - return get_pair_param(self.unique.tolist(), pair, **self.dd) + self.cn = cn_d3 def _get_hscale(self) -> Tensor: """ @@ -261,120 +145,26 @@ def _get_hscale(self) -> Tensor: return ksh - def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: + @override + def _get_elem_valence(self) -> Tensor: """ - Build the xTB Hamiltonian. - - Parameters - ---------- - positions : Tensor - Atomic positions of molecular structure. - overlap : Tensor | None, optional - Overlap matrix. If ``None``, the true xTB Hamiltonian is *not* - built. Defaults to ``None``. + Obtain a mask for valence and non-valence shells. This is only required for GFN1-xTB's second hydrogen s-function. Returns ------- Tensor - Hamiltonian (always symmetric). + Mask indicating valence shells for each unique species. """ - if self.par.hamiltonian is None: - raise RuntimeError("No Hamiltonian specified.") - - # masks - mask_atom_diagonal = real_pairs(self.numbers, mask_diagonal=True) - mask_shell = real_pairs( - self.ihelp.spread_atom_to_shell(self.numbers), mask_diagonal=False - ) - mask_shell_diagonal = self.ihelp.spread_atom_to_shell( - mask_atom_diagonal, dim=(-2, -1) - ) - - zero = torch.tensor(0.0, **self.dd) - - # ---------------- - # Eq.29: H_(mu,mu) - # ---------------- - if self.cn is None: - cn = torch.zeros_like(self.numbers, **self.dd) - else: - cn = self.cn(self.numbers, positions) - - kcn = self.ihelp.spread_ushell_to_shell(self.kcn) - - # formula differs from paper to be consistent with GFN2 -> "kcn" adapted - selfenergy = self.ihelp.spread_ushell_to_shell( - self.selfenergy - ) - kcn * self.ihelp.spread_atom_to_shell(cn) - - # ---------------------- - # Eq.24: PI(R_AB, l, l') - # ---------------------- - distances = storch.cdist(positions, positions, p=2) - rad = self.ihelp.spread_uspecies_to_atom(self.rad) - - rr = storch.divide(distances, rad.unsqueeze(-1) + rad.unsqueeze(-2)) - rr_shell = self.ihelp.spread_atom_to_shell( - torch.where(mask_atom_diagonal, storch.sqrt(rr), zero), - (-2, -1), - ) - - shpoly = self.ihelp.spread_ushell_to_shell(self.shpoly) - var_pi = (1.0 + shpoly.unsqueeze(-1) * rr_shell) * ( - 1.0 + shpoly.unsqueeze(-2) * rr_shell - ) - - # -------------------- - # Eq.28: X(EN_A, EN_B) - # -------------------- - en = self.ihelp.spread_uspecies_to_shell(self.en) - var_x = torch.where( - mask_shell_diagonal, - 1.0 - + self.par.hamiltonian.xtb.enscale - * torch.pow(en.unsqueeze(-1) - en.unsqueeze(-2), 2.0), - zero, - ) - - # -------------------- - # Eq.23: K_{AB}^{l,l'} - # -------------------- - kpair = self.ihelp.spread_uspecies_to_shell(self.kpair, dim=(-2, -1)) - hscale = self.ihelp.spread_ushell_to_shell(self.hscale, dim=(-2, -1)) - valence = self.ihelp.spread_ushell_to_shell(self.valence) - - var_k = torch.where( - valence.unsqueeze(-1) * valence.unsqueeze(-2), - hscale * kpair * var_x, - hscale, - ) - - # ------------ - # Eq.23: H_EHT - # ------------ - var_h = torch.where( - mask_shell, - 0.5 * (selfenergy.unsqueeze(-1) + selfenergy.unsqueeze(-2)), - zero, - ) + # pylint: disable=import-outside-toplevel + from dxtb._src.param import get_elem_valence - hcore = self.ihelp.spread_shell_to_orbital( - torch.where( - mask_shell_diagonal, - var_pi * var_k * var_h, # scale only off-diagonals - var_h, - ), - dim=(-2, -1), + return get_elem_valence( + self.unique, + self.par.element, + pad_val=PAD, + device=self.device, ) - if overlap is not None: - hcore = hcore * overlap - - # force symmetry to avoid problems through numerical errors - h0 = symmetrize(hcore) - self.matrix = h0 - return h0 - def get_gradient( self, positions: Tensor, diff --git a/src/dxtb/_src/xtb/gfn2.py b/src/dxtb/_src/xtb/gfn2.py index 147099f0..7ffbaa91 100644 --- a/src/dxtb/_src/xtb/gfn2.py +++ b/src/dxtb/_src/xtb/gfn2.py @@ -23,10 +23,16 @@ from __future__ import annotations +from functools import partial + +import torch + +from dxtb import IndexHelper from dxtb._src.components.interactions import Potential -from dxtb._src.typing import Tensor +from dxtb._src.param import Param +from dxtb._src.typing import Any, Tensor -from .base import BaseHamiltonian +from .base import PAD, BaseHamiltonian __all__ = ["GFN2Hamiltonian"] @@ -36,8 +42,94 @@ class GFN2Hamiltonian(BaseHamiltonian): The GFN2-xTB Hamiltonian. """ - def build(self, positions: Tensor, overlap: Tensor | None = None) -> Tensor: - raise NotImplementedError("GFN2 not implemented yet.") + def __init__( + self, + numbers: Tensor, + par: Param, + ihelp: IndexHelper, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: Any, + ) -> None: + super().__init__(numbers, par, ihelp, device, dtype) + + # coordination number function + if "cn" in kwargs: + self.cn = kwargs.pop("cn") + else: + from tad_mctc.ncoord import cn_d3, gfn2_count + + self.cn = partial(cn_d3, counting_function=gfn2_count) + + def _get_hscale(self) -> Tensor: + """ + Obtain the off-site scaling factor for the Hamiltonian. + + Returns + ------- + Tensor + Off-site scaling factor for the Hamiltonian. + """ + if self.par.hamiltonian is None: + raise RuntimeError("No Hamiltonian specified.") + + # extract some vars for convenience + shell = self.par.hamiltonian.xtb.shell + wexp = self.par.hamiltonian.xtb.wexp + ushells = self.ihelp.unique_angular + + angular2label = { + 0: "s", + 1: "p", + 2: "d", + 3: "f", + 4: "g", + } + angular_labels = [angular2label.get(int(ang), PAD) for ang in ushells] + + # ---------------------- + # Eq.37: Y(z^A_l, z^B_m) + # ---------------------- + z = self._get_elem_param("slater") + zi = z.unsqueeze(-1) + zj = z.unsqueeze(-2) + zmat = (2 * torch.sqrt(zi * zj) / (zi + zj)) ** wexp + + ksh = torch.ones((len(ushells), len(ushells)), **self.dd) + for i, ang_i in enumerate(ushells): + ang_i = angular_labels[i] + + for j, ang_j in enumerate(ushells): + ang_j = angular_labels[j] + + # Since the parametrization only contains "sp" (not "ps"), + # we need to check both. + # For some reason, the parametrization does not contain "sp" + # or "ps", although the value is calculated from "ss" and "pp", + # and hence, always the same. The paper, however, specifically + # mentions this. + # tblite: xtb/gfn2.f90::new_gfn2_h0spec + if f"{ang_i}{ang_j}" in shell: + kij = shell[f"{ang_i}{ang_j}"] + elif f"{ang_j}{ang_i}" in shell: + kij = shell[f"{ang_j}{ang_i}"] + else: + if f"{ang_i}{ang_i}" not in shell: + raise KeyError( + f"GFN2 HCore: Missing {ang_i}{ang_i} in shell." + ) + if f"{ang_j}{ang_j}" not in shell: # pragma: no cover + raise KeyError( + f"GFN2 HCore: Missing {ang_j}{ang_j} in shell." + ) + + kij = 0.5 * ( + shell[f"{ang_i}{ang_i}"] + shell[f"{ang_j}{ang_j}"] + ) + + ksh[i, j] = kij * zmat[i, j] + + return ksh def get_gradient( self, diff --git a/test/test_classical/test_dispersion/test_d3.py b/test/test_classical/test_dispersion/test_d3.py index bc407f61..8c3a128f 100644 --- a/test/test_classical/test_dispersion/test_d3.py +++ b/test/test_classical/test_dispersion/test_d3.py @@ -29,11 +29,8 @@ from tad_mctc.batch import pack from dxtb import GFN1_XTB as par -from dxtb._src.components.classicals.dispersion import ( - DispersionD3, - new_dispersion, -) from dxtb._src.typing import DD, Tensor +from dxtb.components.dispersion import DispersionD3, new_dispersion from ...conftest import DEVICE from .samples import samples @@ -81,7 +78,7 @@ def test_disp_batch(dtype: torch.dtype) -> None: } energy = d3.disp.dispersion( - numbers, positions, param, c6, rvdw, r4r2, d3.disp.rational_damping + numbers, positions, param, c6, rvdw, r4r2, d3.damping.rational_damping ) assert energy.dtype == dtype assert pytest.approx(ref.cpu()) == energy.cpu() diff --git a/test/test_classical/test_dispersion/test_energy.py b/test/test_classical/test_dispersion/test_energy.py index ac4db9a2..e9764ad8 100644 --- a/test/test_classical/test_dispersion/test_energy.py +++ b/test/test_classical/test_dispersion/test_energy.py @@ -78,7 +78,7 @@ def test_disp_batch(dtype: torch.dtype) -> None: } energy = d3.disp.dispersion( - numbers, positions, param, c6, rvdw, r4r2, d3.disp.rational_damping + numbers, positions, param, c6, rvdw, r4r2, d3.damping.rational_damping ) assert energy.dtype == dtype assert pytest.approx(ref.cpu()) == energy.cpu() diff --git a/test/test_classical/test_repulsion/test_grad_pos.py b/test/test_classical/test_repulsion/test_grad_pos.py index b8e7dc74..f7e84e7d 100644 --- a/test/test_classical/test_repulsion/test_grad_pos.py +++ b/test/test_classical/test_repulsion/test_grad_pos.py @@ -48,7 +48,10 @@ @pytest.mark.grad @pytest.mark.parametrize("dtype", [torch.double]) @pytest.mark.parametrize("name", ["H2O", "SiH4"]) -def test_backward_vs_tblite(dtype: torch.dtype, name: str) -> None: +@pytest.mark.parametrize("with_analytical_gradient", [True, False]) +def test_backward_vs_tblite( + dtype: torch.dtype, name: str, with_analytical_gradient: bool +) -> None: """Compare with reference values from tblite.""" dd: DD = {"device": DEVICE, "dtype": dtype} @@ -57,7 +60,9 @@ def test_backward_vs_tblite(dtype: torch.dtype, name: str) -> None: positions = sample["positions"].to(**dd) ref = sample["gfn1_grad"].to(**dd) - rep = new_repulsion(numbers, par, **dd) + rep = new_repulsion( + numbers, par, with_analytical_gradient=with_analytical_gradient, **dd + ) assert rep is not None ihelp = IndexHelper.from_numbers(numbers, par) @@ -83,8 +88,9 @@ def test_backward_vs_tblite(dtype: torch.dtype, name: str) -> None: @pytest.mark.parametrize("dtype", [torch.double]) @pytest.mark.parametrize("name1", ["H2O", "SiH4"]) @pytest.mark.parametrize("name2", ["H2O", "SiH4"]) +@pytest.mark.parametrize("with_analytical_gradient", [True, False]) def test_backward_batch_vs_tblite( - dtype: torch.dtype, name1: str, name2: str + dtype: torch.dtype, name1: str, name2: str, with_analytical_gradient: bool ) -> None: """Compare with reference values from tblite.""" dd: DD = {"device": DEVICE, "dtype": dtype} @@ -109,7 +115,9 @@ def test_backward_batch_vs_tblite( ] ) - rep = new_repulsion(numbers, par, **dd) + rep = new_repulsion( + numbers, par, with_analytical_gradient=with_analytical_gradient, **dd + ) assert rep is not None ihelp = IndexHelper.from_numbers(numbers, par) diff --git a/test/test_components/__init__.py b/test/test_components/__init__.py new file mode 100644 index 00000000..15d042be --- /dev/null +++ b/test/test_components/__init__.py @@ -0,0 +1,16 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/test_components/test_list.py b/test/test_components/test_list.py new file mode 100644 index 00000000..1236eed4 --- /dev/null +++ b/test/test_components/test_list.py @@ -0,0 +1,64 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test collection of list of components. +""" + +from __future__ import annotations + +import pytest +import torch + +from dxtb._src.components.list import ComponentList, ComponentListCache +from dxtb._src.typing import Any + +from ..conftest import DEVICE + + +def test_cache() -> None: + cache = ComponentListCache() + + dummy = torch.tensor([1, 2, 3], device=DEVICE) + assert len(list(cache.keys())) == 0 + + # dummy functions + cache.cull(dummy, dummy) # type: ignore + cache.restore() + + +def test_list() -> None: + class Dummy(ComponentList): + def get_energy(self, *_: Any, **__: Any) -> None: + pass + + def get_gradient(self, *_: Any, **__: Any) -> None: + pass + + def get_cache(self, *_: Any, **__: Any) -> None: + pass + + clist = Dummy() + assert len(clist) == 0 + + with pytest.raises(ValueError): + clist.get_interaction("dummy") + + with pytest.raises(ValueError): + clist.update("dummy") + + with pytest.raises(ValueError): + clist.reset("dummy") diff --git a/test/test_hamiltonian/test_general.py b/test/test_hamiltonian/test_general.py index 3b8f90df..8d3c9af1 100644 --- a/test/test_hamiltonian/test_general.py +++ b/test/test_hamiltonian/test_general.py @@ -25,12 +25,13 @@ from tad_mctc.convert import str_to_device from tad_mctc.typing import MockTensor -from dxtb import GFN1_XTB as par -from dxtb import IndexHelper +from dxtb import GFN1_XTB, GFN2_XTB, IndexHelper, Param from dxtb._src.xtb.gfn1 import GFN1Hamiltonian +from dxtb._src.xtb.gfn2 import GFN2Hamiltonian -def test_no_h0_fail() -> None: +@pytest.mark.parametrize("par", [GFN1_XTB, GFN2_XTB]) +def test_no_h0_fail(par: Param) -> None: dummy = torch.tensor([]) _par = par.model_copy(deep=True) _par.hamiltonian = None @@ -38,11 +39,49 @@ def test_no_h0_fail() -> None: with pytest.raises(RuntimeError): GFN1Hamiltonian(dummy, _par, dummy) # type: ignore + with pytest.raises(RuntimeError): + GFN2Hamiltonian(dummy, _par, dummy) # type: ignore + + +def test_no_h0_fail_2() -> None: + numbers = torch.tensor([1]) + par_gfn1 = GFN1_XTB.model_copy(deep=True) + par_gfn2 = GFN2_XTB.model_copy(deep=True) + + ihelp_gfn1 = IndexHelper.from_numbers(numbers, par_gfn1) + ihelp_gfn2 = IndexHelper.from_numbers(numbers, par_gfn2) + + h0_gfn1 = GFN1Hamiltonian(numbers, par_gfn1, ihelp_gfn1) + h0_gfn1.par.hamiltonian = None + h0_gfn2 = GFN2Hamiltonian(numbers, par_gfn2, ihelp_gfn2) + h0_gfn2.par.hamiltonian = None + + with pytest.raises(RuntimeError): + h0_gfn1._get_hscale() + + with pytest.raises(RuntimeError): + h0_gfn2._get_hscale() + -def test_no_h0_fail2() -> None: +def test_no_h0_fail_3() -> None: + numbers = torch.tensor([1]) + par_gfn2 = GFN2_XTB.model_copy(deep=True) + ihelp_gfn2 = IndexHelper.from_numbers(numbers, par_gfn2) + + h0_gfn2 = GFN2Hamiltonian(numbers, par_gfn2, ihelp_gfn2) + assert h0_gfn2.par.hamiltonian is not None + + h0_gfn2.par.hamiltonian.xtb.shell = {} + h0_gfn2.ihelp.unique_angular = torch.tensor([4]) + + with pytest.raises(KeyError): + h0_gfn2._get_hscale() + + +def test_no_h0_fail_4() -> None: numbers = torch.tensor([1]) ihelp = IndexHelper.from_numbers_angular(numbers, {1: [0]}) - _par = par.model_copy(deep=True) + _par = GFN1_XTB.model_copy(deep=True) h0 = GFN1Hamiltonian(numbers, _par, ihelp) _par.hamiltonian = None @@ -68,14 +107,14 @@ def test_no_h0_fail2() -> None: def test_change_type(dtype: torch.dtype) -> None: numbers = torch.tensor([1]) ihelp = IndexHelper.from_numbers_angular(numbers, {1: [0]}) - h0 = GFN1Hamiltonian(numbers, par, ihelp) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp) assert h0.type(dtype).dtype == dtype def test_change_type_fail() -> None: numbers = torch.tensor([1]) ihelp = IndexHelper.from_numbers_angular(numbers, {1: [0]}) - h0 = GFN1Hamiltonian(numbers, par, ihelp) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp) # trying to use setter with pytest.raises(AttributeError): @@ -93,7 +132,7 @@ def test_change_device(device_str: str) -> None: numbers = torch.tensor([1], device=device) ihelp = IndexHelper.from_numbers_angular(numbers, {1: [0]}) - h0 = GFN1Hamiltonian(numbers, par, ihelp, device=device) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp, device=device) if device_str == "cpu": dev = torch.device("cpu") @@ -109,7 +148,7 @@ def test_change_device(device_str: str) -> None: def test_change_device_fail() -> None: numbers = torch.tensor([1]) ihelp = IndexHelper.from_numbers_angular(numbers, {1: [0]}) - h0 = GFN1Hamiltonian(numbers, par, ihelp) + h0 = GFN1Hamiltonian(numbers, GFN1_XTB, ihelp) # trying to use setter with pytest.raises(AttributeError): @@ -125,4 +164,4 @@ def test_wrong_device_fail() -> None: # numbers is on a different device with pytest.raises(ValueError): - GFN1Hamiltonian(numbers, par, ihelp, device=torch.device("cpu")) + GFN1Hamiltonian(numbers, GFN1_XTB, ihelp, device=torch.device("cpu")) diff --git a/test/test_hamiltonian/test_h0.py b/test/test_hamiltonian/test_gfn1.py similarity index 100% rename from test/test_hamiltonian/test_h0.py rename to test/test_hamiltonian/test_gfn1.py diff --git a/test/test_hamiltonian/test_gfn2.py b/test/test_hamiltonian/test_gfn2.py new file mode 100644 index 00000000..56df9877 --- /dev/null +++ b/test/test_hamiltonian/test_gfn2.py @@ -0,0 +1,276 @@ +# This file is part of dxtb. +# +# SPDX-Identifier: Apache-2.0 +# Copyright (C) 2024 Grimme Group +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Run tests for building the Hamiltonian matrix. +References calculated with tblite 0.3.0. +""" + +from __future__ import annotations + +from math import sqrt + +import pytest +import torch + +from dxtb import GFN2_XTB, IndexHelper +from dxtb._src.integral.driver.pytorch import IntDriverPytorch as IntDriver +from dxtb._src.integral.driver.pytorch import OverlapPytorch as Overlap +from dxtb._src.param import Param +from dxtb._src.typing import DD, Tensor +from dxtb._src.xtb.gfn2 import GFN2Hamiltonian + +from ..conftest import DEVICE +from .samples import samples + +small = ["H2", "LiH", "S2", "SiH4"] + +# No overlap! +ref = { + "H2": torch.tensor( + [ + *[-3.91986886330804e-1, -7.10344830094570e-1, -7.10344830094570e-1], + *[-3.91986886330804e-1], + ], + dtype=torch.float64, + ).reshape(2, 2), + "LiH": torch.tensor( + [ + *[-1.85652593278629e-1, -1.32503346243859e-1, -1.32503346243859e-1], + *[-1.32503346243859e-1, -5.11314980009571e-1, -1.32503346243859e-1], + *[-7.93540992090888e-2, -7.93540992090888e-2, -7.93540992090888e-2], + *[-5.67490975937946e-1, -1.32503346243859e-1, -7.93540992090888e-2], + *[-7.93540992090888e-2, -7.93540992090888e-2, -5.67490975937946e-1], + *[-1.32503346243859e-1, -7.93540992090888e-2, -7.93540992090888e-2], + *[-7.93540992090888e-2, -5.67490975937946e-1, -5.11314980009571e-1], + *[-5.67490975937946e-1, -5.67490975937946e-1, -5.67490975937946e-1], + *[-3.91761150304657e-1], + ], + dtype=torch.float64, + ).reshape(5, 5), + "S2": torch.tensor( + [ + *[-7.35145168851922e-1, -5.76455469048441e-1, -5.76455469048441e-1], + *[-5.76455469048441e-1, -3.78932593408957e-1, -3.78932593408957e-1], + *[-3.78932593408957e-1, -3.78932593408957e-1, -3.78932593408957e-1], + *[-7.69842185412757e-1, -8.16534240351844e-1, -8.16534240351844e-1], + *[-8.16534240351844e-1, -7.11124165876788e-1, -7.11124165876788e-1], + *[-7.11124165876788e-1, -7.11124165876788e-1, -7.11124165876788e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.69842185412757e-1, -8.16534240351844e-1, -8.16534240351844e-1], + *[-8.16534240351844e-1, -7.11124165876788e-1, -7.11124165876788e-1], + *[-7.11124165876788e-1, -7.11124165876788e-1, -7.11124165876788e-1], + *[-7.35145168851922e-1, -5.76455469048441e-1, -5.76455469048441e-1], + *[-5.76455469048441e-1, -3.78932593408957e-1, -3.78932593408957e-1], + *[-3.78932593408957e-1, -3.78932593408957e-1, -3.78932593408957e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-8.16534240351844e-1, -7.93531808806201e-1, -7.93531808806201e-1], + *[-7.93531808806201e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.76455469048441e-1, -4.17765769244959e-1, -4.17765769244959e-1], + *[-4.17765769244959e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + *[-7.11124165876788e-1, -5.06786853369552e-1, -5.06786853369552e-1], + *[-5.06786853369552e-1, -7.90334123556902e-2, -7.90334123556902e-2], + *[-7.90334123556902e-2, -7.90334123556902e-2, -7.90334123556902e-2], + *[-3.78932593408957e-1, -2.20242893605476e-1, -2.20242893605476e-1], + *[-2.20242893605476e-1, -2.27200179659920e-2, -2.27200179659920e-2], + *[-2.27200179659920e-2, -2.27200179659920e-2, -2.27200179659920e-2], + ], + dtype=torch.float64, + ).reshape(18, 18), + "SiH4": torch.tensor( + [ + *[-5.52421014778594e-1, -3.94095353144021e-1, -3.94095353144021e-1], + *[-3.94095353144021e-1, -2.96900602546706e-1, -2.96900602546706e-1], + *[-2.96900602546706e-1, -2.96900602546706e-1, -2.96900602546706e-1], + *[-8.79823684667296e-1, -8.79823684667296e-1, -8.79823684667296e-1], + *[-8.79823684667296e-1, -3.94095353144021e-1, -2.35769691509448e-1], + *[-2.35769691509448e-1, -2.35769691509448e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -5.80507239433161e-1, -3.94095353144021e-1], + *[-2.35769691509448e-1, -2.35769691509448e-1, -2.35769691509448e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-3.94095353144021e-1, -2.35769691509448e-1, -2.35769691509448e-1], + *[-2.35769691509448e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-5.80507239433161e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -2.96900602546706e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -2.96900602546706e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-2.96900602546706e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -2.96900602546706e-1, -1.38574940912133e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -2.96900602546706e-1], + *[-1.38574940912133e-1, -1.38574940912133e-1, -1.38574940912133e-1], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.13801903148179e-2], + *[-4.13801903148179e-2, -4.13801903148179e-2, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-8.79823684667296e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-3.91823589915583e-1, -6.98232664598717e-1, -6.98232664598717e-1], + *[-6.98232664598717e-1, -8.79823684667296e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -5.80507239433161e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -6.98232664598717e-1, -3.91823589915583e-1], + *[-6.98232664598717e-1, -6.98232664598717e-1, -8.79823684667296e-1], + *[-5.80507239433161e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -6.98232664598717e-1], + *[-6.98232664598717e-1, -3.91823589915583e-1, -6.98232664598717e-1], + *[-8.79823684667296e-1, -5.80507239433161e-1, -5.80507239433161e-1], + *[-5.80507239433161e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-4.79036928570672e-1, -4.79036928570672e-1, -4.79036928570672e-1], + *[-6.98232664598717e-1, -6.98232664598717e-1, -6.98232664598717e-1], + *[-3.91823589915583e-1], + ], + dtype=torch.float64, + ).reshape(13, 13), +} + + +def run( + numbers: Tensor, + positions: Tensor, + par: Param, + ref: Tensor, + dd: DD, + skip_overlap: bool = True, +) -> None: + tol = sqrt(torch.finfo(dd["dtype"]).eps) * 10 + + ihelp = IndexHelper.from_numbers(numbers, par) + driver = IntDriver(numbers, par, ihelp, **dd) + overlap = Overlap(**dd) + h0 = GFN2Hamiltonian(numbers, par, ihelp, **dd) + + driver.setup(positions) + + s = None if skip_overlap is True else overlap.build(driver) + h = h0.build(positions, s) + assert pytest.approx(h.cpu(), abs=tol) == h.mT.cpu() + assert pytest.approx(h.cpu(), abs=tol) == ref.cpu() + + +# No overlap! +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +@pytest.mark.parametrize("name", small) +def test_single(dtype: torch.dtype, name: str) -> None: + dd: DD = {"dtype": dtype, "device": DEVICE} + + sample = samples[name] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + _ref = ref[name].to(**dd) + + run(numbers, positions, GFN2_XTB, _ref, dd) diff --git a/test/test_integrals/test_general.py b/test/test_integrals/test_general.py index 2a0c20d3..3890458c 100644 --- a/test/test_integrals/test_general.py +++ b/test/test_integrals/test_general.py @@ -34,6 +34,17 @@ from ..conftest import DEVICE +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_properties(dtype: torch.dtype): + dd: DD = {"dtype": dtype, "device": DEVICE} + + mgr = ints.DriverManager(INTDRIVER_LIBCINT, **dd) + i = ints.Integrals(mgr, **dd) + + assert len(i.labels) == 0 + assert len(i) == 0 + + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_empty(dtype: torch.dtype): dd: DD = {"dtype": dtype, "device": DEVICE} diff --git a/test/test_integrals/test_wrappers.py b/test/test_integrals/test_wrappers.py index 488e799b..a8d0e085 100644 --- a/test/test_integrals/test_wrappers.py +++ b/test/test_integrals/test_wrappers.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Test overlap build from integral container. +Test wrappers for integrals. """ from __future__ import annotations @@ -39,60 +39,71 @@ ) -def test_fail() -> None: +@pytest.mark.parametrize("par", [GFN1_XTB, GFN2_XTB]) +def test_fail(par: Param) -> None: with pytest.raises(ValueError): - par1 = GFN1_XTB.model_copy(deep=True) - par1.meta = None - wrappers.hcore(numbers, positions, par1) + _par = par.model_copy(deep=True) + _par.meta = None + wrappers.hcore(numbers, positions, _par) with pytest.raises(ValueError): - par1 = GFN1_XTB.model_copy(deep=True) - assert par1.meta is not None + _par = par.model_copy(deep=True) + assert _par.meta is not None - par1.meta.name = None - wrappers.hcore(numbers, positions, par1) + _par.meta.name = None + wrappers.hcore(numbers, positions, _par) with pytest.raises(ValueError): - par1 = GFN1_XTB.model_copy(deep=True) - assert par1.meta is not None + _par = par.model_copy(deep=True) + assert _par.meta is not None - par1.meta.name = "fail" - wrappers.hcore(numbers, positions, par1) + _par.meta.name = "fail" + wrappers.hcore(numbers, positions, _par) with pytest.raises(ValueError): # pylint: disable=import-outside-toplevel from dxtb._src.integral.wrappers import _integral - _integral("fail", numbers, positions, par1) # type: ignore + _integral("fail", numbers, positions, _par) # type: ignore -@pytest.mark.parametrize("par", [GFN1_XTB]) -def test_h0_gfn1(par: Param) -> None: - h0 = wrappers.hcore(numbers, positions, par) +def test_h0_gfn1() -> None: + h0 = wrappers.hcore(numbers, positions, GFN1_XTB) assert h0.shape == (17, 17) - h0 = wrappers.hcore(numbers, positions, par, cn=None) + h0 = wrappers.hcore(numbers, positions, GFN1_XTB, cn=None) assert h0.shape == (17, 17) -@pytest.mark.parametrize("par", [GFN2_XTB]) -def test_h0_gfn2(par: Param) -> None: - with pytest.raises(NotImplementedError): - wrappers.hcore(numbers, positions, par) +def test_h0_gfn2() -> None: + h0 = wrappers.hcore(numbers, positions, GFN2_XTB) + assert h0.shape == (13, 13) + + h0 = wrappers.hcore(numbers, positions, GFN2_XTB, cn=None) + assert h0.shape == (13, 13) def test_overlap() -> None: s = wrappers.overlap(numbers, positions, GFN1_XTB) assert s.shape == (17, 17) + s = wrappers.overlap(numbers, positions, GFN2_XTB) + assert s.shape == (13, 13) + @pytest.mark.skipif(not has_libcint, reason="libcint not available") def test_dipole() -> None: s = wrappers.dipint(numbers, positions, GFN1_XTB) assert s.shape == (3, 17, 17) + s = wrappers.dipint(numbers, positions, GFN2_XTB) + assert s.shape == (3, 13, 13) + @pytest.mark.skipif(not has_libcint, reason="libcint not available") def test_quad() -> None: s = wrappers.quadint(numbers, positions, GFN1_XTB) assert s.shape == (9, 17, 17) + + s = wrappers.quadint(numbers, positions, GFN2_XTB) + assert s.shape == (9, 13, 13) diff --git a/test/test_interaction/test_list.py b/test/test_interaction/test_list.py index 191e22e4..4bee15b1 100644 --- a/test/test_interaction/test_list.py +++ b/test/test_interaction/test_list.py @@ -22,15 +22,23 @@ import torch -from dxtb import IndexHelper -from dxtb._src.components.interactions import ( - InteractionList, - InteractionListCache, -) +from dxtb import GFN1_XTB, IndexHelper +from dxtb.components.base import InteractionList, InteractionListCache +from dxtb.components.coulomb import new_es3 from ..conftest import DEVICE +def test_properties() -> None: + numbers = torch.tensor([6, 1], device=DEVICE) + es3 = new_es3(numbers, GFN1_XTB) + ilist = InteractionList(es3) + + assert len(ilist.components) == 1 + assert len(ilist) == 1 + assert id(ilist.get_interaction("ES3")) == id(es3) + + def test_empty() -> None: ilist = InteractionList() assert len(ilist.components) == 0