From 74e318baab7159e517fe4de0ee40cd435b770bc5 Mon Sep 17 00:00:00 2001 From: Juan Jose Garcia-Ripoll Date: Fri, 12 Apr 2024 09:04:32 +0200 Subject: [PATCH] Replace left/right_orth_2site with _update_canonical_2site_left/right --- src/seemps/state/canonical_mps.py | 10 ++--- src/seemps/state/core.cc | 2 + src/seemps/state/core.pyi | 6 +++ src/seemps/state/mps.h | 5 ++- src/seemps/state/schmidt.cc | 67 +++++++++++++++++++------------ src/seemps/state/schmidt.py | 2 +- 6 files changed, 58 insertions(+), 34 deletions(-) diff --git a/src/seemps/state/canonical_mps.py b/src/seemps/state/canonical_mps.py index 8c75744..afe17e6 100644 --- a/src/seemps/state/canonical_mps.py +++ b/src/seemps/state/canonical_mps.py @@ -9,6 +9,8 @@ Strategy, _update_in_canonical_form_right, _update_in_canonical_form_left, + _update_canonical_2site_left, + _update_canonical_2site_right, _canonicalize, ) from .mps import MPS @@ -270,9 +272,7 @@ def update_2site_right(self, AA: Tensor4, site: int, strategy: Strategy) -> None Truncation strategy, including relative tolerances and maximum bond dimensions """ - self._data[site], self._data[site + 1], err = schmidt.left_orth_2site( - AA, strategy - ) + err = _update_canonical_2site_left(self._data, AA, site, strategy) self.center = site + 1 self.update_error(err) @@ -293,9 +293,7 @@ def update_2site_left(self, AA: Tensor4, site: int, strategy: Strategy) -> None: Truncation strategy, including relative tolerances and maximum bond dimensions """ - self._data[site], self._data[site + 1], err = schmidt.right_orth_2site( - AA, strategy - ) + err = _update_canonical_2site_right(self._data, AA, site, strategy) self.center = site self.update_error(err) diff --git a/src/seemps/state/core.cc b/src/seemps/state/core.cc index 41872a2..f4767bd 100644 --- a/src/seemps/state/core.cc +++ b/src/seemps/state/core.cc @@ -149,6 +149,8 @@ PYBIND11_MODULE(core, m) { m.def("_canonicalize", &_canonicalize, R"doc(Update a list of `Tensor3` objects to be in canonical form with respect to `center`.)doc"); + m.def("_update_canonical_2site_left", &_update_canonical_2site_left); + m.def("_update_canonical_2site_right", &_update_canonical_2site_right); m.def("left_orth_2site", &left_orth_2site); m.def("right_orth_2site", &right_orth_2site); } diff --git a/src/seemps/state/core.pyi b/src/seemps/state/core.pyi index 22a5f5d..b8684c9 100644 --- a/src/seemps/state/core.pyi +++ b/src/seemps/state/core.pyi @@ -83,6 +83,12 @@ def _update_in_canonical_form_left( def _update_in_canonical_form_right( state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy ) -> tuple[int, float]: ... +def _update_canonical_2site_left( + state: list[Tensor3], A: Tensor4, site: int, truncation: Strategy +) -> float: ... +def _update_canonical_2site_right( + state: list[Tensor3], A: Tensor4, site: int, truncation: Strategy +) -> float: ... def _canonicalize(state: list[Tensor3], center: int, truncation: Strategy) -> float: ... from .mps import MPS # noqa: E402 diff --git a/src/seemps/state/mps.h b/src/seemps/state/mps.h index a07deae..96748e3 100644 --- a/src/seemps/state/mps.h +++ b/src/seemps/state/mps.h @@ -30,5 +30,8 @@ std::tuple left_orth_2site(py::object AA, const Strategy &strategy); std::tuple right_orth_2site(py::object AA, const Strategy &strategy); - +double _update_canonical_2site_left(py::list state, py::object A, int site, + const Strategy &strategy); +double _update_canonical_2site_right(py::list state, py::object A, int site, + const Strategy &strategy); } // namespace seemps diff --git a/src/seemps/state/schmidt.cc b/src/seemps/state/schmidt.cc index 93cfd37..7c3e086 100644 --- a/src/seemps/state/schmidt.cc +++ b/src/seemps/state/schmidt.cc @@ -104,36 +104,51 @@ right_orth_2site(py::object AA, const Strategy &strategy) { as_3tensor(matrix_resize(V, D, -1), D, d2, b), err}; } -/* - - +double _update_canonical_2site_left(py::list state, py::object A, int site, + const Strategy &strategy) { + if (!is_array(A) || array_ndim(A) != 4) { + throw std::invalid_argument( + "Invalid tensor passed to _update_canonical_2site_left"); + } + py::object tensor = array_getcontiguous(A); + auto a = array_dim(tensor, 0); + auto d1 = array_dim(tensor, 1); + auto d2 = array_dim(tensor, 2); + auto b = array_dim(tensor, 3); + // Split tensor + auto [U, s, V] = + destructive_svd(array_reshape(tensor, array_dims_t{a * d1, d2 * b})); + auto err = destructively_truncate_vector(s, strategy); + auto D = array_size(s); -def left_orth_2site(AA, strategy: Strategy): - """Split a tensor AA[a,b,c,d] into B[a,b,r] and C[r,c,d] such - that 'B' is a left-isometry, truncating the size 'r' according - to the given 'strategy'. Tensor 'AA' may be overwritten.""" - α, d1, d2, β = AA.shape - U, S, V = _destructive_svd(AA.reshape(α * d1, β * d2)) - err = destructively_truncate_vector(S, strategy) - D = S.size - return ( - U[:, :D].reshape(α, d1, D), - (S.reshape(D, 1) * V[:D, :]).reshape(D, d2, β), - err, - ) + state[site] = as_3tensor(matrix_resize(U, -1, D), a, d1, D); + state[site + 1] = + as_3tensor(as_matrix(s, D, 1) * matrix_resize(V, D, -1), D, d2, b); + return err; +} +double _update_canonical_2site_right(py::list state, py::object A, int site, + const Strategy &strategy) { + if (!is_array(A) || array_ndim(A) != 4) { + throw std::invalid_argument( + "Invalid tensor passed to _update_canonical_2site_left"); + } + py::object tensor = array_getcontiguous(A); + auto a = array_dim(tensor, 0); + auto d1 = array_dim(tensor, 1); + auto d2 = array_dim(tensor, 2); + auto b = array_dim(tensor, 3); -def right_orth_2site(AA, strategy: Strategy): - """Split a tensor AA[a,b,c,d] into B[a,b,r] and C[r,c,d] such - that 'C' is a right-isometry, truncating the size 'r' according - to the given 'strategy'. Tensor 'AA' may be overwritten.""" - α, d1, d2, β = AA.shape - U, S, V = _destructive_svd(AA.reshape(α * d1, β * d2)) - err = destructively_truncate_vector(S, strategy) - D = S.size - return (U[:, :D] * S).reshape(α, d1, D), V[:D, :].reshape(D, d2, β), err + // Split tensor + auto [U, s, V] = + destructive_svd(array_reshape(tensor, array_dims_t{a * d1, d2 * b})); + auto err = destructively_truncate_vector(s, strategy); + auto D = array_size(s); - */ + state[site] = as_3tensor(matrix_resize(U, -1, D) * s, a, d1, D); + state[site + 1] = as_3tensor(matrix_resize(V, D, -1), D, d2, b); + return err; +} } // namespace seemps diff --git a/src/seemps/state/schmidt.py b/src/seemps/state/schmidt.py index f64630d..7de77c2 100644 --- a/src/seemps/state/schmidt.py +++ b/src/seemps/state/schmidt.py @@ -13,7 +13,7 @@ right_orth_2site, ) -__all__ = ["schmidt_weights", "left_orth_2site", "right_orth_2site", "vector2mps"] +__all__ = ["schmidt_weights", "vector2mps"] def vector2mps(