Skip to content

Commit

Permalink
New version of vector2mps, more stable, creating canonical forms
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Mar 24, 2024
1 parent 401cb8a commit bd4ec2d
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 6 deletions.
17 changes: 13 additions & 4 deletions src/seemps/state/canonical_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
center: Optional[int] = None,
normalize: bool = False,
strategy: Strategy = DEFAULT_STRATEGY,
is_canonical: bool = False,
**kwdargs,
):
super().__init__(data, **kwdargs)
Expand All @@ -105,7 +106,10 @@ def __init__(
self.center = actual_center = self._interpret_center(
0 if center is None else center
)
self.update_error(_canonicalize(self._data, actual_center, self.strategy))
if not is_canonical:
self.update_error(
_canonicalize(self._data, actual_center, self.strategy)
)
if normalize or self.strategy.get_normalize_flag():
A = self[actual_center]
self[actual_center] = A / np.linalg.norm(A)
Expand All @@ -117,6 +121,7 @@ def from_vector(
dimensions: Sequence[int],
strategy: Strategy = DEFAULT_STRATEGY,
normalize: bool = True,
center: int = 0,
**kwdargs,
) -> CanonicalMPS:
"""Create an MPS in canonical form starting from a state vector.
Expand All @@ -132,6 +137,8 @@ def from_vector(
Default truncation strategy for algorithms working on this state.
normalize : bool, default = True
Whether the state is normalized to compensate truncation errors.
center : int, default = 0
Center for the canonical form of this decomposition.
Returns
-------
Expand All @@ -142,10 +149,12 @@ def from_vector(
--------
:py:meth:`~seemps.state.MPS.from_vector`
"""
data, error = schmidt.vector2mps(ψ, dimensions, strategy, normalize, center)
return CanonicalMPS(
schmidt.vector2mps(ψ, dimensions, strategy, normalize),
center=kwdargs.get("center", 0),
strategy=strategy,
data,
error=error,
center=center,
is_canonical=True,
)

def norm_squared(self) -> float:
Expand Down
4 changes: 3 additions & 1 deletion src/seemps/state/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def from_vector(
dimensions: Sequence[int],
strategy: Strategy = DEFAULT_STRATEGY,
normalize: bool = True,
center: int = -1,
**kwdargs,
) -> MPS:
"""Create a matrix-product state from a state vector.
Expand All @@ -115,7 +116,8 @@ def from_vector(
MPS
A valid matrix-product state approximating this state vector.
"""
return MPS(vector2mps(ψ, dimensions, strategy, normalize))
data, error = vector2mps(ψ, dimensions, strategy, normalize, center)
return MPS(data, error)

@classmethod
def from_tensor(
Expand Down
50 changes: 49 additions & 1 deletion src/seemps/state/schmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def right_orth_2site(AA, strategy: Strategy):
return (U[:, :D] * S).reshape(α, d1, D), V[:D, :].reshape(D, d2, β), err


def vector2mps(
def old_vector2mps(
state: VectorLike,
dimensions: Sequence[int],
strategy: Strategy = DEFAULT_STRATEGY,
Expand Down Expand Up @@ -129,3 +129,51 @@ def vector2mps(
output[-1] = ψ.reshape(Da, dimensions[-1], 1)

return output


def vector2mps(
state: VectorLike,
dimensions: Sequence[int],
strategy: Strategy = DEFAULT_STRATEGY,
normalize: bool = True,
center: int = -1,
) -> tuple[list[Tensor3], float]:
"""Construct a list of tensors for an MPS that approximates the state ψ
represented as a complex vector in a Hilbert space.
Parameters
----------
ψ -- wavefunction with \\prod_i dimensions[i] elements
dimensions -- list of dimensions of the Hilbert spaces that build ψ
tolerance -- truncation criterion for dropping Schmidt numbers
normalize -- boolean to determine if the MPS is normalized
"""
ψ: NDArray = np.asarray(state).copy().reshape(1, -1, 1)
L = len(dimensions)
if math.prod(dimensions) != ψ.size:
raise Exception("Wrong dimensions specified when converting a vector to MPS")
output = [None] * L
Da = 1
if center < 0:
center = L + center
if center < 0 or center >= L:
raise Exception("Invalid value of center in vector2mps")
err = 0.0
for i in range(center):
s = ψ.shape
output[i], ψ, new_err = left_orth_2site(
ψ.reshape(ψ.shape[0], dimensions[i], -1, ψ.shape[-1]), strategy
)
err += np.sqrt(new_err)
for i in range(L - 1, center, -1):
s = ψ.shape
ψ, output[i], new_err = right_orth_2site(
ψ.reshape(ψ.shape[0], -1, dimensions[i], ψ.shape[-1]), strategy
)
err += np.sqrt(new_err)
if normalize:
N = np.linalg.norm(ψ.reshape(-1))
ψ /= N
err /= N
output[center] = ψ
return output, err * err
1 change: 1 addition & 0 deletions tests/test_states/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
test_canonical,
test_mps,
test_mpssum,
test_mps_from_vector,
test_random_mps,
test_sampling,
test_sample_states,
Expand Down
94 changes: 94 additions & 0 deletions tests/test_states/test_mps_from_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
from seemps.state import NO_TRUNCATION
from seemps.state.schmidt import vector2mps, old_vector2mps
from seemps.state.array import TensorArray
from .. import tools


class TestMPSFromVector(tools.TestCase):
def join_tensors(self, state):
w = np.ones((1, 1))
for A in state:
w = np.einsum("ia,ajb->ijb", w, A)
w = w.reshape(-1, w.shape[-1])
return w.reshape(-1)

def test_mps_from_vector_on_one_site(self):
v = self.rng.normal(size=5)
state, err = vector2mps(v, [5], strategy=NO_TRUNCATION, normalize=False)
self.assertTrue(err >= 0)
self.assertAlmostEqual(err, 0)
self.assertEqual(len(state), 1)
self.assertEqual(state[0].shape, (1, 5, 1))

def test_mps_from_vector_on_different_sizes(self):
v1 = self.rng.normal(size=2)
v2 = self.rng.normal(size=3)
v3 = self.rng.normal(size=4)
w = v1[:, np.newaxis, np.newaxis] * v2[:, np.newaxis] * v3

state, err = vector2mps(
w.reshape(-1), [2, 3, 4], strategy=NO_TRUNCATION, normalize=False
)
self.assertTrue(err >= 0)
self.assertAlmostEqual(err, 0)
self.assertEqual(len(state), 3)
self.assertEqual(state[0].shape, (1, 2, 2))
self.assertEqual(state[1].shape, (2, 3, 4))
self.assertEqual(state[2].shape, (4, 4, 1))

w = np.einsum("aib,bjc,ckd->ijk", state[0], state[1], state[2])
self.assertSimilar(w.reshape(-1), w.reshape(-1))

def test_mps_from_vector_on_random_qubit_states(self):
for normalize in [False, True]:
for N in range(1, 18):
v = self.rng.normal(size=(2**N,))
state, err = vector2mps(
v, [2] * N, strategy=NO_TRUNCATION, normalize=normalize
)

self.assertTrue(err >= 0)
self.assertAlmostEqual(err, 0)

self.assertEqual(len(state), N)
for i in range(N):
self.assertEqual(state[i].shape[1], 2)

w = self.join_tensors(state)
if normalize:
self.assertSimilar(w, v / np.linalg.norm(v))
self.assertAlmostEqual(np.linalg.norm(w), 1)
else:
self.assertSimilar(w, v)

def test_mps_from_vector_works_on_all_centers(self):
for N in range(1, 10):
v = self.rng.normal(size=(2**N,))
for center in range(-N + 1, N):
state, err = vector2mps(
v, [2] * N, center=center, strategy=NO_TRUNCATION, normalize=False
)
self.assertSimilar(self.join_tensors(state), v)

def test_mps_from_vector_produces_isometries(self):
for N in range(2, 10):
v = self.rng.normal(size=(2**N,))
for center in range(0, N):
state, err = vector2mps(
v, [2] * N, center=center, strategy=NO_TRUNCATION
)
for i, A in enumerate(state):
if i < center:
self.assertApproximateIsometry(A, +1)
elif i > center:
self.assertApproximateIsometry(A, -1)

def test_mps_from_vector_normalizes_central_tensor(self):
for N in range(1, 10):
v = self.rng.normal(size=(2**N,))
for center in range(0, N):
state, err = vector2mps(
v, [2] * N, center=center, normalize=True, strategy=NO_TRUNCATION
)
self.assertAlmostEqual(np.linalg.norm(state[center].reshape(-1)), 1)

0 comments on commit bd4ec2d

Please sign in to comment.