Skip to content

Commit

Permalink
Fixed comparison of states
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Dec 20, 2023
1 parent d6c16f1 commit 3ae5059
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/test_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_gradient_descent_with_local_field(self):
exact = product_state([0, 1], N)
result = gradient_descent(H, guess, tol=1e-15)
self.assertAlmostEqual(result.energy, H.expectation(exact))
self.assertSimilar(result.state, exact, atol=1e-7)
self.assertSimilarStates(result.state, exact, atol=1e-7)

def test_gradient_descent_acknowledges_tolerance(self):
"""Check that algorithm stops if energy change is below tolerance."""
Expand All @@ -42,9 +42,11 @@ def test_gradient_descent_acknowledges_tolerance(self):

def callback(self):
norms = []
def callback_func(state:MPS):

def callback_func(state: MPS):
norms.append(np.sqrt(state.norm_squared()))
return None

return callback_func, norms

def test_gradient_descent_with_callback(self):
Expand All @@ -53,6 +55,8 @@ def test_gradient_descent_with_callback(self):
H = self.make_local_Sz_mpo(N)
guess = product_state(np.asarray([1, 1]) / np.sqrt(2.0), N)
callback_func, norms = self.callback()
result = gradient_descent(H, guess, maxiter=maxiter, tol=1e-15, callback=callback_func)
result = gradient_descent(
H, guess, maxiter=maxiter, tol=1e-15, callback=callback_func
)
self.assertSimilar(norms, np.ones(len(norms)))
self.assertEqual(maxiter, len(norms))

0 comments on commit 3ae5059

Please sign in to comment.