diff --git a/tests/test_analysis/test_cross.py b/tests/test_analysis/test_cross.py index 187ebe8..dd4c315 100644 --- a/tests/test_analysis/test_cross.py +++ b/tests/test_analysis/test_cross.py @@ -1,5 +1,6 @@ import numpy as np -import functools +from abc import abstractmethod +import unittest import seemps from seemps.state import MPS, scprod @@ -43,7 +44,7 @@ def callback(mps: MPS, **kwargs) -> float: def gaussian_setup_mps(dims, n=5, a=-1, b=1): - func = lambda tensor: np.exp(-np.sum(tensor, axis=0) ** 2) + func = lambda tensor: np.exp(-(np.sum(tensor, axis=0) ** 2)) intervals = [RegularInterval(a, b, 2**n) for _ in range(dims)] mesh = Mesh(intervals) # type: ignore mesh_tensor = mesh.to_tensor() @@ -66,31 +67,29 @@ def gaussian_setup_1d_mpo(is_diagonal, n=5, a=-1, b=1): class CrossTests(TestCase): - def setUp(self, method): - if method == "maxvol": - self.cross_method = cross_maxvol - elif method == "dmrg": - self.cross_method = cross_dmrg - elif method == "greedy_full": - strat = CrossStrategyGreedy(partial=False) - self.cross_method = functools.partial(cross_greedy, cross_strategy=strat) - elif method == "greedy_partial": - strat = CrossStrategyGreedy(partial=True) - self.cross_method = functools.partial(cross_greedy, cross_strategy=strat) - - def _test_load_1d_mps(self, n=5): + @classmethod + def setUpClass(cls): + if cls is CrossTests: + raise unittest.SkipTest(f"Skip {cls} tests, it's a base class") + super().setUpClass() + + @abstractmethod + def cross_method(self, function, *args, **kwdargs): + raise Exception("cross_method not implemented in " + str(self.cls)) + + def test_load_1d_mps(self, n=5): func, mesh, _, y = gaussian_setup_mps(1, n=n) black_box = BlackBoxLoadMPS(func, mesh) cross_results = self.cross_method(black_box) self.assertSimilar(y, cross_results.mps.to_vector()) - def _test_load_2d_mps(self, n=5): + def test_load_2d_mps(self, n=5): func, mesh, _, y = gaussian_setup_mps(2, n=n) black_box = BlackBoxLoadMPS(func, mesh) cross_results = self.cross_method(black_box) self.assertSimilar(y, cross_results.mps.to_vector()) - def _test_load_2d_mps_with_order_B(self, n=5): + def test_load_2d_mps_with_order_B(self, n=5): func, mesh, _, y = gaussian_setup_mps(2, n=n) black_box = BlackBoxLoadMPS(func, mesh, mps_order="B") cross_results = self.cross_method(black_box) @@ -98,14 +97,14 @@ def _test_load_2d_mps_with_order_B(self, n=5): tensor = reorder_tensor(cross_results.mps.to_vector(), qubits) self.assertSimilar(y, tensor) - def _test_load_2d_tt(self, n=5): + def test_load_2d_tt(self, n=5): func, mesh, _, y = gaussian_setup_mps(2, n=n) black_box = BlackBoxLoadTT(func, mesh) cross_results = self.cross_method(black_box) vector = cross_results.mps.to_vector() self.assertSimilar(y, vector) - def _test_2d_integration_callback(self, n=8): + def test_2d_integration_callback(self, n=8): a, b = -1, 1 func = lambda tensor: np.exp(tensor[0] + tensor[1]) # f(x,y) = e^(x+y) interval = RegularInterval(a, b, 2**n, endpoint_right=True) @@ -119,14 +118,14 @@ def _test_2d_integration_callback(self, n=8): integral = cross_results.callback_output[-1] # type: ignore self.assertAlmostEqual(integral, exact_integral) - def _test_load_1d_mpo_diagonal(self, n=5): + def test_load_1d_mpo_diagonal(self, n=5): func, x, mesh, mps_I = gaussian_setup_1d_mpo(is_diagonal=True, n=n) black_box = BlackBoxLoadMPO(func, mesh, is_diagonal=True) cross_results = self.cross_method(black_box) mps_diagonal = mps_as_mpo(cross_results.mps).apply(mps_I) self.assertSimilar(func(x, x), mps_diagonal.to_vector()) - def _test_load_1d_mpo_nondiagonal(self, n=5): + def test_load_1d_mpo_nondiagonal(self, n=5): func, x, mesh, _ = gaussian_setup_1d_mpo(is_diagonal=False, n=n) black_box = BlackBoxLoadMPO(func, mesh) cross_results = self.cross_method(black_box) @@ -134,14 +133,14 @@ def _test_load_1d_mpo_nondiagonal(self, n=5): xx, yy = np.meshgrid(x, x) self.assertSimilar(func(xx, yy), y_mps) - def _test_compose_1d_mps_list(self, n=5): + def test_compose_1d_mps_list(self, n=5): _, _, mps_0, y_0 = gaussian_setup_mps(1, n=n) func = lambda v: v[0] + np.sin(v[1]) + np.cos(v[2]) black_box = BlackBoxComposeMPS(func, [mps_0, mps_0, mps_0]) cross_results = self.cross_method(black_box) self.assertSimilar(func([y_0, y_0, y_0]), cross_results.mps.to_vector()) - def _test_compose_2d_mps_list(self, n=5): + def test_compose_2d_mps_list(self, n=5): _, _, mps_0, y_0 = gaussian_setup_mps(2, n=n) func = lambda v: v[0] + np.sin(v[1]) + np.cos(v[2]) black_box = BlackBoxComposeMPS(func, [mps_0, mps_0, mps_0]) @@ -150,131 +149,33 @@ def _test_compose_2d_mps_list(self, n=5): class TestCrossMaxvol(CrossTests): - def setUp(self): - super().setUp("maxvol") - - def test_load_1d_mps(self): - super()._test_load_1d_mps() - - def test_load_2d_mps(self): - super()._test_load_2d_mps() - - def test_load_2d_mps_with_order_B(self): - super()._test_load_2d_mps_with_order_B() - - def test_load_2d_tt(self): - super()._test_load_2d_tt() - - def test_2d_integration_callback(self): - super()._test_2d_integration_callback() - - def test_load_1d_mpo_diagonal(self): - super()._test_load_1d_mpo_diagonal() - - def test_load_1d_mpo_nondiagonal(self): - super()._test_load_1d_mpo_nondiagonal() - - def test_compose_1d_mps_list(self): - super()._test_compose_1d_mps_list() - - def test_compose_2d_mps_list(self): - super()._test_compose_2d_mps_list() + def cross_method(self, function, *args, **kwdargs): + return cross_maxvol(function, *args, **kwdargs) class TestCrossDMRG(CrossTests): - def setUp(self): - super().setUp("dmrg") - - def test_load_1d_mps(self): - super()._test_load_1d_mps() - - def test_load_2d_mps(self): - super()._test_load_2d_mps() - - def test_load_2d_mps_with_order_B(self): - super()._test_load_2d_mps_with_order_B() - - def test_load_2d_tt(self): - super()._test_load_2d_tt() - - def test_2d_integration_callback(self): - super()._test_2d_integration_callback() - - def test_load_1d_mpo_diagonal(self): - super()._test_load_1d_mpo_diagonal() - - def test_load_1d_mpo_nondiagonal(self): - super()._test_load_1d_mpo_nondiagonal() - - def test_compose_1d_mps_list(self): - super()._test_compose_1d_mps_list() - - def test_compose_2d_mps_list(self): - super()._test_compose_2d_mps_list() + def cross_method(self, function, *args, **kwdargs): + return cross_dmrg(function, *args, **kwdargs) class TestCrossGreedyFull(CrossTests): - def setUp(self): - super().setUp("greedy_full") - - def test_load_1d_mps(self): - super()._test_load_1d_mps() - - def test_load_2d_mps(self): - super()._test_load_2d_mps() - - def test_load_2d_mps_with_order_B(self): - super()._test_load_2d_mps_with_order_B() - - def test_load_2d_tt(self): - super()._test_load_2d_tt() - - def test_2d_integration_callback(self): - super()._test_2d_integration_callback() - - def test_load_1d_mpo_diagonal(self): - super()._test_load_1d_mpo_diagonal() - - def test_load_1d_mpo_nondiagonal(self): - super()._test_load_1d_mpo_nondiagonal() - - def test_compose_1d_mps_list(self): - super()._test_compose_1d_mps_list() - - def test_compose_2d_mps_list(self): - super()._test_compose_2d_mps_list() + def cross_method(self, function, *args, **kwdargs): + return cross_greedy( + function, + *args, + cross_strategy=CrossStrategyGreedy(partial=False), + **kwdargs, + ) class TestCrossGreedyPartial(CrossTests): - def setUp(self): - super().setUp("greedy_partial") - - def test_load_1d_mps(self): - super()._test_load_1d_mps() - - def test_load_2d_mps(self): - super()._test_load_2d_mps() - - def test_load_2d_mps_with_order_B(self): - super()._test_load_2d_mps_with_order_B() - - def test_load_2d_tt(self): - super()._test_load_2d_tt() - - def test_2d_integration_callback(self): - super()._test_2d_integration_callback() - - def test_load_1d_mpo_diagonal(self): - super()._test_load_1d_mpo_diagonal() - - def test_load_1d_mpo_nondiagonal(self): - super()._test_load_1d_mpo_nondiagonal() - - def test_compose_1d_mps_list(self): - super()._test_compose_1d_mps_list() - - def test_compose_2d_mps_list(self): - super()._test_compose_2d_mps_list() + def cross_method(self, function, *args, **kwdargs): + return cross_greedy( + function, + *args, + cross_strategy=CrossStrategyGreedy(partial=True), + **kwdargs, + ) class TestSkeleton(TestCase):