diff --git a/src/tensor/dispatch.h b/src/tensor/dispatch.h index 421fffb..0398b14 100644 --- a/src/tensor/dispatch.h +++ b/src/tensor/dispatch.h @@ -3,16 +3,20 @@ #include "Tensor.h" #include +#include namespace xt { // Tensor version -template -auto dispatch(Tensor& t, T&... args) -> typename std::result_of)(F&, Tensor&, T&...)>::type +template +auto dispatch(TensorT&& t, Args&&... args) -> typename std::result_of)(F&, TensorT&&, Args&&...)>::type { - using ReturnType = typename std::result_of)(F&, Tensor&, T&...)>::type; + using ReturnT = typename std::result_of)(F&, TensorT&&, Args&&...)>::type; + using FunctionT = std::function; + F functor; + if(t.device() == kCPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template cpu, &F::template cpu, &F::template cpu, @@ -21,10 +25,9 @@ auto dispatch(Tensor& t, T&... args) -> typename std::result_of, &F::template cpu, }}; - F functor; - return dyn.at(t.type())(functor, t, args...); + return dyn.at(t.type())(functor, std::forward(t), std::forward(args)...); } else if(t.device() == kGPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template gpu, &F::template gpu, &F::template gpu, @@ -33,54 +36,22 @@ auto dispatch(Tensor& t, T&... args) -> typename std::result_of, &F::template gpu, }}; - F functor; - return dyn.at(t.type())(functor, t, args...); + return dyn.at(t.type())(functor, std::forward(t), std::forward(args)...); } else { throw std::invalid_argument("unsupported device"); } } // Context, Tensor version -template -auto dispatch(Context& ctx, Tensor& t, T&... args) -> typename std::result_of)(F&, Context&, Tensor&, T&...)>::type +template +auto dispatch(Context& ctx, TensorT&& t, Args&&... args) -> typename std::result_of)(F&, Context&, TensorT&&, Args&&...)>::type { - using ReturnType = typename std::result_of)(F&, Context&, Tensor&, T&...)>::type; - if(t.device() == kCPU) { - static std::array, 7> dyn = {{ - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - }}; - F functor; - return dyn.at(t.type())(functor, t, args...); - } else if(t.device() == kGPU) { - static std::array, 7> dyn = {{ - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - }}; - F functor; - return dyn.at(t.type())(functor, ctx, t, args...); - } else { - throw std::invalid_argument("unsupported device"); - } -} + using ReturnT = typename std::result_of)(F&, Context&, TensorT&&, Args&&...)>::type; + using FunctionT = std::function; + F functor; -// const Tensor version -template -auto dispatch(const Tensor& t, T&... args) -> typename std::result_of)(F&, const Tensor&, T&...)>::type -{ - using ReturnType = typename std::result_of)(F&, const Tensor&, T&...)>::type; if(t.device() == kCPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template cpu, &F::template cpu, &F::template cpu, @@ -89,10 +60,9 @@ auto dispatch(const Tensor& t, T&... args) -> typename std::result_of, &F::template cpu, }}; - F functor; - return dyn.at(t.type())(functor, t, args...); + return dyn.at(t.type())(functor, ctx, std::forward(t), std::forward(args)...); } else if(t.device() == kGPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template gpu, &F::template gpu, &F::template gpu, @@ -101,54 +71,22 @@ auto dispatch(const Tensor& t, T&... args) -> typename std::result_of, &F::template gpu, }}; - F functor; - return dyn.at(t.type())(functor, t, args...); - } else { - throw std::invalid_argument("unsupported device"); - } -} - -// Context, const Tensor version -template -auto dispatch(Context& ctx, const Tensor& t, T&... args) -> typename std::result_of)(F&, Context&, const Tensor&, T&...)>::type -{ - using ReturnType = typename std::result_of)(F&, Context&, const Tensor&, T&...)>::type; - if(t.device() == kCPU) { - static std::array, 7> dyn = {{ - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - &F::template cpu, - }}; - F functor; - return dyn.at(t.type())(functor, ctx, t, args...); - } else if(t.device() == kGPU) { - static std::array, 7> dyn = {{ - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - &F::template gpu, - }}; - F functor; - return dyn.at(t.type())(functor, ctx, t, args...); + return dyn.at(t.type())(functor, ctx, std::forward(t), std::forward(args)...); } else { throw std::invalid_argument("unsupported device"); } } // type/device version -template -auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std::result_of)(F&, T&...)>::type +template +auto dispatch(TensorType ttype, TensorDevice tdev, Args&&... args) -> typename std::result_of)(F&, Args&&...)>::type { - using ReturnType = typename std::result_of)(F&, T&...)>::type; + using ReturnT = typename std::result_of)(F&, Args&&...)>::type; + using FunctionT = std::function; + F functor; + if(tdev == kCPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template cpu, &F::template cpu, &F::template cpu, @@ -157,10 +95,9 @@ auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std:: &F::template cpu, &F::template cpu, }}; - F functor; - return dyn.at(ttype)(functor, args...); + return dyn.at(ttype)(functor, std::forward(args)...); } else if(tdev == kGPU) { - static std::array, 7> dyn = {{ + static std::array dyn = {{ &F::template gpu, &F::template gpu, &F::template gpu, @@ -169,8 +106,7 @@ auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std:: &F::template gpu, &F::template gpu, }}; - F functor; - return dyn.at(ttype)(functor, args...); + return dyn.at(ttype)(functor, std::forward(args)...); } else { throw std::invalid_argument("unsupported device"); } diff --git a/src/tensor/test/basic.cc b/src/tensor/test/basic.cc index ad407d3..c1a2a42 100644 --- a/src/tensor/test/basic.cc +++ b/src/tensor/test/basic.cc @@ -4,7 +4,7 @@ using namespace xt; -struct sum_op +struct sum_op_ref { template Tensor cpu(Tensor& x) { @@ -19,12 +19,60 @@ struct sum_op } return sum; }; + + template Tensor gpu(Tensor& x) { throw std::invalid_argument("device not supported"); }; }; +struct sum_op_const_ref +{ + template Tensor cpu(const Tensor& x) + { + if(!isContiguous(x)) { + throw std::invalid_argument("contiguous tensor expected"); + } + T* x_p = x.data(); + int64_t size = numel(x); + T sum = 0; + for(int64_t i = 0; i < size; i++) { + sum += x_p[i]; + } + return sum; + }; + + + template Tensor gpu(const Tensor& x) + { + throw std::invalid_argument("device not supported"); + }; +}; + +struct sum_op_rvalue_ref +{ + template Tensor cpu(Tensor&& x) + { + if(!isContiguous(x)) { + throw std::invalid_argument("contiguous tensor expected"); + } + T* x_p = x.data(); + int64_t size = numel(x); + T sum = 0; + for(int64_t i = 0; i < size; i++) { + sum += x_p[i]; + } + return sum; + }; + + + template Tensor gpu(const Tensor&& x) + { + throw std::invalid_argument("device not supported"); + }; +}; + static void test(TensorDevice device) { { @@ -184,10 +232,26 @@ static void test(TensorDevice device) if(device == kCPU) { - std::cout << "manual sum:" << std::endl; + std::cout << "manual sum (ref dispatch):" << std::endl; + Tensor a = rand({3, 7}, kFloat, device); + std::cout << a << std::endl; + std::cout << dispatch(a) << " == " << sum(a) << std::endl; + } + + if(device == kCPU) + { + std::cout << "manual sum (const ref dispatch):" << std::endl; + const Tensor a = rand({3, 7}, kFloat, device); + std::cout << a << std::endl; + std::cout << dispatch(a) << " == " << sum(a) << std::endl; + } + + if(device == kCPU) + { + std::cout << "manual sum (rvalue ref dispatch):" << std::endl; Tensor a = rand({3, 7}, kFloat, device); std::cout << a << std::endl; - std::cout << dispatch(a) << " == " << sum(a) << std::endl; + std::cout << dispatch(std::move(a)) << " == " << sum(a) << std::endl; } {