From b176cb45dc0e6043386f8abc2080a42e8548f66b Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Tue, 17 Jul 2018 15:50:16 +0000 Subject: [PATCH] SINGA-386 Implement RNN operation for autograd - fix bugs in cpp parts, the codes can be made without error. --- python/singa/autograd.py | 40 ++++++- src/model/operation/rnn.cc | 227 ++++++++++++++++++++++++++----------- src/model/operation/rnn.h | 55 ++++++--- 3 files changed, 238 insertions(+), 84 deletions(-) diff --git a/python/singa/autograd.py b/python/singa/autograd.py index a084764eb8..1d649ca9b3 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -772,7 +772,7 @@ def __call__(self, x): self.handle.device_id = x.device.id() y = batchnorm_2d(self.handle, x, self.scale, self.bias, - self.running_mean, self.running_var) + self.running_mean, self.running_var) return y @@ -936,3 +936,41 @@ def __init__(self, kernel_size, stride=None, padding=0): stride = kernel_size super(MaxPool2d, self).__init__( (1, kernel_size), (0, stride), (0, padding), False) + + +class _RNN(Operation): + + def __init__(self, handle): + self.handle = handle + + def forward(self, X, W): + + if self.handle.device_id == -1: + raise NotImplementedError + else: + if training: + out, self.cache = singa.GpuRNNForwardTraining( + self.handle, X, W) + else: + out = singa.GpuRNNForwardInference(self.handle, X, W) + return out + + def backward(self, dY): + assert training is True and hasattr( + self, 'cache'), 'Please set training as True before do BP. ' + + if dY.device().id() != self.handle.device_id: + dY.ToDevice(self.inputs[0].device()) + + if self.handle.device_id == -1: + raise NotImplementedError + else: + dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache) + return dX, dW + + +def rnn(): + pass + + +class RNN(Layer): diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc index 2ef213912f..afeba67686 100644 --- a/src/model/operation/rnn.cc +++ b/src/model/operation/rnn.cc @@ -1,6 +1,10 @@ -RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string Rnn_mode, const float Dropout, const bool bidirectional) { - +#include "./rnn.h" + +namespace singa { + +RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional) { + input_size_ = Input_size; CHECK_GT(input_size_, 0u); hidden_size_ = Hidden_size; @@ -18,9 +22,9 @@ RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const si rnn_mode_ = Rnn_mode; if (rnn_mode_ == "lstm") { has_cell_ = true; - } else if (rnn_mode_ !="relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { + } else if (rnn_mode_ != "relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_ - << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; + << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; } // the first constant (4) is the size of float // the second constant (2, 8, 6) is the number of sets of params @@ -31,30 +35,39 @@ RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const si mult *= 4; else if (rnn_mode_ == "gru") mult *= 3; - if (direction_ == "bidirectional") + if (bidirectional) mult *= 2; weight_size = 0; for (size_t i = 0; i < num_stacks_; i++) { - size_t dim = hidden_size_ * (in_sample[0] + hidden_size_ + 2); + size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2); if (i > 0) dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2); weight_size += mult * dim; } -} +}; +#ifdef USE_CUDNN -CudnnRecHandle::CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional): - RecHandle(Input_size, Hidden_size, Num_stacks, nonlinearity, bias, dropout, bidirectional){ +CudnnRNNHandle::CudnnRNNHandle(const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional): + RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) { CHECK_GT(inputs.size(), 1u + has_cell_); size_t num_x = inputs.size() - has_cell_ - 1; + DataType dtype = inputs.at(0).data_type(); + if (rnn_desc_ != nullptr) + CHECK_EQ(dtype_, GetCudnnDataType(dtype)) + << "Cannot change cudnn data type during training from " << dtype_ + << " to " << GetCudnnDataType(dtype); + else + dtype_ = GetCudnnDataType(dtype); + UpdateStates(num_x, inputs); - } +}; -void CudnnRecHandle::UpdateStates(size_t num_x, const vector &inputs) { +void CudnnRNNHandle::UpdateStates(size_t num_x, const vector &inputs) { UpdateIODescriptors(num_x, inputs); size_t new_batch_size = inputs.at(0).shape(0); if (batch_size_ != new_batch_size) @@ -64,9 +77,28 @@ void CudnnRecHandle::UpdateStates(size_t num_x, const vector &inputs) { UpdateSpaces(num_x, inputs.at(0).device()); batch_size_ = new_batch_size; seq_length_ = num_x; -} +}; + +void CudnnRNNHandle::DestroyIODescriptors() { + if (x_descs_ != nullptr) { + for (size_t i = 0; i < max_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); + } + delete [] x_descs_; + delete [] dx_descs_; + } + if (y_descs_ != nullptr) { + for (size_t i = 0; i < max_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); + } + delete [] y_descs_; + delete [] dy_descs_; + } +}; -void CudnnRecHandle::UpdateIODescriptors(size_t len, const vector &inputs) { +void CudnnRNNHandle::UpdateIODescriptors(size_t len, const vector &inputs) { bool reset = false; if (max_length_ < len) { DestroyIODescriptors(); @@ -104,9 +136,9 @@ void CudnnRecHandle::UpdateIODescriptors(size_t len, const vector &input CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); } } -} +}; -void CudnnRecHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { +void CudnnRNNHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { if (batch_size_ == 0) { CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); @@ -133,9 +165,9 @@ void CudnnRecHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride)); -} +}; -void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { +void CudnnRNNHandle::SetRNNDescriptor(shared_ptr dev) { auto ctx = dev->context(0); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_)); size_t state_size; @@ -148,7 +180,7 @@ void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; //if (input_mode_ == "skip") - //input_mode = CUDNN_SKIP_INPUT; + //input_mode = CUDNN_SKIP_INPUT; cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; if (num_directions_ == 2) @@ -179,9 +211,9 @@ void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_)); CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_, CUDNN_TENSOR_NCHW, 3, filter_dim)); -} +}; -void CudnnRecHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { +void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { size_t count; auto ctx = dev->context(0); CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, @@ -210,11 +242,11 @@ Tensor MergeInputs(size_t num, const vector &in) { offset += in.at(i).Size(); } return out; -} +}; vector SplitOutput(size_t num, size_t dim, - const vector &in, - const Tensor output) { + const vector &in, + const Tensor output) { vector outputs; if (num == 1) { outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); @@ -229,30 +261,23 @@ vector SplitOutput(size_t num, size_t dim, CHECK_EQ(num, outputs.size()); } return outputs; -} +}; -const std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ +std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { DataType dtype = inputs.at(0).data_type(); auto dev = inputs.at(0).device(); - CHECK_GT(inputs.size(), 1u + has_cell_); - size_t num_x = inputs.size() - has_cell_ - 1; + CHECK_GT(inputs.size(), 1u + crh.has_cell_); + size_t num_x = inputs.size() - crh.has_cell_ - 1; Tensor input = MergeInputs(num_x, inputs); - if (rnn_desc_ != nullptr) - CHECK_EQ(dtype_, GetCudnnDataType(dtype)) - << "Cannot change cudnn data type during training from " << dtype_ - << " to " << GetCudnnDataType(dtype); - else - dtype_ = GetCudnnDataType(dtype); - Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; Tensor output(outshape, dev, dtype); // LOG(INFO) << "output size " << output.Size(); Tensor hx = inputs.at(num_x); Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; Tensor hy(state_shape, dev, dtype); - + Tensor cy, cx; if (crh.has_cell_) { cx = inputs.at(num_x + 1); @@ -285,30 +310,30 @@ const std::vector> GpuRecForwardTraining(const CudnnRecHandl *rspace = crh.reserve_space_.block(); dev->Exec( - [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { - // clang-format off - cudnnRNNForwardTraining( - ctx->cudnn_handle, - crh.rnn_desc_, - crh.seq_length_, - crh.x_descs_, inb->data(), - crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), - crh.weight_desc_, wb->data(), - crh.y_descs_, outb->mutable_data(), - crh.hy_desc_, hyb->mutable_data(), - crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), - wspace->mutable_data(), - crh.workspace_.Size(), rspace->mutable_data(), - crh.reserve_space_.Size()); - // clang-format on - }, - {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); + [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardTraining( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), + crh.workspace_.Size(), rspace->mutable_data(), + crh.reserve_space_.Size()); + // clang-format on + }, + {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); auto outputs = SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); outputs.push_back(hy); - if (has_cell_) outputs.push_back(cy); + if (crh.has_cell_) outputs.push_back(cy); std::vector cache; cache.push_back(input); @@ -318,18 +343,82 @@ const std::vector> GpuRecForwardTraining(const CudnnRecHandl cache.push_back(W); return {outputs, cache}; -} +}; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { + DataType dtype = inputs.at(0).data_type(); + auto dev = inputs.at(0).device(); -const std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ + CHECK_GT(inputs.size(), 1u + crh.has_cell_); + size_t num_x = inputs.size() - crh.has_cell_ - 1; + Tensor input = MergeInputs(num_x, inputs); -} + Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; + Tensor output(outshape, dev, dtype); + // LOG(INFO) << "output size " << output.Size(); + Tensor hx = inputs.at(num_x); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor hy(state_shape, dev, dtype); + + Tensor cy, cx; + if (crh.has_cell_) { + cx = inputs.at(num_x + 1); + cy.ResetLike(hy); + } + + int did = input.device()->id(); + CHECK_EQ(did, output.device()->id()); + if (hx.Size()) { + CHECK_EQ(did, hx.device()->id()); + CHECK_EQ(hx.device()->lang(), kCuda); + } + if (cx.Size()) { + CHECK_EQ(did, cx.device()->id()); + CHECK_EQ(cx.device()->lang(), kCuda); + } + CHECK_EQ(did, W.device()->id()); + CHECK_EQ(did, crh.workspace_.device()->id()); + CHECK_EQ(input.device()->lang(), kCuda); + CHECK_EQ(output.device()->lang(), kCuda); + CHECK_EQ(W.device()->lang(), kCuda); + CHECK_EQ(crh.workspace_.device()->lang(), kCuda); -const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache){ - const Tensor x= cache[0]; - const Tensor y= cache[1]; - const Tensor hx= cache[2]; - const Tensor cx= cache[3]; - const Tensor W= cache[4]; + Block *inb = input.block(), *outb = output.block(), + *wb = W.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = crh.workspace_.block(); + + dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardInference( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), crh.workspace_.Size()); + // clang-format on + }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); + + auto outputs = + SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); + outputs.push_back(hy); + if (crh.has_cell_) outputs.push_back(cy); + + return outputs; +}; + +std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache) { + const Tensor x = cache[0]; + const Tensor y = cache[1]; + const Tensor hx = cache[2]; + const Tensor cx = cache[3]; + const Tensor W = cache[4]; auto dev = y.device(); auto dtype = y.data_type(); @@ -396,12 +485,14 @@ const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx); data_grads.push_back(dhx); - if (has_cell_) + if (crh.has_cell_) data_grads.push_back(dcx); return std::make_pair(data_grads, dw); -} +}; +#endif // USE_CUDNN +} // namespace singa diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h index 5ca5c2151c..0dbbac9974 100644 --- a/src/model/operation/rnn.h +++ b/src/model/operation/rnn.h @@ -1,7 +1,23 @@ -class RecHandle { +#ifndef SINGA_MODEL_OPERATION_CUDNN_RNN_H_ +#define SINGA_MODEL_OPERATION_CUDNN_RNN_H_ + +#include +#include +#include "singa/core/tensor.h" + + +#ifdef USE_CUDNN +#include +#include "../layer/cudnn_utils.h" +#endif // USE_CUDNN + + +namespace singa { + +class RNNHandle { public: - RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional); size_t input_size_; size_t hidden_size_; @@ -10,24 +26,28 @@ class RecHandle { size_t seed_ = 0x1234567; size_t num_directions_; std::string rnn_mode_; - bool has_cell; + bool has_cell_; size_t weight_size; - size_t batch_size = 0; + size_t batch_size_ = 0; size_t seq_length_ = 0; size_t max_length_ = 0; -} +}; + +#ifdef USE_CUDNN -class CudnnRecHandle: public RecHandle { +class CudnnRNNHandle: public RNNHandle { public: - CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + CudnnRNNHandle(const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional); void UpdateStates(size_t num_x, const vector &inputs); + void DestroyIODescriptors(); void UpdateIODescriptors(size_t len, const vector &inputs); void ResetHiddenAndCellDescriptors(size_t batch_size); void SetRNNDescriptor(shared_ptr dev); void UpdateSpaces(size_t seq_length, shared_ptr dev); + cudnnTensorDescriptor_t* x_descs_ = nullptr; cudnnTensorDescriptor_t* dx_descs_ = nullptr; cudnnTensorDescriptor_t* y_descs_ = nullptr; @@ -48,16 +68,21 @@ class CudnnRecHandle: public RecHandle { Tensor workspace_; Tensor reserve_space_; Tensor dropout_state_; -} +}; Tensor MergeInputs(size_t num, const vector &in); vector SplitOutput(size_t num, size_t dim, - const vector &in, - const Tensor output); + const vector &in, + const Tensor output); + +std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); -std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W); +std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache); -std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs,const Tensor &W); +#endif // USE_CUDNN -const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache); +} // namespace singa +#endif // SINGA_MODEL_OPERATION_CUDNN_RNN_H_