Skip to content

Commit

Permalink
VTEN-13-Add-astype-method (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
shengtsui authored Nov 12, 2024
1 parent 35e19dd commit 41ef7a1
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 10 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ sphinx-build -b html docs/source docs/build/html
```

### Future updates
- Support Solvers (Cholesky, QR etc.)
- Support GPUDirect
- Support more matrix operations
- Support Sparse martix with CuSparse
7 changes: 7 additions & 0 deletions docs/source/api/core/astype.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
vt::astype
=======================

.. toctree::
:maxdepth: 1

.. doxygenfunction:: vt::astype
1 change: 1 addition & 0 deletions docs/source/api/core/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Core
.. toctree::
:maxdepth: 1

astype
broadcast
broadcast_to
cutensor
Expand Down
1 change: 1 addition & 0 deletions lib/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cuda_library(
visibility = ["//visibility:public"],
hdrs = [
"assertions.hpp",
"astype.hpp",
"broadcast.hpp",
"cutensor.hpp",
"iterator.hpp",
Expand Down
31 changes: 31 additions & 0 deletions lib/core/astype.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "lib/core/tensor.hpp"

namespace vt {

/// Forward declaration of the Tensor class.
template <typename T, size_t N>
class Tensor;

/**
* @brief Convert the tensor to a new data type.
*
* @tparam T: Data type of the tensor.
* @tparam U: The data type to cast.
* @tparam N: Number of dimensions of the tensor.
* @param tensor: The tensor object.
* @return Tensor<U, N>: The new tensor object.
*/
template <typename T, typename U, size_t N>
Tensor<U, N> astype(Tensor<T, N> tensor) {
if constexpr (std::is_same_v<T, U>) {
return tensor;
} else {
auto result = Tensor<U, N>(tensor.shape());
thrust::transform(tensor.begin(), tensor.end(), result.begin(), [] __device__(const T& x) { return static_cast<U>(x); });
return result;
}
}

} // namespace vt
13 changes: 4 additions & 9 deletions lib/core/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <rmm/device_vector.hpp>

#include "lib/core/assertions.hpp"
#include "lib/core/astype.hpp"
#include "lib/core/iterator.hpp"
#include "lib/core/slice.hpp"

Expand All @@ -30,7 +31,7 @@ using Shape = std::array<size_t, N>;
template <size_t N>
size_t get_size(const Shape<N>& shape) {
size_t size = 1;
for (size_t i = 0; i < N; ++i) size *= shape[i];
for (int i = 0; i < N; ++i) size *= shape[i];
return size;
}

Expand Down Expand Up @@ -417,14 +418,8 @@ class Tensor {
* @return Tensor<U, N>: The new tensor object.
*/
template <typename U>
Tensor<U, N> astype() const {
if constexpr (std::is_same_v<T, U>) {
return *this;
} else {
auto result = Tensor<U, N>(_shape);
thrust::transform(this->begin(), this->end(), result.begin(), [] __device__(const T& x) { return static_cast<U>(x); });
return result;
}
Tensor<U, N> astype() {
return vt::astype<T, U, N>(*this);
}

/**
Expand Down

0 comments on commit 41ef7a1

Please sign in to comment.