-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast and generic implementation using OpenMP and CUDA #45
Open
shikishima-TasakiLab
wants to merge
14
commits into
d-li14:main
Choose a base branch
from
shikishima-TasakiLab:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
221b32b
First commit
shikishima-TasakiLab c32b399
CPU-only support
shikishima-TasakiLab 41a5fb0
Update README.md
shikishima-TasakiLab 1140e73
Remove "AutoNonVariableTypeMode".
shikishima-TasakiLab 684769a
Merge branch 'main' of https://github.com/shikishima-TasakiLab/Involu…
shikishima-TasakiLab 4b036d3
Merge remote-tracking branch 'upstream/main' into main
shikishima-TasakiLab ce9100f
Compatible with PyTorch 1.7.0 or later
shikishima-TasakiLab 4dc6c8f
Update README.md
shikishima-TasakiLab 57a6d6a
Merge remote-tracking branch 'upstream/main' into main
shikishima-TasakiLab 0c190bc
Fixed the value of CUDA_MAX_THREADS
shikishima-TasakiLab 353c22a
Fixed auto-casting to float32 when using AMP.
shikishima-TasakiLab 265c309
Merge remote-tracking branch 'upstream/main' into main
shikishima-TasakiLab 2319dea
Support fatbin.
shikishima-TasakiLab 7734b83
Merge remote-tracking branch 'upstream/main' into main
shikishima-TasakiLab File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
RUN_DIR=$(dirname $(readlink -f $0)) | ||
|
||
DOCKER_VOLUME="${DOCKER_VOLUME} -v $(dirname ${RUN_DIR}):/workspace/involution:rw" | ||
|
||
docker run \ | ||
-it \ | ||
--rm \ | ||
--gpus '"device=0"' \ | ||
${DOCKER_VOLUME} \ | ||
--name Involution-PyTorch \ | ||
pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel bash | ||
# pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel bash | ||
# pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel bash | ||
# pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash | ||
# nvcr.io/nvidia/pytorch:21.05-py3 | ||
# nvcr.io/nvidia/pytorch:20.08-py3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/Parallel.h> | ||
|
||
namespace involution { | ||
namespace cpu { | ||
|
||
at::Tensor involution2d_forward( | ||
const at::Tensor& input, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
at::Tensor involution2d_backward_grad_input( | ||
const at::Tensor& grad, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& input_shape, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
at::Tensor involution2d_backward_grad_weight( | ||
const at::Tensor& grad, | ||
const at::Tensor& input, | ||
const std::vector<int64_t>& weight_shape, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
std::vector<at::Tensor> involution2d_backward( | ||
const at::Tensor& grad, | ||
const at::Tensor& weight, | ||
const at::Tensor& input, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
} // namespace cpu | ||
} // namespace involution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
namespace involution { | ||
namespace cuda { | ||
|
||
#define CUDA_MAX_THREADS 1024u | ||
|
||
#define CUDA_KERNEL_LOOP(i, n) \ | ||
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) | ||
|
||
at::Tensor involution2d_forward( | ||
const at::Tensor& input, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
at::Tensor involution2d_backward_grad_input( | ||
const at::Tensor& grad, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& input_shape, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
at::Tensor involution2d_backward_grad_weight( | ||
const at::Tensor& grad, | ||
const at::Tensor& input, | ||
const std::vector<int64_t>& weight_shape, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
std::vector<at::Tensor> involution2d_backward( | ||
const at::Tensor& grad, | ||
const at::Tensor& weight, | ||
const at::Tensor& input, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
); | ||
|
||
} // namespace cuda | ||
} // namespace involution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
#pragma once | ||
|
||
#include <ATen/core/dispatch/Dispatcher.h> | ||
#include <ATen/autocast_mode.h> | ||
#include <torch/csrc/autograd/custom_function.h> | ||
|
||
#include "involution2d_cpu.h" | ||
|
||
#ifdef USE_CUDA | ||
# include "involution2d_cuda.cuh" | ||
#endif | ||
|
||
namespace involution { | ||
|
||
at::Tensor involution2d( | ||
const at::Tensor& input, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation | ||
) { | ||
static auto op = at::Dispatcher::singleton() | ||
.findSchemaOrThrow("involution::involution2d", "") | ||
.typed<decltype(involution2d)>(); | ||
|
||
return op.call(input, weight, stride, padding, dilation); | ||
} | ||
|
||
at::Tensor involution2d_autocast( | ||
const at::Tensor& input, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation | ||
) { | ||
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); | ||
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); | ||
return involution2d(at::autocast::cached_cast(exec_type, input), at::autocast::cached_cast(exec_type, weight), stride, padding, dilation) | ||
.to(input.scalar_type()); | ||
} | ||
|
||
at::Tensor _involution2d_backward_grad_input( | ||
const at::Tensor& grad, | ||
const at::Tensor& weight, | ||
const std::vector<int64_t>& input_shape, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation | ||
) { | ||
static auto op = at::Dispatcher::singleton() | ||
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_input", "") | ||
.typed<decltype(_involution2d_backward_grad_input)>(); | ||
|
||
return op.call(grad, weight, input_shape, stride, padding, dilation); | ||
} | ||
|
||
at::Tensor _involution2d_backward_grad_weight( | ||
const at::Tensor& grad, | ||
const at::Tensor& input, | ||
const std::vector<int64_t>& weight_shape, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation | ||
) { | ||
static auto op = at::Dispatcher::singleton() | ||
.findSchemaOrThrow("involution2d::_involution2d_backward_grad_weight", "") | ||
.typed<decltype(_involution2d_backward_grad_weight)>(); | ||
|
||
return op.call(grad, input, weight_shape, stride, padding, dilation); | ||
} | ||
|
||
namespace cpu { | ||
|
||
class Involution2dFunctionCPU : public torch::autograd::Function<Involution2dFunctionCPU> | ||
{ | ||
public: | ||
|
||
static torch::autograd::variable_list forward( | ||
torch::autograd::AutogradContext* ctx, | ||
const torch::autograd::Variable& input, | ||
const torch::autograd::Variable& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
) { | ||
ctx->saved_data["kernel_size"] = kernel_size; | ||
ctx->saved_data["stride"] = stride; | ||
ctx->saved_data["padding"] = padding; | ||
ctx->saved_data["dilation"] = dilation; | ||
ctx->saved_data["groups"] = groups; | ||
ctx->save_for_backward({input, weight}); | ||
|
||
auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); | ||
|
||
return {output}; | ||
} | ||
|
||
static torch::autograd::variable_list backward( | ||
torch::autograd::AutogradContext* ctx, | ||
const torch::autograd::variable_list grad_output | ||
) { | ||
torch::autograd::variable_list saved = ctx->get_saved_variables(); | ||
torch::autograd::Variable input = saved[0]; | ||
torch::autograd::Variable weight = saved[1]; | ||
|
||
auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); | ||
auto stride = ctx->saved_data["stride"].toIntVector(); | ||
auto padding = ctx->saved_data["padding"].toIntVector(); | ||
auto dilation = ctx->saved_data["dilation"].toIntVector(); | ||
auto groups = ctx->saved_data["groups"].toInt(); | ||
|
||
auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); | ||
|
||
return { | ||
grads[0], | ||
grads[1], | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable() | ||
}; | ||
} | ||
}; | ||
|
||
at::Tensor involution2d_autograd( | ||
const torch::autograd::Variable& input, | ||
const torch::autograd::Variable& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
) { | ||
return Involution2dFunctionCPU::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; | ||
} | ||
|
||
} // namespace cpu | ||
|
||
#ifdef USE_CUDA | ||
|
||
namespace cuda { | ||
|
||
class Involution2dFunctionCUDA : public torch::autograd::Function<Involution2dFunctionCUDA> | ||
{ | ||
public: | ||
|
||
static torch::autograd::variable_list forward( | ||
torch::autograd::AutogradContext* ctx, | ||
const torch::autograd::Variable& input, | ||
const torch::autograd::Variable& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
) { | ||
ctx->saved_data["kernel_size"] = kernel_size; | ||
ctx->saved_data["stride"] = stride; | ||
ctx->saved_data["padding"] = padding; | ||
ctx->saved_data["dilation"] = dilation; | ||
ctx->saved_data["groups"] = groups; | ||
ctx->save_for_backward({input, weight}); | ||
|
||
auto output = involution2d_forward(input, weight, kernel_size, stride, padding, dilation, groups); | ||
|
||
return {output}; | ||
} | ||
|
||
static torch::autograd::variable_list backward( | ||
torch::autograd::AutogradContext* ctx, | ||
const torch::autograd::variable_list grad_output | ||
) { | ||
torch::autograd::variable_list saved = ctx->get_saved_variables(); | ||
torch::autograd::Variable input = saved[0]; | ||
torch::autograd::Variable weight = saved[1]; | ||
|
||
auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); | ||
auto stride = ctx->saved_data["stride"].toIntVector(); | ||
auto padding = ctx->saved_data["padding"].toIntVector(); | ||
auto dilation = ctx->saved_data["dilation"].toIntVector(); | ||
auto groups = ctx->saved_data["groups"].toInt(); | ||
|
||
auto grads = involution2d_backward(grad_output[0], weight, input, kernel_size, stride, padding, dilation, groups); | ||
|
||
return { | ||
grads[0], | ||
grads[1], | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable(), | ||
torch::autograd::Variable() | ||
}; | ||
} | ||
}; | ||
|
||
at::Tensor involution2d_autograd( | ||
const torch::autograd::Variable& input, | ||
const torch::autograd::Variable& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
) { | ||
return Involution2dFunctionCUDA::apply(input, weight, kernel_size, stride, padding, dilation, groups)[0]; | ||
} | ||
|
||
at::Tensor involution2d_autocast( | ||
const torch::autograd::Variable& input, | ||
const torch::autograd::Variable& weight, | ||
const std::vector<int64_t>& kernel_size, | ||
const std::vector<int64_t>& stride, | ||
const std::vector<int64_t>& padding, | ||
const std::vector<int64_t>& dilation, | ||
const int64_t groups | ||
) { | ||
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); | ||
auto exec_type = at::autocast::promote_type(at::kFloat, input, weight); | ||
return involution2d_autograd( | ||
at::autocast::cached_cast(exec_type, input), | ||
at::autocast::cached_cast(exec_type, weight), | ||
kernel_size, stride, padding, dilation, groups | ||
); | ||
} | ||
|
||
} // namespace cuda | ||
|
||
#endif | ||
|
||
} // namespace involution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from glob import glob | ||
import os | ||
|
||
from torch import ops | ||
|
||
_LIB_PATH = glob(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'involution.*.so'))[0] | ||
ops.load_library(_LIB_PATH) | ||
|
||
from .involution2d import Involution2d |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@csvance
Fixed CUDA implementation input to be full precision using Autocast.