diff --git a/src/seemps/state/canonical_mps.py b/src/seemps/state/canonical_mps.py index afe17e6..2848b40 100644 --- a/src/seemps/state/canonical_mps.py +++ b/src/seemps/state/canonical_mps.py @@ -7,8 +7,8 @@ from seemps.state.core import ( DEFAULT_STRATEGY, Strategy, - _update_in_canonical_form_right, - _update_in_canonical_form_left, + _update_canonical_right, + _update_canonical_left, _update_canonical_2site_left, _update_canonical_2site_right, _canonicalize, @@ -244,11 +244,11 @@ def update_canonical( The truncation error of this update. """ if direction > 0: - self.center, err = _update_in_canonical_form_right( + self.center, err = _update_canonical_right( self._data, A, self.center, truncation ) else: - self.center, err = _update_in_canonical_form_left( + self.center, err = _update_canonical_left( self._data, A, self.center, truncation ) self.update_error(err) diff --git a/src/seemps/state/core.cc b/src/seemps/state/core.cc index f4767bd..f933f90 100644 --- a/src/seemps/state/core.cc +++ b/src/seemps/state/core.cc @@ -135,14 +135,14 @@ PYBIND11_MODULE(core, m) { m.def("schmidt_weights", &schmidt_weights); m.def( - "_update_in_canonical_form_right", &_update_in_canonical_form_right, - py::arg("state"), py::arg("tensor"), py::arg("site"), py::arg("strategy"), + "_update_canonical_right", &_update_canonical_right, py::arg("state"), + py::arg("tensor"), py::arg("site"), py::arg("strategy"), py::arg("overwrite") = false, R"doc(Insert a tensor in canonical form into the MPS Ψ at the given site. Update the neighboring sites in the process)doc"); m.def( - "_update_in_canonical_form_left", &_update_in_canonical_form_left, - py::arg("state"), py::arg("tensor"), py::arg("site"), py::arg("strategy"), + "_update_canonical_left", &_update_canonical_left, py::arg("state"), + py::arg("tensor"), py::arg("site"), py::arg("strategy"), py::arg("overwrite") = false, R"doc(Insert a tensor in canonical form into the MPS Ψ at the given site. Update the neighboring sites in the process)doc"); diff --git a/src/seemps/state/core.pyi b/src/seemps/state/core.pyi index b8684c9..cf671ca 100644 --- a/src/seemps/state/core.pyi +++ b/src/seemps/state/core.pyi @@ -77,10 +77,10 @@ def right_orth_2site( ) -> tuple[Tensor3, Tensor3, float]: ... def _destructive_svd(A: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ... def schmidt_weights(A: np.ndarray) -> np.ndarray: ... -def _update_in_canonical_form_left( +def _update_canonical_left( state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy ) -> tuple[int, float]: ... -def _update_in_canonical_form_right( +def _update_canonical_right( state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy ) -> tuple[int, float]: ... def _update_canonical_2site_left( diff --git a/src/seemps/state/mps.h b/src/seemps/state/mps.h index 96748e3..b9831df 100644 --- a/src/seemps/state/mps.h +++ b/src/seemps/state/mps.h @@ -17,14 +17,14 @@ Weight scprod(py::object A, py::object B); py::object schmidt_weights(py::object A); -std::tuple -_update_in_canonical_form_right(py::list state, py::object A, int site, - const Strategy &truncation, - bool overwrite = false); -std::tuple -_update_in_canonical_form_left(py::list state, py::object A, int site, - const Strategy &truncation, - bool overwrite = false); +std::tuple _update_canonical_right(py::list state, py::object A, + int site, + const Strategy &truncation, + bool overwrite = false); +std::tuple _update_canonical_left(py::list state, py::object A, + int site, + const Strategy &truncation, + bool overwrite = false); double _canonicalize(py::list state, int center, const Strategy &truncation); std::tuple left_orth_2site(py::object AA, const Strategy &strategy); diff --git a/src/seemps/state/schmidt.cc b/src/seemps/state/schmidt.cc index 7c3e086..3bf8a55 100644 --- a/src/seemps/state/schmidt.cc +++ b/src/seemps/state/schmidt.cc @@ -5,12 +5,13 @@ namespace seemps { -std::tuple -_update_in_canonical_form_right(py::list state, py::object A, int site, - const Strategy &strategy, bool overwrite) { +std::tuple _update_canonical_right(py::list state, py::object A, + int site, + const Strategy &strategy, + bool overwrite) { if (!is_array(A) || array_ndim(A) != 3) { throw std::invalid_argument( - "Invalid tensor passed to _update_in_canonical_form_right"); + "Invalid tensor passed to _update_canonical_right"); } py::object tensor = overwrite ? array_getcontiguous(A) : array_copy(A); auto a = array_dim(tensor, 0); @@ -28,13 +29,13 @@ _update_in_canonical_form_right(py::list state, py::object A, int site, return {site, err}; } -std::tuple _update_in_canonical_form_left(py::list state, - py::object A, int site, - const Strategy &strategy, - bool overwrite) { +std::tuple _update_canonical_left(py::list state, py::object A, + int site, + const Strategy &strategy, + bool overwrite) { if (!is_array(A) || array_ndim(A) != 3) { throw std::invalid_argument( - "Invalid tensor passed to _update_in_canonical_form_right"); + "Invalid tensor passed to _update_canonical_right"); } py::object tensor = overwrite ? array_getcontiguous(A) : array_copy(A); auto a = array_dim(tensor, 0); @@ -55,14 +56,12 @@ std::tuple _update_in_canonical_form_left(py::list state, double _canonicalize(py::list state, int center, const Strategy &strategy) { double err = 0.0; for (int i = 0; i < center;) { - auto [site, errk] = - _update_in_canonical_form_right(state, state[i], i, strategy); + auto [site, errk] = _update_canonical_right(state, state[i], i, strategy); err += errk; i = site; } for (int i = state.size() - 1; i > center;) { - auto [site, errk] = - _update_in_canonical_form_left(state, state[i], i, strategy); + auto [site, errk] = _update_canonical_left(state, state[i], i, strategy); err += errk; i = site; } diff --git a/tests/test_states/test_canonical.py b/tests/test_states/test_canonical.py index a67db8e..74a579d 100644 --- a/tests/test_states/test_canonical.py +++ b/tests/test_states/test_canonical.py @@ -7,8 +7,8 @@ random_uniform_mps, ) from seemps.state.canonical_mps import ( - _update_in_canonical_form_left, - _update_in_canonical_form_right, + _update_canonical_left, + _update_canonical_right, _canonicalize, ) from ..fixture_mps_states import MPSStatesFixture @@ -18,18 +18,18 @@ class TestCanonicalForm(MPSStatesFixture): def test_local_update_canonical(self): # - # We verify that _update_in_canonical_form() leaves a tensor that + # We verify that _update_canonical() leaves a tensor that # is an approximate isometry. # def ok(Ψ, normalization=False): strategy = DEFAULT_STRATEGY.replace(normalize=normalization) for i in range(Ψ.size - 1): ξ = Ψ.copy() - _update_in_canonical_form_right(ξ._data, ξ[i], i, strategy) + _update_canonical_right(ξ._data, ξ[i], i, strategy) self.assertTrue(approximateIsometry(ξ[i], +1)) for i in range(1, Ψ.size): ξ = Ψ.copy() - _update_in_canonical_form_left(ξ._data, ξ[i], i, strategy) + _update_canonical_left(ξ._data, ξ[i], i, strategy) self.assertTrue(approximateIsometry(ξ[i], -1)) run_over_random_uniform_mps(ok)