Skip to content

Commit

Permalink
Merge pull request #56 from PaulaGarciaMolina/main
Browse files Browse the repository at this point in the history
Update algorithms tests
  • Loading branch information
PaulaGarciaMolina authored Jan 11, 2024
2 parents ba94b4f + 1b79eb6 commit 12cfa1e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 20 deletions.
18 changes: 4 additions & 14 deletions tests/test_arnoldi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from seemps.optimization.arnoldi import arnoldi_eigh
from seemps.hamiltonians import HeisenbergHamiltonian

from seemps import MPO, product_state
from seemps.hamiltonians import HeisenbergHamiltonian
from seemps.optimization.arnoldi import arnoldi_eigh

from .tools import *


Expand All @@ -26,15 +28,3 @@ def test_arnoldi_eigh_with_local_field(self):
result = arnoldi_eigh(H, guess, maxiter=20, tol=1e-15)
self.assertAlmostEqual(result.energy, H.expectation(exact))
self.assertSimilarStates(result.state, exact, atol=1e-7)

def test_arnoldi_eigsh_acknowledges_tolerance(self):
"""Check that algorithm stops if energy change is below tolerance."""
N = 4
H = HeisenbergHamiltonian(N).to_mpo()
tol = 1e-5
guess = CanonicalMPS(
random_uniform_mps(2, N, rng=self.rng), center=0, normalize=True
)
result = arnoldi_eigh(H, guess, tol=tol, maxiter=10)
self.assertTrue(result.converged)
self.assertTrue(abs(result.trajectory[-1] - result.trajectory[-2]) < tol)
8 changes: 5 additions & 3 deletions tests/test_gradient_descent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from seemps.optimization.descent import gradient_descent
from seemps.hamiltonians import HeisenbergHamiltonian

from seemps import MPO, product_state
from seemps.hamiltonians import HeisenbergHamiltonian
from seemps.optimization.descent import gradient_descent

from .tools import *


Expand Down Expand Up @@ -59,4 +61,4 @@ def test_gradient_descent_with_callback(self):
H, guess, maxiter=maxiter, tol=1e-15, callback=callback_func
)
self.assertSimilar(norms, np.ones(len(norms)))
self.assertEqual(maxiter, len(norms))
self.assertEqual(maxiter + 1, len(norms))
6 changes: 3 additions & 3 deletions tests/test_itime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_euler_with_callback(self):
callback_func, norms = callback()
result = euler(H, guess, maxiter=maxiter, callback=callback_func)
self.assertSimilar(norms, np.ones(len(norms)))
self.assertEqual(maxiter, len(norms))
self.assertEqual(maxiter + 1, len(norms))


class TestImprovedEuler(TestCase):
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_improved_euler_with_callback(self):
callback_func, norms = callback()
result = improved_euler(H, guess, maxiter=maxiter, callback=callback_func)
self.assertSimilar(norms, np.ones(len(norms)))
self.assertEqual(maxiter, len(norms))
self.assertEqual(maxiter + 1, len(norms))


class TestRungeKutta(TestCase):
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_runge_kutta_with_callback(self):
callback_func, norms = callback()
result = runge_kutta(H, guess, maxiter=maxiter, callback=callback_func)
self.assertSimilar(norms, np.ones(len(norms)))
self.assertEqual(maxiter, len(norms))
self.assertEqual(maxiter + 1, len(norms))


class TestRungeKuttaFehlberg(TestCase):
Expand Down

0 comments on commit 12cfa1e

Please sign in to comment.