Skip to content

Commit

Permalink
Rename _update_in_canonical_form* for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Apr 12, 2024
1 parent 74e318b commit fb04d4b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 36 deletions.
8 changes: 4 additions & 4 deletions src/seemps/state/canonical_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from seemps.state.core import (
DEFAULT_STRATEGY,
Strategy,
_update_in_canonical_form_right,
_update_in_canonical_form_left,
_update_canonical_right,
_update_canonical_left,
_update_canonical_2site_left,
_update_canonical_2site_right,
_canonicalize,
Expand Down Expand Up @@ -244,11 +244,11 @@ def update_canonical(
The truncation error of this update.
"""
if direction > 0:
self.center, err = _update_in_canonical_form_right(
self.center, err = _update_canonical_right(
self._data, A, self.center, truncation
)
else:
self.center, err = _update_in_canonical_form_left(
self.center, err = _update_canonical_left(
self._data, A, self.center, truncation
)
self.update_error(err)
Expand Down
8 changes: 4 additions & 4 deletions src/seemps/state/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ PYBIND11_MODULE(core, m) {

m.def("schmidt_weights", &schmidt_weights);
m.def(
"_update_in_canonical_form_right", &_update_in_canonical_form_right,
py::arg("state"), py::arg("tensor"), py::arg("site"), py::arg("strategy"),
"_update_canonical_right", &_update_canonical_right, py::arg("state"),
py::arg("tensor"), py::arg("site"), py::arg("strategy"),
py::arg("overwrite") = false,
R"doc(Insert a tensor in canonical form into the MPS Ψ at the given site.
Update the neighboring sites in the process)doc");
m.def(
"_update_in_canonical_form_left", &_update_in_canonical_form_left,
py::arg("state"), py::arg("tensor"), py::arg("site"), py::arg("strategy"),
"_update_canonical_left", &_update_canonical_left, py::arg("state"),
py::arg("tensor"), py::arg("site"), py::arg("strategy"),
py::arg("overwrite") = false,
R"doc(Insert a tensor in canonical form into the MPS Ψ at the given site.
Update the neighboring sites in the process)doc");
Expand Down
4 changes: 2 additions & 2 deletions src/seemps/state/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def right_orth_2site(
) -> tuple[Tensor3, Tensor3, float]: ...
def _destructive_svd(A: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
def schmidt_weights(A: np.ndarray) -> np.ndarray: ...
def _update_in_canonical_form_left(
def _update_canonical_left(
state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy
) -> tuple[int, float]: ...
def _update_in_canonical_form_right(
def _update_canonical_right(
state: list[Tensor3], A: Tensor3, site: int, truncation: Strategy
) -> tuple[int, float]: ...
def _update_canonical_2site_left(
Expand Down
16 changes: 8 additions & 8 deletions src/seemps/state/mps.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ Weight scprod(py::object A, py::object B);

py::object schmidt_weights(py::object A);

std::tuple<int, double>
_update_in_canonical_form_right(py::list state, py::object A, int site,
const Strategy &truncation,
bool overwrite = false);
std::tuple<int, double>
_update_in_canonical_form_left(py::list state, py::object A, int site,
const Strategy &truncation,
bool overwrite = false);
std::tuple<int, double> _update_canonical_right(py::list state, py::object A,
int site,
const Strategy &truncation,
bool overwrite = false);
std::tuple<int, double> _update_canonical_left(py::list state, py::object A,
int site,
const Strategy &truncation,
bool overwrite = false);
double _canonicalize(py::list state, int center, const Strategy &truncation);
std::tuple<py::object, py::object, double>
left_orth_2site(py::object AA, const Strategy &strategy);
Expand Down
25 changes: 12 additions & 13 deletions src/seemps/state/schmidt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

namespace seemps {

std::tuple<int, double>
_update_in_canonical_form_right(py::list state, py::object A, int site,
const Strategy &strategy, bool overwrite) {
std::tuple<int, double> _update_canonical_right(py::list state, py::object A,
int site,
const Strategy &strategy,
bool overwrite) {
if (!is_array(A) || array_ndim(A) != 3) {
throw std::invalid_argument(
"Invalid tensor passed to _update_in_canonical_form_right");
"Invalid tensor passed to _update_canonical_right");
}
py::object tensor = overwrite ? array_getcontiguous(A) : array_copy(A);
auto a = array_dim(tensor, 0);
Expand All @@ -28,13 +29,13 @@ _update_in_canonical_form_right(py::list state, py::object A, int site,
return {site, err};
}

std::tuple<int, double> _update_in_canonical_form_left(py::list state,
py::object A, int site,
const Strategy &strategy,
bool overwrite) {
std::tuple<int, double> _update_canonical_left(py::list state, py::object A,
int site,
const Strategy &strategy,
bool overwrite) {
if (!is_array(A) || array_ndim(A) != 3) {
throw std::invalid_argument(
"Invalid tensor passed to _update_in_canonical_form_right");
"Invalid tensor passed to _update_canonical_right");
}
py::object tensor = overwrite ? array_getcontiguous(A) : array_copy(A);
auto a = array_dim(tensor, 0);
Expand All @@ -55,14 +56,12 @@ std::tuple<int, double> _update_in_canonical_form_left(py::list state,
double _canonicalize(py::list state, int center, const Strategy &strategy) {
double err = 0.0;
for (int i = 0; i < center;) {
auto [site, errk] =
_update_in_canonical_form_right(state, state[i], i, strategy);
auto [site, errk] = _update_canonical_right(state, state[i], i, strategy);
err += errk;
i = site;
}
for (int i = state.size() - 1; i > center;) {
auto [site, errk] =
_update_in_canonical_form_left(state, state[i], i, strategy);
auto [site, errk] = _update_canonical_left(state, state[i], i, strategy);
err += errk;
i = site;
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test_states/test_canonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
random_uniform_mps,
)
from seemps.state.canonical_mps import (
_update_in_canonical_form_left,
_update_in_canonical_form_right,
_update_canonical_left,
_update_canonical_right,
_canonicalize,
)
from ..fixture_mps_states import MPSStatesFixture
Expand All @@ -18,18 +18,18 @@
class TestCanonicalForm(MPSStatesFixture):
def test_local_update_canonical(self):
#
# We verify that _update_in_canonical_form() leaves a tensor that
# We verify that _update_canonical() leaves a tensor that
# is an approximate isometry.
#
def ok(Ψ, normalization=False):
strategy = DEFAULT_STRATEGY.replace(normalize=normalization)
for i in range(Ψ.size - 1):
ξ = Ψ.copy()
_update_in_canonical_form_right(ξ._data, ξ[i], i, strategy)
_update_canonical_right(ξ._data, ξ[i], i, strategy)
self.assertTrue(approximateIsometry(ξ[i], +1))
for i in range(1, Ψ.size):
ξ = Ψ.copy()
_update_in_canonical_form_left(ξ._data, ξ[i], i, strategy)
_update_canonical_left(ξ._data, ξ[i], i, strategy)
self.assertTrue(approximateIsometry(ξ[i], -1))

run_over_random_uniform_mps(ok)
Expand Down

0 comments on commit fb04d4b

Please sign in to comment.