Skip to content

Commit

Permalink
Fix class hierarchy in test_cross.py
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Nov 26, 2024
1 parent 2a5ad6c commit 6ae5c06
Showing 1 changed file with 40 additions and 139 deletions.
179 changes: 40 additions & 139 deletions tests/test_analysis/test_cross.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import functools
from abc import abstractmethod
import unittest

import seemps
from seemps.state import MPS, scprod
Expand Down Expand Up @@ -43,7 +44,7 @@ def callback(mps: MPS, **kwargs) -> float:


def gaussian_setup_mps(dims, n=5, a=-1, b=1):
func = lambda tensor: np.exp(-np.sum(tensor, axis=0) ** 2)
func = lambda tensor: np.exp(-(np.sum(tensor, axis=0) ** 2))
intervals = [RegularInterval(a, b, 2**n) for _ in range(dims)]
mesh = Mesh(intervals) # type: ignore
mesh_tensor = mesh.to_tensor()
Expand All @@ -66,46 +67,44 @@ def gaussian_setup_1d_mpo(is_diagonal, n=5, a=-1, b=1):


class CrossTests(TestCase):
def setUp(self, method):
if method == "maxvol":
self.cross_method = cross_maxvol
elif method == "dmrg":
self.cross_method = cross_dmrg
elif method == "greedy_full":
strat = CrossStrategyGreedy(partial=False)
self.cross_method = functools.partial(cross_greedy, cross_strategy=strat)
elif method == "greedy_partial":
strat = CrossStrategyGreedy(partial=True)
self.cross_method = functools.partial(cross_greedy, cross_strategy=strat)

def _test_load_1d_mps(self, n=5):
@classmethod
def setUpClass(cls):
if cls is CrossTests:
raise unittest.SkipTest(f"Skip {cls} tests, it's a base class")
super().setUpClass()

@abstractmethod
def cross_method(self, function, *args, **kwdargs):
raise Exception("cross_method not implemented in " + str(self.cls))

def test_load_1d_mps(self, n=5):
func, mesh, _, y = gaussian_setup_mps(1, n=n)
black_box = BlackBoxLoadMPS(func, mesh)
cross_results = self.cross_method(black_box)
self.assertSimilar(y, cross_results.mps.to_vector())

def _test_load_2d_mps(self, n=5):
def test_load_2d_mps(self, n=5):
func, mesh, _, y = gaussian_setup_mps(2, n=n)
black_box = BlackBoxLoadMPS(func, mesh)
cross_results = self.cross_method(black_box)
self.assertSimilar(y, cross_results.mps.to_vector())

def _test_load_2d_mps_with_order_B(self, n=5):
def test_load_2d_mps_with_order_B(self, n=5):
func, mesh, _, y = gaussian_setup_mps(2, n=n)
black_box = BlackBoxLoadMPS(func, mesh, mps_order="B")
cross_results = self.cross_method(black_box)
qubits = [int(np.log2(s)) for s in mesh.dimensions]
tensor = reorder_tensor(cross_results.mps.to_vector(), qubits)
self.assertSimilar(y, tensor)

def _test_load_2d_tt(self, n=5):
def test_load_2d_tt(self, n=5):
func, mesh, _, y = gaussian_setup_mps(2, n=n)
black_box = BlackBoxLoadTT(func, mesh)
cross_results = self.cross_method(black_box)
vector = cross_results.mps.to_vector()
self.assertSimilar(y, vector)

def _test_2d_integration_callback(self, n=8):
def test_2d_integration_callback(self, n=8):
a, b = -1, 1
func = lambda tensor: np.exp(tensor[0] + tensor[1]) # f(x,y) = e^(x+y)
interval = RegularInterval(a, b, 2**n, endpoint_right=True)
Expand All @@ -119,29 +118,29 @@ def _test_2d_integration_callback(self, n=8):
integral = cross_results.callback_output[-1] # type: ignore
self.assertAlmostEqual(integral, exact_integral)

def _test_load_1d_mpo_diagonal(self, n=5):
def test_load_1d_mpo_diagonal(self, n=5):
func, x, mesh, mps_I = gaussian_setup_1d_mpo(is_diagonal=True, n=n)
black_box = BlackBoxLoadMPO(func, mesh, is_diagonal=True)
cross_results = self.cross_method(black_box)
mps_diagonal = mps_as_mpo(cross_results.mps).apply(mps_I)
self.assertSimilar(func(x, x), mps_diagonal.to_vector())

def _test_load_1d_mpo_nondiagonal(self, n=5):
def test_load_1d_mpo_nondiagonal(self, n=5):
func, x, mesh, _ = gaussian_setup_1d_mpo(is_diagonal=False, n=n)
black_box = BlackBoxLoadMPO(func, mesh)
cross_results = self.cross_method(black_box)
y_mps = mps_as_mpo(cross_results.mps).to_matrix()
xx, yy = np.meshgrid(x, x)
self.assertSimilar(func(xx, yy), y_mps)

def _test_compose_1d_mps_list(self, n=5):
def test_compose_1d_mps_list(self, n=5):
_, _, mps_0, y_0 = gaussian_setup_mps(1, n=n)
func = lambda v: v[0] + np.sin(v[1]) + np.cos(v[2])
black_box = BlackBoxComposeMPS(func, [mps_0, mps_0, mps_0])
cross_results = self.cross_method(black_box)
self.assertSimilar(func([y_0, y_0, y_0]), cross_results.mps.to_vector())

def _test_compose_2d_mps_list(self, n=5):
def test_compose_2d_mps_list(self, n=5):
_, _, mps_0, y_0 = gaussian_setup_mps(2, n=n)
func = lambda v: v[0] + np.sin(v[1]) + np.cos(v[2])
black_box = BlackBoxComposeMPS(func, [mps_0, mps_0, mps_0])
Expand All @@ -150,131 +149,33 @@ def _test_compose_2d_mps_list(self, n=5):


class TestCrossMaxvol(CrossTests):
def setUp(self):
super().setUp("maxvol")

def test_load_1d_mps(self):
super()._test_load_1d_mps()

def test_load_2d_mps(self):
super()._test_load_2d_mps()

def test_load_2d_mps_with_order_B(self):
super()._test_load_2d_mps_with_order_B()

def test_load_2d_tt(self):
super()._test_load_2d_tt()

def test_2d_integration_callback(self):
super()._test_2d_integration_callback()

def test_load_1d_mpo_diagonal(self):
super()._test_load_1d_mpo_diagonal()

def test_load_1d_mpo_nondiagonal(self):
super()._test_load_1d_mpo_nondiagonal()

def test_compose_1d_mps_list(self):
super()._test_compose_1d_mps_list()

def test_compose_2d_mps_list(self):
super()._test_compose_2d_mps_list()
def cross_method(self, function, *args, **kwdargs):
return cross_maxvol(function, *args, **kwdargs)


class TestCrossDMRG(CrossTests):
def setUp(self):
super().setUp("dmrg")

def test_load_1d_mps(self):
super()._test_load_1d_mps()

def test_load_2d_mps(self):
super()._test_load_2d_mps()

def test_load_2d_mps_with_order_B(self):
super()._test_load_2d_mps_with_order_B()

def test_load_2d_tt(self):
super()._test_load_2d_tt()

def test_2d_integration_callback(self):
super()._test_2d_integration_callback()

def test_load_1d_mpo_diagonal(self):
super()._test_load_1d_mpo_diagonal()

def test_load_1d_mpo_nondiagonal(self):
super()._test_load_1d_mpo_nondiagonal()

def test_compose_1d_mps_list(self):
super()._test_compose_1d_mps_list()

def test_compose_2d_mps_list(self):
super()._test_compose_2d_mps_list()
def cross_method(self, function, *args, **kwdargs):
return cross_dmrg(function, *args, **kwdargs)


class TestCrossGreedyFull(CrossTests):
def setUp(self):
super().setUp("greedy_full")

def test_load_1d_mps(self):
super()._test_load_1d_mps()

def test_load_2d_mps(self):
super()._test_load_2d_mps()

def test_load_2d_mps_with_order_B(self):
super()._test_load_2d_mps_with_order_B()

def test_load_2d_tt(self):
super()._test_load_2d_tt()

def test_2d_integration_callback(self):
super()._test_2d_integration_callback()

def test_load_1d_mpo_diagonal(self):
super()._test_load_1d_mpo_diagonal()

def test_load_1d_mpo_nondiagonal(self):
super()._test_load_1d_mpo_nondiagonal()

def test_compose_1d_mps_list(self):
super()._test_compose_1d_mps_list()

def test_compose_2d_mps_list(self):
super()._test_compose_2d_mps_list()
def cross_method(self, function, *args, **kwdargs):
return cross_greedy(
function,
*args,
cross_strategy=CrossStrategyGreedy(partial=False),
**kwdargs,
)


class TestCrossGreedyPartial(CrossTests):
def setUp(self):
super().setUp("greedy_partial")

def test_load_1d_mps(self):
super()._test_load_1d_mps()

def test_load_2d_mps(self):
super()._test_load_2d_mps()

def test_load_2d_mps_with_order_B(self):
super()._test_load_2d_mps_with_order_B()

def test_load_2d_tt(self):
super()._test_load_2d_tt()

def test_2d_integration_callback(self):
super()._test_2d_integration_callback()

def test_load_1d_mpo_diagonal(self):
super()._test_load_1d_mpo_diagonal()

def test_load_1d_mpo_nondiagonal(self):
super()._test_load_1d_mpo_nondiagonal()

def test_compose_1d_mps_list(self):
super()._test_compose_1d_mps_list()

def test_compose_2d_mps_list(self):
super()._test_compose_2d_mps_list()
def cross_method(self, function, *args, **kwdargs):
return cross_greedy(
function,
*args,
cross_strategy=CrossStrategyGreedy(partial=True),
**kwdargs,
)


class TestSkeleton(TestCase):
Expand Down

0 comments on commit 6ae5c06

Please sign in to comment.