Skip to content

Commit

Permalink
Profile moving devices (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jan 1, 2025
1 parent 3d10537 commit 38ad6fa
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 30 deletions.
95 changes: 95 additions & 0 deletions examples/profiling/move-device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple energy calculation.
"""
import functools
import logging
import traceback

import torch

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(message)s",
)


def log_tensor_move(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
device = None
if args:
device = args[0]
elif "device" in kwargs:
device = kwargs["device"]

# Get tensor details
tensor_id = id(self)
tensor_shape = tuple(self.size())
tensor_dtype = self.dtype
tensor_device = self.device

# Capture stack trace
stack = "".join(traceback.format_stack(limit=4)[:-1])

# Only log if the tensor is moved to a different device
if tensor_device == device:
return func(self, *args, **kwargs)

logging.info(
f"Tensor ID: {tensor_id}, Shape: {tensor_shape}, Dtype: {tensor_dtype}, "
f"From Device: {tensor_device}, To Device: {device}, "
f"Called from:\n{stack}"
)

return func(self, *args, **kwargs)

return wrapper


def override_tensor_methods():
tensor_methods_to_override = ["to", "cuda", "cpu"]

for method_name in tensor_methods_to_override:
original_method = getattr(torch.Tensor, method_name)
decorated_method = log_tensor_move(original_method)
setattr(torch.Tensor, method_name, decorated_method)


override_tensor_methods()

###############################################################################
###############################################################################
###############################################################################
###############################################################################

import dxtb

dd = {"dtype": torch.double, "device": torch.device("cuda:0")}

# LiH
numbers = torch.tensor([3, 1], device=dd["device"])
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.5]], **dd)

# instantiate a calculator
opts = {"verbosity": 6}
calc = dxtb.calculators.GFN1Calculator(numbers, opts=opts, **dd)

# compute the energy
pos = positions.clone().requires_grad_(True)
energy = calc.get_energy(pos)
26 changes: 23 additions & 3 deletions src/dxtb/_src/basis/indexhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,11 @@ def __init__(

@classmethod
def from_numbers(
cls, numbers: Tensor, par: Param, batch_mode: int | None = None
cls,
numbers: Tensor,
par: Param,
batch_mode: int | None = None,
move_to_numbers_device: bool = True,
) -> IndexHelper:
"""
Construct an index helper instance from atomic numbers and a
Expand All @@ -308,21 +312,32 @@ def from_numbers(
- 0: Single system
- 1: Multiple systems with padding
- 2: Multiple systems with no padding (conformer ensemble)
move_to_numbers_device : bool
Move the resulting tensors to the device of the ``numbers`` tensor.
This should be switched off for GPU calculations that use `libcint`
for integrals as the :class:`.IndexHelper` has to be on the CPU
for this step.
Returns
-------
IndexHelper
Instance of index helper for given basis set.
"""
angular = get_elem_angular(par.element)
return cls.from_numbers_angular(numbers, angular, batch_mode)
return cls.from_numbers_angular(
numbers,
angular,
batch_mode,
move_to_numbers_device=move_to_numbers_device,
)

@classmethod
def from_numbers_angular(
cls,
numbers: Tensor,
angular: dict[int, list[int]],
batch_mode: int | None = None,
move_to_numbers_device: bool = True,
) -> IndexHelper:
"""
Construct an index helper instance from atomic numbers and their
Expand Down Expand Up @@ -350,14 +365,19 @@ def from_numbers_angular(
- 0: Single system
- 1: Multiple systems with padding
- 2: Multiple systems with no padding (conformer ensemble)
move_to_numbers_device : bool
Move the resulting tensors to the device of the ``numbers`` tensor.
This should be switched off for GPU calculations that use `libcint`
for integrals as the :class:`.IndexHelper` has to be on the CPU
for this step.
Returns
-------
IndexHelper
Instance of index helper for given basis set.
"""
device = numbers.device
cpu = torch.device("cpu")
device = numbers.device if move_to_numbers_device else cpu

# Ensure that all tensors are moved to CPU to avoid inefficient
# memory transfers between devices (.item() and native for-loops).
Expand Down
10 changes: 5 additions & 5 deletions src/dxtb/_src/calculators/config/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
# Fermi
fermi_etemp: float = defaults.FERMI_ETEMP,
fermi_maxiter: int = defaults.FERMI_MAXITER,
fermi_thresh: dict = defaults.FERMI_THRESH,
fermi_thresh: float | int | None = defaults.FERMI_THRESH,
fermi_partition: str | int = defaults.FERMI_PARTITION,
# PyTorch
device: torch.device = get_default_device(),
Expand Down Expand Up @@ -376,7 +376,7 @@ class ConfigFermi:
maxiter: int
"""Maximum number of iterations for Fermi smearing."""

thresh: dict
thresh: float | int | None
"""Float data type dependent threshold for Fermi iterations."""

partition: int
Expand All @@ -395,7 +395,7 @@ def __init__(
*,
etemp: float | int = defaults.FERMI_ETEMP,
maxiter: int = defaults.FERMI_MAXITER,
thresh: dict = defaults.FERMI_THRESH,
thresh: float | int | None = defaults.FERMI_THRESH,
partition: str | int = defaults.FERMI_PARTITION,
# PyTorch
device: torch.device = get_default_device(),
Expand Down Expand Up @@ -444,7 +444,7 @@ def __init__(
f"'{type(partition)}' was given."
)

def info(self) -> dict[str, dict[str, float | int | str]]:
def info(self) -> dict[str, dict[str, None | float | int | str]]:
"""
Return a dictionary with the Fermi smearing configuration.
Expand All @@ -457,7 +457,7 @@ def info(self) -> dict[str, dict[str, float | int | str]]:
"Fermi Smearing": {
"Temperature": self.etemp,
"Maxiter": self.maxiter,
"Threshold": self.thresh[self.dtype].item(),
"Threshold": self.thresh,
"Partioning": labels.FERMI_PARTITION_MAP[self.partition],
}
}
Expand Down
16 changes: 13 additions & 3 deletions src/dxtb/_src/calculators/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def __init__(
numbers : Tensor
Atomic numbers for all atoms in the system (shape: ``(..., nat)``).
par : Param
Representation of an extended tight-binding model (full xtb
Representation of an extended tight-binding model (full xTB
parametrization). Decides energy contributions.
classical : Sequence[Classical] | None, optional
Additional classical contributions. Defaults to ``None``.
Expand Down Expand Up @@ -515,6 +515,16 @@ def __init__(
if self.opts.batch_mode == 0 and numbers.ndim > 1:
self.opts.batch_mode = 1

# PERF: The IndexHelper is created on CPU and moved to the device of the
# `number` tensor. This si required for the instantiation of the
# integral classes later in this constructor. However, if the `libcint`
# interface is used, we need to transfer the IndexHelper to the CPU
# again. Correspondingly, we have one unnecessary transfer.
# (It could be circumvented if the intgrals are calculated immediately
# after instantiation, i.e., compute integrals with `libcint` first,
# then move IndexHelper to the device and compute Hamiltonian. However,
# this would require a change in the code structure. So we take the
# very small performance hit here.)
self.ihelp = IndexHelper.from_numbers(
numbers, par, self.opts.batch_mode
)
Expand Down Expand Up @@ -546,7 +556,7 @@ def __init__(
else:
raise TypeError(
"Expected 'interaction' to be 'None' or of type 'Interaction', "
"'list[Interaction]' or 'tuple[Interaction]', but got "
"'list[Interaction]', or 'tuple[Interaction]', but got "
f"'{type(interaction).__name__}'."
)

Expand Down Expand Up @@ -590,7 +600,7 @@ def __init__(
else:
raise TypeError(
"Expected 'classical' to be 'None' or of type 'Classical', "
"'list[Classical]' or 'tuple[Classical]', but got "
"'list[Classical]', or 'tuple[Classical]', but got "
f"'{type(classical).__name__}'."
)

Expand Down
11 changes: 5 additions & 6 deletions src/dxtb/_src/constants/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,11 @@
FERMI_MAXITER = 200
"""Maximum number of iterations for Fermi smearing."""

FERMI_THRESH = {
torch.float16: torch.tensor(1e-2, dtype=torch.float16),
torch.float32: torch.tensor(1e-5, dtype=torch.float32),
torch.float64: torch.tensor(1e-10, dtype=torch.float64),
}
"""Convergence thresholds for different float data types."""
FERMI_THRESH = None
"""
Convergence thresholds for different float data types.
``None`` uses sqrt of machine epsilon.
"""

FERMI_PARTITION = labels.FERMI_PARTITION_EQUAL
"""Partitioning scheme for electronic free energy."""
Expand Down
11 changes: 4 additions & 7 deletions src/dxtb/_src/wavefunction/filling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
from __future__ import annotations

import torch
from tad_mctc.convert import any_to_tensor

from dxtb._src.typing import DD, Tensor

from ..constants import defaults

__all__ = [
"get_alpha_beta_occupation",
"get_aufbau_occupation",
Expand Down Expand Up @@ -266,7 +265,7 @@ def get_fermi_occupation(
emo: Tensor,
kt: Tensor | None = None,
mask: Tensor | None = None,
thr: dict[torch.dtype, Tensor] | None = None,
thr: Tensor | float | int | None = None,
maxiter: int = 200,
) -> Tensor:
"""
Expand Down Expand Up @@ -325,10 +324,8 @@ def get_fermi_occupation(
return torch.zeros_like(emo)

if thr is None:
thr = defaults.FERMI_THRESH
thresh = thr.get(emo.dtype, torch.tensor(1e-5, dtype=torch.float)).to(
emo.device
)
thr = torch.tensor(torch.finfo(emo.dtype).eps, **dd) ** 0.5
thresh = any_to_tensor(thr, **dd)

e_fermi, homo = get_fermi_energy(nel, emo, mask=mask)

Expand Down
16 changes: 10 additions & 6 deletions test/test_scf/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,6 @@
opts = {
"fermi_etemp": 300,
"fermi_maxiter": 500,
"fermi_thresh": {
# instead of 1e-5
torch.float32: torch.tensor(1e-4, dtype=torch.float32),
torch.float64: torch.tensor(1e-10, dtype=torch.float64),
},
"scf_mode": labels.SCF_MODE_IMPLICIT_NON_PURE,
"scp_mode": labels.SCP_MODE_POTENTIAL, # better convergence for atoms
"verbosity": 0,
Expand All @@ -348,7 +343,14 @@ def test_element(dtype: torch.dtype, number: int) -> None:

# opts["spin"] = uhf[number - 1]
atol = 1e-5 if dtype == torch.float else 1e-6
options = dict(opts, **{"f_atol": atol, "x_atol": atol})
options = dict(
opts,
**{
"f_atol": atol,
"x_atol": atol,
"fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10,
},
)
calc = Calculator(numbers, par, opts=options, **dd)
results = calc.singlepoint(positions, charges)
assert pytest.approx(r.cpu(), abs=tol) == results.scf.sum(-1).cpu()
Expand Down Expand Up @@ -377,6 +379,7 @@ def test_element_cation(dtype: torch.dtype, number: int) -> None:
**{
"f_atol": 1e-5, # avoids Jacobian inversion error
"x_atol": 1e-5, # avoids Jacobian inversion error
"fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10,
},
)
calc = Calculator(numbers, par, opts=options, **dd)
Expand Down Expand Up @@ -413,6 +416,7 @@ def test_element_anion(dtype: torch.dtype, number: int) -> None:
**{
"f_atol": 1e-5, # avoid Jacobian inversion error
"x_atol": 1e-5, # avoid Jacobian inversion error
"fermi_thresh": 1e-4 if dtype == torch.float32 else 1e-10,
},
)
calc = Calculator(numbers, par, opts=options, **dd)
Expand Down

0 comments on commit 38ad6fa

Please sign in to comment.