Skip to content

Commit

Permalink
Fixed code to work with GCC
Browse files Browse the repository at this point in the history
  • Loading branch information
juanjosegarciaripoll committed Apr 30, 2024
1 parent 6770298 commit ecdf533
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 53 deletions.
6 changes: 6 additions & 0 deletions src/seemps/state/blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ template <class f>
static void load_wrapper(py::dict &__pyx_capi__, const char *name,
f *&pointer) {
py::capsule wrapper = __pyx_capi__[name];
#if 1
// This copes with a bug in pybind11, which uses
// static_cast to cast a void* to the pointer.
pointer = reinterpret_cast<f *>(wrapper.get_pointer<void>());
#else
pointer = wrapper.get_pointer<f>();
#endif
}

void load_scipy_wrappers() {
Expand Down
29 changes: 15 additions & 14 deletions src/seemps/state/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,25 @@ Environment CanonicalMPS::right_environment(int site) const {
return rho;
}

py::object CanonicalMPS::Schmidt_weights(int site) const {
py::list CanonicalMPS::Schmidt_weights(int site) const {
site = interpret_center(site);
return schmidt_weights(
((site == center()) ? (*this) : copy().recenter(site, strategy_))
.center_tensor());
}

double CanonicalMPS::entanglement_entropy(int center) const {
auto s = Schmidt_weights(center);
return std::accumulate(s.begin(), s.end(), 0.0,
[](double entropy, auto schmidt_weight) -> double {
auto w = schmidt_weight.cast<double>();
return entropy - w * std::log2(w);
});
const py::list s = Schmidt_weights(center);
return std::accumulate(
py::begin(s), py::end(s), 0.0,
[](double entropy, const py::object schmidt_weight) -> double {
auto w = schmidt_weight.cast<double>();
return entropy - w * std::log2(w);
});
}

double CanonicalMPS::Renyi_entropy(int center, double alpha) const {
auto s = Schmidt_weights(center);
const py::list s = Schmidt_weights(center);
if (alpha < 0) {
std::invalid_argument("Invalid Renyi entropy power");
}
Expand All @@ -136,11 +137,12 @@ double CanonicalMPS::Renyi_entropy(int center, double alpha) const {
} else if (alpha == 1) {
alpha = 1 - 1e-9;
}
return std::log(std::accumulate(s.begin(), s.end(), 0.0,
[=](double sum, auto schmidt_weight) {
double w = schmidt_weight.cast<double>();
return sum + std::pow(w, alpha);
})) /
return std::log(
std::accumulate(py::begin(s), py::end(s), 0.0,
[=](double sum, const py::object schmidt_weight) {
double w = schmidt_weight.cast<double>();
return sum + std::pow(w, alpha);
})) /
(1 - alpha);
}

Expand Down Expand Up @@ -184,7 +186,6 @@ const CanonicalMPS &CanonicalMPS::recenter(int new_center,
const Strategy &strategy) {

new_center = interpret_center(new_center);
auto old_center = center();
while (center() < new_center) {
update_canonical(center_tensor(), +1, strategy);
}
Expand Down
2 changes: 1 addition & 1 deletion src/seemps/state/contractions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace seemps {

py::object _matmul(py::object &A, py::object &B) {
py::object _matmul(const py::object &A, const py::object &B) {
auto numpy = py::module_::import("numpy");
auto matmul = numpy.attr("matmul");
return matmul(A, B);
Expand Down
2 changes: 1 addition & 1 deletion src/seemps/state/mps.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class CanonicalMPS : public MPS {

Environment right_environment(int site) const;

py::object Schmidt_weights(int site = no_defined_center) const;
py::list Schmidt_weights(int site = no_defined_center) const;

double entanglement_entropy(int center) const;

Expand Down
20 changes: 0 additions & 20 deletions src/seemps/state/strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ static double _truncate_relative_norm_squared(const py::object &a,
}

double max_error = total * s.get_tolerance();
double final_error = 0.0;
size_t final_size = 1;
for (i = 1; i < N; ++i) {
if (errors[i] > max_error) {
Expand Down Expand Up @@ -243,23 +242,4 @@ double destructively_truncate_vector(const py::object a, const Strategy &s) {
}
}

static py::object contract_nrjl_ijk_klm(py::object U, py::object A,
py::object B) {
if (PyArray_Check(A.ptr()) == 0 || PyArray_Check(B.ptr()) == 0 ||
PyArray_Check(U.ptr()) == 0 || array_ndim(A) != 3 || array_ndim(B) != 3 ||
array_ndim(U) != 2) {
throw std::invalid_argument("Invalid arguments to _contract_nrjl_ijk_klm");
}
auto a = array_dim(A, 0);
auto d = array_dim(A, 1);
auto b = array_dim(A, 2);
auto e = array_dim(B, 1);
auto c = array_dim(B, 2);
npy_intp final_dims[4] = {a, d, e, c};
npy_intp intermediate_dims[3] = {a, d * e, c};
auto AB = matrix_product(as_matrix(A, a * d, b), as_matrix(B, b, e * c));
return array_reshape(_matmul(U, array_reshape(AB, intermediate_dims)),
final_dims);
}

} // namespace seemps
6 changes: 3 additions & 3 deletions src/seemps/state/tensors.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ template <typename elt> inline elt *array_data(const py::object &a) {
using array_dims_t = std::initializer_list<npy_intp>;

template <class Dimensions>
inline py::object array_reshape(const py::object &a, Dimensions &d) {
inline py::object array_reshape(const py::object &a, const Dimensions &d) {
PyArray_Dims dims = {const_cast<npy_intp *>(&(*std::begin(d))),
static_cast<int>(std::size(d))};
return py::reinterpret_steal<py::object>(
Expand Down Expand Up @@ -164,7 +164,7 @@ inline py::object empty_matrix(npy_intp rows, npy_intp cols, int type) {
template <class Dimensions>
inline py::object zero_array(const Dimensions &dims, int type = NPY_DOUBLE) {
auto the_dims = const_cast<npy_intp *>(&(*std::begin(dims)));
auto rank = static_cast<int>(std::size(d));
auto rank = static_cast<int>(std::size(dims));
return py::reinterpret_steal<py::object>(
PyArray_ZEROS(rank, the_dims, type, 0));
}
Expand Down Expand Up @@ -204,7 +204,7 @@ inline py::object array_conjugate(const py::object &a) {
* Advanced contractions
*/

py::object _matmul(py::object &A, py::object &B);
py::object _matmul(const py::object &A, const py::object &B);

py::object contract_last_and_first(py::object A, py::object B);
py::object contract_nrjl_ijk_klm(py::object U, py::object A, py::object B);
Expand Down
29 changes: 15 additions & 14 deletions src/seemps/state/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class python_list_iterator {
operator object() const { return list_[index_]; }
};

iterator(const iterator &) = default;
iterator(iterator &&) = default;
iterator &operator=(const iterator &) = default;
iterator &operator=(iterator &&) = default;
iterator(list &list, size_t index) : list_{list}, index_{index} {}
python_list_iterator(const python_list_iterator &) = default;
python_list_iterator(python_list_iterator &&) = default;
python_list_iterator &operator=(const python_list_iterator &) = default;
python_list_iterator &operator=(python_list_iterator &&) = default;
python_list_iterator(list &list, size_t index) : list_{list}, index_{index} {}
~python_list_iterator() = default;

bool operator==(const iterator &other) const {
Expand Down Expand Up @@ -70,35 +70,36 @@ class python_list_const_iterator {
using pointer = object *;
using reference = object &;

iterator(const iterator &) = default;
iterator(iterator &&) = default;
iterator &operator=(const iterator &it) {
python_list_const_iterator(const python_list_const_iterator &) = default;
python_list_const_iterator(python_list_const_iterator &&) = default;
python_list_const_iterator &operator=(const python_list_const_iterator &it) {
list_ = it.list_;
index_ = it.index_;
return *this;
}
iterator &operator=(iterator &&it) {
python_list_const_iterator &operator=(python_list_const_iterator &&it) {
list_ = std::move(it.list_);
index_ = it.index_;
return *this;
}
iterator(const list &list, size_t index) : list_{list}, index_{index} {}
python_list_const_iterator(const list &list, size_t index)
: list_{list}, index_{index} {}
~python_list_const_iterator() = default;

bool operator==(const iterator &other) const {
bool operator==(const python_list_const_iterator &other) const {
return index_ == other.index_;
}

bool operator!=(const iterator &other) const {
bool operator!=(const python_list_const_iterator &other) const {
return index_ != other.index_;
}

iterator &operator++() {
python_list_const_iterator &operator++() {
++index_;
return *this;
}

iterator operator++(int) {
python_list_const_iterator operator++(int) {
auto return_value = *this;
++index_;
return return_value;
Expand Down

0 comments on commit ecdf533

Please sign in to comment.