From 41ef7a100a723980451356563c05cfdba5b82ae1 Mon Sep 17 00:00:00 2001 From: shengtsui <124718322+shengtsui@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:13:32 -0800 Subject: [PATCH] VTEN-13-Add-astype-method (#15) --- README.md | 1 - docs/source/api/core/astype.rst | 7 +++++++ docs/source/api/core/index.rst | 1 + lib/core/BUILD | 1 + lib/core/astype.hpp | 31 +++++++++++++++++++++++++++++++ lib/core/tensor.hpp | 13 ++++--------- 6 files changed, 44 insertions(+), 10 deletions(-) create mode 100644 docs/source/api/core/astype.rst create mode 100644 lib/core/astype.hpp diff --git a/README.md b/README.md index a0a4b22..13f625f 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,6 @@ sphinx-build -b html docs/source docs/build/html ``` ### Future updates -- Support Solvers (Cholesky, QR etc.) - Support GPUDirect - Support more matrix operations - Support Sparse martix with CuSparse diff --git a/docs/source/api/core/astype.rst b/docs/source/api/core/astype.rst new file mode 100644 index 0000000..4ca80a9 --- /dev/null +++ b/docs/source/api/core/astype.rst @@ -0,0 +1,7 @@ +vt::astype +======================= + +.. toctree:: + :maxdepth: 1 + +.. doxygenfunction:: vt::astype diff --git a/docs/source/api/core/index.rst b/docs/source/api/core/index.rst index 726755c..7ad73f8 100644 --- a/docs/source/api/core/index.rst +++ b/docs/source/api/core/index.rst @@ -4,6 +4,7 @@ Core .. toctree:: :maxdepth: 1 + astype broadcast broadcast_to cutensor diff --git a/lib/core/BUILD b/lib/core/BUILD index 97bc24d..0e46270 100644 --- a/lib/core/BUILD +++ b/lib/core/BUILD @@ -5,6 +5,7 @@ cuda_library( visibility = ["//visibility:public"], hdrs = [ "assertions.hpp", + "astype.hpp", "broadcast.hpp", "cutensor.hpp", "iterator.hpp", diff --git a/lib/core/astype.hpp b/lib/core/astype.hpp new file mode 100644 index 0000000..9505a91 --- /dev/null +++ b/lib/core/astype.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "lib/core/tensor.hpp" + +namespace vt { + +/// Forward declaration of the Tensor class. +template +class Tensor; + +/** + * @brief Convert the tensor to a new data type. + * + * @tparam T: Data type of the tensor. + * @tparam U: The data type to cast. + * @tparam N: Number of dimensions of the tensor. + * @param tensor: The tensor object. + * @return Tensor: The new tensor object. + */ +template +Tensor astype(Tensor tensor) { + if constexpr (std::is_same_v) { + return tensor; + } else { + auto result = Tensor(tensor.shape()); + thrust::transform(tensor.begin(), tensor.end(), result.begin(), [] __device__(const T& x) { return static_cast(x); }); + return result; + } +} + +} // namespace vt \ No newline at end of file diff --git a/lib/core/tensor.hpp b/lib/core/tensor.hpp index a20a7ff..d5f692e 100644 --- a/lib/core/tensor.hpp +++ b/lib/core/tensor.hpp @@ -11,6 +11,7 @@ #include #include "lib/core/assertions.hpp" +#include "lib/core/astype.hpp" #include "lib/core/iterator.hpp" #include "lib/core/slice.hpp" @@ -30,7 +31,7 @@ using Shape = std::array; template size_t get_size(const Shape& shape) { size_t size = 1; - for (size_t i = 0; i < N; ++i) size *= shape[i]; + for (int i = 0; i < N; ++i) size *= shape[i]; return size; } @@ -417,14 +418,8 @@ class Tensor { * @return Tensor: The new tensor object. */ template - Tensor astype() const { - if constexpr (std::is_same_v) { - return *this; - } else { - auto result = Tensor(_shape); - thrust::transform(this->begin(), this->end(), result.begin(), [] __device__(const T& x) { return static_cast(x); }); - return result; - } + Tensor astype() { + return vt::astype(*this); } /**