Skip to content

Commit

Permalink
Fix device in full SCF (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Jan 13, 2025
1 parent 92a7e77 commit 79f011a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
43 changes: 43 additions & 0 deletions examples/issues/194/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch

import dxtb
from dxtb.typing import DD

dd: DD = {"device": torch.device("cuda:0"), "dtype": torch.double}


numbers = torch.tensor(
[
[3, 1, 0],
[8, 1, 1],
],
device=dd["device"],
)
positions = torch.tensor(
[
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0],
],
[
[0.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 2.0],
],
],
**dd
)
positions.requires_grad_(True)
charge = torch.tensor([0, 0], **dd)


# no conformers -> batched mode 1
opts = {"verbosity": 6, "batch_mode": 1, "scf_mode": dxtb.labels.SCF_MODE_FULL}

dxtb.timer.reset()

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd)
result = calc.energy(positions, chrg=charge)

dxtb.timer.print(v=-999)
12 changes: 6 additions & 6 deletions src/dxtb/_src/scf/unrolling/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@ def scf(
potential_data = self._data.potential.copy()

# shape: (nb, <number of moments>, norb)
q_converged = torch.full_like(guess, defaults.PADNZ)
q_converged = torch.full_like(guess, defaults.PADNZ, device=self.device)

overlap = self._data.ints.overlap
hcore = self._data.ints.hcore
dipole = self._data.ints.dipole
quad = self._data.ints.quadrupole

# indices for systems in batch, required for culling
idxs = torch.arange(guess.size(0))
idxs = torch.arange(guess.size(0), device=self.device)

# tracker for converged systems
converged = torch.full(idxs.shape, False)
converged = torch.full(idxs.shape, False, device=self.device)

# maximum number of orbitals in batch
norb = self._data.ihelp.nao
Expand Down Expand Up @@ -291,7 +291,7 @@ def scf(

# collect unconverged indices with convergence tracker; charges
# are already culled, and hence, require no further indexing
idxs = torch.arange(guess.size(0))
idxs = torch.arange(guess.size(0), device=self.device)
iconv = idxs[~converged]
q_converged[iconv, ..., :norb] = q

Expand All @@ -311,14 +311,14 @@ def scf(
msg + msg_converged, exceptions.SCFConvergenceWarning
)

if culled:
if culled is True:
# write converged variables back to `self._data` for final
# energy evaluation; if we continue with unconverged properties,
# we first need to write the unconverged values from the
# `_data` object back to the converged variable before saving it
# for the final energy evaluation
if not converged.all():
idxs = torch.arange(guess.size(0))
idxs = torch.arange(guess.size(0), device=self.device)
iconv = idxs[~converged]

cevals[iconv, :norb] = self._data.evals
Expand Down

0 comments on commit 79f011a

Please sign in to comment.