diff --git a/src/seemps/state/canonical_mps.py b/src/seemps/state/canonical_mps.py index 5464da64..cf7d2f4e 100644 --- a/src/seemps/state/canonical_mps.py +++ b/src/seemps/state/canonical_mps.py @@ -90,6 +90,7 @@ def __init__( center: Optional[int] = None, normalize: bool = False, strategy: Strategy = DEFAULT_STRATEGY, + is_canonical: bool = False, **kwdargs, ): super().__init__(data, **kwdargs) @@ -105,7 +106,10 @@ def __init__( self.center = actual_center = self._interpret_center( 0 if center is None else center ) - self.update_error(_canonicalize(self._data, actual_center, self.strategy)) + if not is_canonical: + self.update_error( + _canonicalize(self._data, actual_center, self.strategy) + ) if normalize or self.strategy.get_normalize_flag(): A = self[actual_center] self[actual_center] = A / np.linalg.norm(A) @@ -117,6 +121,7 @@ def from_vector( dimensions: Sequence[int], strategy: Strategy = DEFAULT_STRATEGY, normalize: bool = True, + center: int = 0, **kwdargs, ) -> CanonicalMPS: """Create an MPS in canonical form starting from a state vector. @@ -132,6 +137,8 @@ def from_vector( Default truncation strategy for algorithms working on this state. normalize : bool, default = True Whether the state is normalized to compensate truncation errors. + center : int, default = 0 + Center for the canonical form of this decomposition. Returns ------- @@ -142,10 +149,12 @@ def from_vector( -------- :py:meth:`~seemps.state.MPS.from_vector` """ + data, error = schmidt.vector2mps(ψ, dimensions, strategy, normalize, center) return CanonicalMPS( - schmidt.vector2mps(ψ, dimensions, strategy, normalize), - center=kwdargs.get("center", 0), - strategy=strategy, + data, + error=error, + center=center, + is_canonical=True, ) def norm_squared(self) -> float: diff --git a/src/seemps/state/mps.py b/src/seemps/state/mps.py index 1e1ebfd1..fb7aee21 100644 --- a/src/seemps/state/mps.py +++ b/src/seemps/state/mps.py @@ -94,6 +94,7 @@ def from_vector( dimensions: Sequence[int], strategy: Strategy = DEFAULT_STRATEGY, normalize: bool = True, + center: int = -1, **kwdargs, ) -> MPS: """Create a matrix-product state from a state vector. @@ -115,7 +116,8 @@ def from_vector( MPS A valid matrix-product state approximating this state vector. """ - return MPS(vector2mps(ψ, dimensions, strategy, normalize)) + data, error = vector2mps(ψ, dimensions, strategy, normalize, center) + return MPS(data, error) @classmethod def from_tensor( diff --git a/src/seemps/state/schmidt.py b/src/seemps/state/schmidt.py index d8aac655..bdfed5e9 100644 --- a/src/seemps/state/schmidt.py +++ b/src/seemps/state/schmidt.py @@ -93,7 +93,7 @@ def right_orth_2site(AA, strategy: Strategy): return (U[:, :D] * S).reshape(α, d1, D), V[:D, :].reshape(D, d2, β), err -def vector2mps( +def old_vector2mps( state: VectorLike, dimensions: Sequence[int], strategy: Strategy = DEFAULT_STRATEGY, @@ -129,3 +129,51 @@ def vector2mps( output[-1] = ψ.reshape(Da, dimensions[-1], 1) return output + + +def vector2mps( + state: VectorLike, + dimensions: Sequence[int], + strategy: Strategy = DEFAULT_STRATEGY, + normalize: bool = True, + center: int = -1, +) -> tuple[list[Tensor3], float]: + """Construct a list of tensors for an MPS that approximates the state ψ + represented as a complex vector in a Hilbert space. + + Parameters + ---------- + ψ -- wavefunction with \\prod_i dimensions[i] elements + dimensions -- list of dimensions of the Hilbert spaces that build ψ + tolerance -- truncation criterion for dropping Schmidt numbers + normalize -- boolean to determine if the MPS is normalized + """ + ψ: NDArray = np.asarray(state).copy().reshape(1, -1, 1) + L = len(dimensions) + if math.prod(dimensions) != ψ.size: + raise Exception("Wrong dimensions specified when converting a vector to MPS") + output = [None] * L + Da = 1 + if center < 0: + center = L + center + if center < 0 or center >= L: + raise Exception("Invalid value of center in vector2mps") + err = 0.0 + for i in range(center): + s = ψ.shape + output[i], ψ, new_err = left_orth_2site( + ψ.reshape(ψ.shape[0], dimensions[i], -1, ψ.shape[-1]), strategy + ) + err += np.sqrt(new_err) + for i in range(L - 1, center, -1): + s = ψ.shape + ψ, output[i], new_err = right_orth_2site( + ψ.reshape(ψ.shape[0], -1, dimensions[i], ψ.shape[-1]), strategy + ) + err += np.sqrt(new_err) + if normalize: + N = np.linalg.norm(ψ.reshape(-1)) + ψ /= N + err /= N + output[center] = ψ + return output, err * err diff --git a/tests/test_states/__init__.py b/tests/test_states/__init__.py index afecc981..18667116 100644 --- a/tests/test_states/__init__.py +++ b/tests/test_states/__init__.py @@ -2,6 +2,7 @@ test_canonical, test_mps, test_mpssum, + test_mps_from_vector, test_random_mps, test_sampling, test_sample_states, diff --git a/tests/test_states/test_mps_from_vector.py b/tests/test_states/test_mps_from_vector.py new file mode 100644 index 00000000..e9c3c746 --- /dev/null +++ b/tests/test_states/test_mps_from_vector.py @@ -0,0 +1,94 @@ +import numpy as np +from seemps.state import NO_TRUNCATION +from seemps.state.schmidt import vector2mps, old_vector2mps +from seemps.state.array import TensorArray +from .. import tools + + +class TestMPSFromVector(tools.TestCase): + def join_tensors(self, state): + w = np.ones((1, 1)) + for A in state: + w = np.einsum("ia,ajb->ijb", w, A) + w = w.reshape(-1, w.shape[-1]) + return w.reshape(-1) + + def test_mps_from_vector_on_one_site(self): + v = self.rng.normal(size=5) + state, err = vector2mps(v, [5], strategy=NO_TRUNCATION, normalize=False) + self.assertTrue(err >= 0) + self.assertAlmostEqual(err, 0) + self.assertEqual(len(state), 1) + self.assertEqual(state[0].shape, (1, 5, 1)) + + def test_mps_from_vector_on_different_sizes(self): + v1 = self.rng.normal(size=2) + v2 = self.rng.normal(size=3) + v3 = self.rng.normal(size=4) + w = v1[:, np.newaxis, np.newaxis] * v2[:, np.newaxis] * v3 + + state, err = vector2mps( + w.reshape(-1), [2, 3, 4], strategy=NO_TRUNCATION, normalize=False + ) + self.assertTrue(err >= 0) + self.assertAlmostEqual(err, 0) + self.assertEqual(len(state), 3) + self.assertEqual(state[0].shape, (1, 2, 2)) + self.assertEqual(state[1].shape, (2, 3, 4)) + self.assertEqual(state[2].shape, (4, 4, 1)) + + w = np.einsum("aib,bjc,ckd->ijk", state[0], state[1], state[2]) + self.assertSimilar(w.reshape(-1), w.reshape(-1)) + + def test_mps_from_vector_on_random_qubit_states(self): + for normalize in [False, True]: + for N in range(1, 18): + v = self.rng.normal(size=(2**N,)) + state, err = vector2mps( + v, [2] * N, strategy=NO_TRUNCATION, normalize=normalize + ) + + self.assertTrue(err >= 0) + self.assertAlmostEqual(err, 0) + + self.assertEqual(len(state), N) + for i in range(N): + self.assertEqual(state[i].shape[1], 2) + + w = self.join_tensors(state) + if normalize: + self.assertSimilar(w, v / np.linalg.norm(v)) + self.assertAlmostEqual(np.linalg.norm(w), 1) + else: + self.assertSimilar(w, v) + + def test_mps_from_vector_works_on_all_centers(self): + for N in range(1, 10): + v = self.rng.normal(size=(2**N,)) + for center in range(-N + 1, N): + state, err = vector2mps( + v, [2] * N, center=center, strategy=NO_TRUNCATION, normalize=False + ) + self.assertSimilar(self.join_tensors(state), v) + + def test_mps_from_vector_produces_isometries(self): + for N in range(2, 10): + v = self.rng.normal(size=(2**N,)) + for center in range(0, N): + state, err = vector2mps( + v, [2] * N, center=center, strategy=NO_TRUNCATION + ) + for i, A in enumerate(state): + if i < center: + self.assertApproximateIsometry(A, +1) + elif i > center: + self.assertApproximateIsometry(A, -1) + + def test_mps_from_vector_normalizes_central_tensor(self): + for N in range(1, 10): + v = self.rng.normal(size=(2**N,)) + for center in range(0, N): + state, err = vector2mps( + v, [2] * N, center=center, normalize=True, strategy=NO_TRUNCATION + ) + self.assertAlmostEqual(np.linalg.norm(state[center].reshape(-1)), 1)