mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add support for group ConvTranspose (#18794)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18794 Add support for group ConvTranspose Reviewed By: houseroad Differential Revision: D14741327 fbshipit-source-id: 5d947ca044bf8495dd7f8f56122441ebbcc6c7e4
This commit is contained in:
committed by
Facebook Github Bot
parent
8732a1b42e
commit
b145dcca04
@ -1,7 +1,10 @@
|
|||||||
|
#include "caffe2/operators/conv_transpose_op.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "caffe2/core/context_gpu.h"
|
#include "caffe2/core/context_gpu.h"
|
||||||
#include "caffe2/core/cudnn_wrappers.h"
|
#include "caffe2/core/cudnn_wrappers.h"
|
||||||
#include "caffe2/operators/conv_op_cache_cudnn.h"
|
#include "caffe2/operators/conv_op_cache_cudnn.h"
|
||||||
#include "caffe2/operators/conv_transpose_op.h"
|
|
||||||
#include "caffe2/operators/op_utils_cudnn.h"
|
#include "caffe2/operators/op_utils_cudnn.h"
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
@ -49,6 +52,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
|
|||||||
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
|
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
|
||||||
if (InputSize() == 3) {
|
if (InputSize() == 3) {
|
||||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
|
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
|
||||||
|
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_for_bias_));
|
||||||
}
|
}
|
||||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
|
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
|
||||||
CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
|
CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
|
||||||
@ -59,27 +63,59 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
|
|||||||
CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
|
CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
|
||||||
if (InputSize() == 3) {
|
if (InputSize() == 3) {
|
||||||
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
|
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
|
||||||
|
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_for_bias_));
|
||||||
}
|
}
|
||||||
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
|
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
|
||||||
CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
|
CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
vector<int64_t> cudnn_input_dims_;
|
void SetTensor4DDescriptorWithGroup(
|
||||||
vector<int64_t> cudnn_filter_dims_;
|
const cudnnDataType_t data_type,
|
||||||
|
const int N,
|
||||||
|
const int C,
|
||||||
|
const int H,
|
||||||
|
const int W,
|
||||||
|
cudnnTensorDescriptor_t* desc) const {
|
||||||
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
|
const int CC = C;
|
||||||
|
#else
|
||||||
|
const int CC = C / group_;
|
||||||
|
#endif
|
||||||
|
switch (order_) {
|
||||||
|
case StorageOrder::NCHW: {
|
||||||
|
CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
|
||||||
|
*desc, data_type, N, CC, H, W, C * H * W, H * W, W, 1));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case StorageOrder::NHWC: {
|
||||||
|
CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
|
||||||
|
*desc, data_type, N, CC, H, W, H * W * C, 1, W * C, C));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::int64_t> cudnn_input_dims_;
|
||||||
|
std::vector<std::int64_t> cudnn_filter_dims_;
|
||||||
|
|
||||||
CuDNNWrapper cudnn_wrapper_;
|
CuDNNWrapper cudnn_wrapper_;
|
||||||
cudnnTensorDescriptor_t bottom_desc_;
|
cudnnTensorDescriptor_t bottom_desc_;
|
||||||
cudnnFilterDescriptor_t filter_desc_;
|
cudnnFilterDescriptor_t filter_desc_;
|
||||||
cudnnTensorDescriptor_t bias_desc_;
|
cudnnTensorDescriptor_t bias_desc_;
|
||||||
cudnnTensorDescriptor_t top_desc_;
|
cudnnTensorDescriptor_t top_desc_;
|
||||||
|
cudnnTensorDescriptor_t top_desc_for_bias_;
|
||||||
cudnnConvolutionDescriptor_t conv_desc_;
|
cudnnConvolutionDescriptor_t conv_desc_;
|
||||||
|
|
||||||
const size_t cudnn_ws_nbytes_limit_;
|
const size_t cudnn_ws_nbytes_limit_;
|
||||||
size_t cudnn_ws_nbytes_;
|
size_t cudnn_ws_nbytes_;
|
||||||
bool exhaustive_search_;
|
bool exhaustive_search_;
|
||||||
bool deterministic_;
|
bool deterministic_;
|
||||||
size_t cudnn_state_;
|
size_t cudnn_state_;
|
||||||
vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
|
std::vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
|
||||||
bool enable_tensor_core_;
|
bool enable_tensor_core_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -141,10 +177,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
int C = 0;
|
int C = 0;
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case StorageOrder::NHWC:
|
case StorageOrder::NHWC:
|
||||||
C = filter.dim32(3);
|
C = filter.dim32(3) * group_;
|
||||||
break;
|
break;
|
||||||
case StorageOrder::NCHW:
|
case StorageOrder::NCHW:
|
||||||
C = filter.dim32(1);
|
C = filter.dim32(1) * group_;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||||
@ -162,9 +198,8 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
H_out = Y->dim32(1);
|
H_out = Y->dim32(1);
|
||||||
W_out = Y->dim32(2);
|
W_out = Y->dim32(2);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
|
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(3), C);
|
CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
|
||||||
break;
|
break;
|
||||||
case StorageOrder::NCHW:
|
case StorageOrder::NCHW:
|
||||||
N = X.dim32(0);
|
N = X.dim32(0);
|
||||||
@ -173,13 +208,14 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
W = X.dim32(3);
|
W = X.dim32(3);
|
||||||
H_out = Y->dim32(2);
|
H_out = Y->dim32(2);
|
||||||
W_out = Y->dim32(3);
|
W_out = Y->dim32(3);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), C);
|
CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
|
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
|
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||||
}
|
}
|
||||||
|
CAFFE_ENFORCE_EQ(M % group_, 0);
|
||||||
|
|
||||||
if (InputSize() == 3) {
|
if (InputSize() == 3) {
|
||||||
auto& bias = Input(BIAS);
|
auto& bias = Input(BIAS);
|
||||||
@ -188,30 +224,29 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set up the cudnn algorithms & workspace if necessary
|
// Set up the cudnn algorithms & workspace if necessary
|
||||||
bool input_changed = (X.sizes() != cudnn_input_dims_);
|
const bool input_changed = (X.sizes() != cudnn_input_dims_);
|
||||||
bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
|
const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
|
||||||
|
|
||||||
if (input_changed || filter_changed) {
|
if (input_changed || filter_changed) {
|
||||||
VLOG(1) << "Changing the cudnn descriptor configurations.";
|
VLOG(1) << "Changing the cudnn descriptor configurations.";
|
||||||
if (input_changed) {
|
if (input_changed) {
|
||||||
cudnn_input_dims_ = X.sizes().vec();
|
cudnn_input_dims_ = X.sizes().vec();
|
||||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
SetTensor4DDescriptorWithGroup(
|
||||||
bottom_desc_,
|
cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
|
||||||
GetCudnnTensorFormat(order_),
|
|
||||||
cudnnTypeWrapper<T>::type,
|
|
||||||
N,
|
|
||||||
M,
|
|
||||||
H,
|
|
||||||
W));
|
|
||||||
}
|
}
|
||||||
if (filter_changed) {
|
if (filter_changed) {
|
||||||
cudnn_filter_dims_ = filter.sizes().vec();
|
cudnn_filter_dims_ = filter.sizes().vec();
|
||||||
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
|
const int MM = M;
|
||||||
|
#else
|
||||||
|
const int MM = M / group_;
|
||||||
|
#endif
|
||||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||||
filter_desc_,
|
filter_desc_,
|
||||||
cudnnTypeWrapper<T>::type,
|
cudnnTypeWrapper<T>::type,
|
||||||
GetCudnnTensorFormat(order_),
|
GetCudnnTensorFormat(order_),
|
||||||
M,
|
MM,
|
||||||
C,
|
C / group_,
|
||||||
kernel_h(),
|
kernel_h(),
|
||||||
kernel_w()));
|
kernel_w()));
|
||||||
if (InputSize() == 3) {
|
if (InputSize() == 3) {
|
||||||
@ -226,14 +261,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Set the output
|
// Set the output
|
||||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
SetTensor4DDescriptorWithGroup(
|
||||||
top_desc_,
|
cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
|
||||||
GetCudnnTensorFormat(order_),
|
if (InputSize() == 3) {
|
||||||
cudnnTypeWrapper<T>::type,
|
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
||||||
N,
|
top_desc_for_bias_,
|
||||||
C,
|
GetCudnnTensorFormat(order_),
|
||||||
H_out,
|
cudnnTypeWrapper<T>::type,
|
||||||
W_out));
|
N,
|
||||||
|
C,
|
||||||
|
H_out,
|
||||||
|
W_out));
|
||||||
|
}
|
||||||
|
|
||||||
// Set the convolution descriptor
|
// Set the convolution descriptor
|
||||||
CAFFE_ENFORCE_EQ(
|
CAFFE_ENFORCE_EQ(
|
||||||
pad_t(),
|
pad_t(),
|
||||||
@ -246,7 +286,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
"The current padding scheme leads to unequal padding on the left "
|
"The current padding scheme leads to unequal padding on the left "
|
||||||
"and right, which is not supported by cudnn.");
|
"and right, which is not supported by cudnn.");
|
||||||
// Set the convolution descriptor
|
// Set the convolution descriptor
|
||||||
#if CUDNN_VERSION_MIN(6,0,0)
|
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||||
conv_desc_,
|
conv_desc_,
|
||||||
pad_t(),
|
pad_t(),
|
||||||
@ -268,6 +308,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
1,
|
1,
|
||||||
CUDNN_CROSS_CORRELATION));
|
CUDNN_CROSS_CORRELATION));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if CUDNN_VERSION_MIN(7, 0, 0)
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
// enable TensorCore math if desired
|
// enable TensorCore math if desired
|
||||||
enable_tensor_core_ &= TensorCoreAvailable();
|
enable_tensor_core_ &= TensorCoreAvailable();
|
||||||
@ -275,7 +316,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
CUDNN_ENFORCE(
|
CUDNN_ENFORCE(
|
||||||
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
|
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
|
||||||
}
|
}
|
||||||
|
// set cuDNN groups if appropriate
|
||||||
|
CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (force_algo_[ALGO_DGRAD] >= 0) {
|
if (force_algo_[ALGO_DGRAD] >= 0) {
|
||||||
bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
|
bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
|
||||||
} else if (deterministic_) {
|
} else if (deterministic_) {
|
||||||
@ -331,24 +375,56 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
|
VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const T* X_data = X.template data<T>();
|
||||||
|
const T* filter_data = filter.template data<T>();
|
||||||
|
T* Y_data = Y->template mutable_data<T>();
|
||||||
|
|
||||||
// Now, actually run the computation.
|
// Now, actually run the computation.
|
||||||
// Filter
|
// Filter
|
||||||
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||||
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
|
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
|
||||||
state->cudnn_handle(),
|
state->cudnn_handle(),
|
||||||
cudnnTypeWrapper<T>::kOne(),
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
filter_desc_,
|
filter_desc_,
|
||||||
filter.template data<T>(),
|
filter_data,
|
||||||
bottom_desc_,
|
bottom_desc_,
|
||||||
X.template data<T>(),
|
X_data,
|
||||||
conv_desc_,
|
conv_desc_,
|
||||||
bwd_data_algo_,
|
bwd_data_algo_,
|
||||||
state->workspace().get(cudnn_ws_nbytes_),
|
state->workspace().get(cudnn_ws_nbytes_),
|
||||||
cudnn_ws_nbytes_,
|
cudnn_ws_nbytes_,
|
||||||
cudnnTypeWrapper<T>::kZero(),
|
cudnnTypeWrapper<T>::kZero(),
|
||||||
top_desc_,
|
top_desc_,
|
||||||
Y->template mutable_data<T>()));
|
Y_data));
|
||||||
});
|
});
|
||||||
|
#else
|
||||||
|
const int X_HxW = H * W;
|
||||||
|
const int Y_HxW = H_out * W_out;
|
||||||
|
const int group_offset_X =
|
||||||
|
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
|
||||||
|
const int group_offset_Y =
|
||||||
|
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
|
||||||
|
const int group_offset_filter = filter.numel() / group_;
|
||||||
|
for (int i = 0; i < group_; ++i) {
|
||||||
|
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||||
|
CUDNN_ENFORCE(
|
||||||
|
cudnnConvolutionBackwardData(state->cudnn_handle(),
|
||||||
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
|
filter_desc_,
|
||||||
|
filter_data + i * group_offset_filter,
|
||||||
|
bottom_desc_,
|
||||||
|
X_data + i * group_offset_X;
|
||||||
|
conv_desc_,
|
||||||
|
bwd_data_algo_,
|
||||||
|
state->workspace().get(cudnn_ws_nbytes_),
|
||||||
|
cudnn_ws_nbytes_,
|
||||||
|
cudnnTypeWrapper<T_DX>::kZero(),
|
||||||
|
top_desc_,
|
||||||
|
Y_data + i * group_offset_Y));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
// Bias
|
// Bias
|
||||||
if (InputSize() == 3) {
|
if (InputSize() == 3) {
|
||||||
CUDNN_ENFORCE(cudnnAddTensor(
|
CUDNN_ENFORCE(cudnnAddTensor(
|
||||||
@ -357,7 +433,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
bias_desc_,
|
bias_desc_,
|
||||||
Input(BIAS).template data<T>(),
|
Input(BIAS).template data<T>(),
|
||||||
cudnnTypeWrapper<T>::kOne(),
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
top_desc_,
|
top_desc_for_bias_,
|
||||||
Y->template mutable_data<T>()));
|
Y->template mutable_data<T>()));
|
||||||
}
|
}
|
||||||
// Done.
|
// Done.
|
||||||
@ -368,19 +444,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
|
|||||||
// consolidating them.
|
// consolidating them.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
||||||
auto& X = Input(INPUT);
|
const auto& X = Input(INPUT);
|
||||||
auto& filter = Input(FILTER);
|
const auto& filter = Input(FILTER);
|
||||||
auto& dY = Input(OUTPUT_GRAD);
|
const auto& dY = Input(OUTPUT_GRAD);
|
||||||
|
|
||||||
CAFFE_ENFORCE_EQ(X.dim(), 4);
|
CAFFE_ENFORCE_EQ(X.dim(), 4);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim(), 4);
|
CAFFE_ENFORCE_EQ(filter.dim(), 4);
|
||||||
int C = 0;
|
int C = 0;
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case StorageOrder::NHWC:
|
case StorageOrder::NHWC:
|
||||||
C = filter.dim32(3);
|
C = filter.dim32(3) * group_;
|
||||||
break;
|
break;
|
||||||
case StorageOrder::NCHW:
|
case StorageOrder::NCHW:
|
||||||
C = filter.dim32(1);
|
C = filter.dim32(1) * group_;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||||
@ -398,7 +474,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
|
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(3), C);
|
CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
|
||||||
break;
|
break;
|
||||||
case StorageOrder::NCHW:
|
case StorageOrder::NCHW:
|
||||||
N = X.dim32(0);
|
N = X.dim32(0);
|
||||||
@ -407,41 +483,42 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
W = X.dim32(3);
|
W = X.dim32(3);
|
||||||
H_out = dY.dim32(2);
|
H_out = dY.dim32(2);
|
||||||
W_out = dY.dim32(3);
|
W_out = dY.dim32(3);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(1), C);
|
CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
|
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
|
||||||
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
|
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown storage order: " << order_;
|
LOG(FATAL) << "Unknown storage order: " << order_;
|
||||||
}
|
}
|
||||||
|
CAFFE_ENFORCE_EQ(M % group_, 0);
|
||||||
|
|
||||||
// Since we only handle LegacyPadding::NOTSET, we don't need to
|
// Since we only handle LegacyPadding::NOTSET, we don't need to
|
||||||
// compute padding.
|
// compute padding.
|
||||||
auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
|
auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
|
||||||
|
|
||||||
// Set up the cudnn algorithms & workspace if necessary
|
// Set up the cudnn algorithms & workspace if necessary
|
||||||
bool input_changed = (X.sizes() != cudnn_input_dims_);
|
const bool input_changed = (X.sizes() != cudnn_input_dims_);
|
||||||
bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
|
const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
|
||||||
if (input_changed || filter_changed) {
|
if (input_changed || filter_changed) {
|
||||||
VLOG(1) << "Changing the cudnn descriptor configurations.";
|
VLOG(1) << "Changing the cudnn descriptor configurations.";
|
||||||
if (input_changed) {
|
if (input_changed) {
|
||||||
cudnn_input_dims_ = X.sizes().vec();
|
cudnn_input_dims_ = X.sizes().vec();
|
||||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
SetTensor4DDescriptorWithGroup(
|
||||||
bottom_desc_,
|
cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
|
||||||
GetCudnnTensorFormat(order_),
|
|
||||||
cudnnTypeWrapper<T>::type,
|
|
||||||
N,
|
|
||||||
M,
|
|
||||||
H,
|
|
||||||
W));
|
|
||||||
}
|
}
|
||||||
if (filter_changed) {
|
if (filter_changed) {
|
||||||
cudnn_filter_dims_ = filter.sizes().vec();
|
cudnn_filter_dims_ = filter.sizes().vec();
|
||||||
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
|
const int MM = M;
|
||||||
|
#else
|
||||||
|
const int MM = M / group_;
|
||||||
|
#endif
|
||||||
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
|
||||||
filter_desc_,
|
filter_desc_,
|
||||||
cudnnTypeWrapper<T>::type,
|
cudnnTypeWrapper<T>::type,
|
||||||
GetCudnnTensorFormat(order_),
|
GetCudnnTensorFormat(order_),
|
||||||
M,
|
MM,
|
||||||
C,
|
C / group_,
|
||||||
kernel_h(),
|
kernel_h(),
|
||||||
kernel_w()));
|
kernel_w()));
|
||||||
if (!no_bias_) {
|
if (!no_bias_) {
|
||||||
@ -456,14 +533,19 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Set the output
|
// Set the output
|
||||||
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
SetTensor4DDescriptorWithGroup(
|
||||||
top_desc_,
|
cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
|
||||||
GetCudnnTensorFormat(order_),
|
if (!no_bias_) {
|
||||||
cudnnTypeWrapper<T>::type,
|
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
||||||
N,
|
top_desc_for_bias_,
|
||||||
C,
|
GetCudnnTensorFormat(order_),
|
||||||
H_out,
|
cudnnTypeWrapper<T>::type,
|
||||||
W_out));
|
N,
|
||||||
|
C,
|
||||||
|
H_out,
|
||||||
|
W_out));
|
||||||
|
}
|
||||||
|
|
||||||
// Set the convolution descriptor
|
// Set the convolution descriptor
|
||||||
CAFFE_ENFORCE_EQ(
|
CAFFE_ENFORCE_EQ(
|
||||||
pad_t(),
|
pad_t(),
|
||||||
@ -475,7 +557,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
pad_r(),
|
pad_r(),
|
||||||
"The current padding scheme leads to unequal padding on the left "
|
"The current padding scheme leads to unequal padding on the left "
|
||||||
"and right, which is not supported by cudnn.");
|
"and right, which is not supported by cudnn.");
|
||||||
#if CUDNN_VERSION_MIN(6,0,0)
|
#if CUDNN_VERSION_MIN(6, 0, 0)
|
||||||
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
|
||||||
conv_desc_,
|
conv_desc_,
|
||||||
pad_t(),
|
pad_t(),
|
||||||
@ -504,6 +586,8 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
CUDNN_ENFORCE(
|
CUDNN_ENFORCE(
|
||||||
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
|
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
|
||||||
}
|
}
|
||||||
|
// set cuDNN groups if appropriate
|
||||||
|
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
|
||||||
#endif
|
#endif
|
||||||
if (force_algo_[ALGO_WGRAD] >= 0) {
|
if (force_algo_[ALGO_WGRAD] >= 0) {
|
||||||
bwd_filter_algo_ =
|
bwd_filter_algo_ =
|
||||||
@ -622,13 +706,14 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
|
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
|
||||||
cudnn_wrapper_.inline_cudnn_handle(),
|
cudnn_wrapper_.inline_cudnn_handle(),
|
||||||
cudnnTypeWrapper<T>::kOne(),
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
top_desc_,
|
top_desc_for_bias_,
|
||||||
dY.template data<T>(),
|
dY.template data<T>(),
|
||||||
cudnnTypeWrapper<T>::kZero(),
|
cudnnTypeWrapper<T>::kZero(),
|
||||||
bias_desc_,
|
bias_desc_,
|
||||||
dbias->template mutable_data<T>()));
|
dbias->template mutable_data<T>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if CUDNN_VERSION_MIN(7, 0, 0)
|
||||||
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||||
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||||
state->cudnn_handle(),
|
state->cudnn_handle(),
|
||||||
@ -647,7 +732,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
|
|
||||||
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
||||||
// Compute the gradient w.r.t. the input.
|
// Compute the gradient w.r.t. the input.
|
||||||
|
|
||||||
auto* dX = Output(
|
auto* dX = Output(
|
||||||
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
|
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
|
||||||
X.sizes(),
|
X.sizes(),
|
||||||
@ -668,6 +752,55 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
|
|||||||
dX->template mutable_data<T>()));
|
dX->template mutable_data<T>()));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
#else
|
||||||
|
const int X_HxW = H * W;
|
||||||
|
const int Y_HxW = H_out * W_out;
|
||||||
|
const int group_offset_X =
|
||||||
|
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
|
||||||
|
const int group_offset_Y =
|
||||||
|
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
|
||||||
|
const int group_offset_filter = filter.numel() / group_;
|
||||||
|
for (int i = 0; i < group_; ++i) {
|
||||||
|
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||||
|
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
|
||||||
|
state->cudnn_handle(),
|
||||||
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
|
top_desc_,
|
||||||
|
dY.template data<T>() + i * group_offset_Y,
|
||||||
|
bottom_desc_,
|
||||||
|
X.template data<T>() + i * group_offset_X,
|
||||||
|
conv_desc_,
|
||||||
|
bwd_filter_algo_,
|
||||||
|
state->workspace().get(cudnn_ws_nbytes_),
|
||||||
|
cudnn_ws_nbytes_,
|
||||||
|
cudnnTypeWrapper<T>::kZero(),
|
||||||
|
filter_desc_,
|
||||||
|
dfilter->template mutable_data<T>() + i * group_offset_filter));
|
||||||
|
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
||||||
|
// Compute the gradient w.r.t. the input.
|
||||||
|
auto* dX = Output(
|
||||||
|
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
|
||||||
|
X.sizes(),
|
||||||
|
at::dtype<T>());
|
||||||
|
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
|
||||||
|
CUDNN_ENFORCE(cudnnConvolutionForward(
|
||||||
|
state->cudnn_handle(),
|
||||||
|
cudnnTypeWrapper<T>::kOne(),
|
||||||
|
top_desc_,
|
||||||
|
dY.template data<T>() + i * group_offset_Y,
|
||||||
|
filter_desc_,
|
||||||
|
filter.template data<T>() + i * group_offset_filter,
|
||||||
|
conv_desc_,
|
||||||
|
algo_,
|
||||||
|
state->workspace().get(cudnn_ws_nbytes_),
|
||||||
|
cudnn_ws_nbytes_,
|
||||||
|
cudnnTypeWrapper<T>::kZero(),
|
||||||
|
bottom_desc_,
|
||||||
|
dX->template mutable_data<T>() + i * group_offset_X));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -17,7 +17,9 @@ template <class Context>
|
|||||||
class ConvTransposeUnpoolBase : public Operator<Context> {
|
class ConvTransposeUnpoolBase : public Operator<Context> {
|
||||||
public:
|
public:
|
||||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||||
explicit ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws)
|
explicit ConvTransposeUnpoolBase(
|
||||||
|
const OperatorDef& operator_def,
|
||||||
|
Workspace* ws)
|
||||||
: Operator<Context>(operator_def, ws),
|
: Operator<Context>(operator_def, ws),
|
||||||
legacy_pad_(
|
legacy_pad_(
|
||||||
static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
|
static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
|
||||||
@ -27,6 +29,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
|
|||||||
stride_(this->template GetRepeatedArgument<int>("strides")),
|
stride_(this->template GetRepeatedArgument<int>("strides")),
|
||||||
pads_(this->template GetRepeatedArgument<int>("pads")),
|
pads_(this->template GetRepeatedArgument<int>("pads")),
|
||||||
adj_(this->template GetRepeatedArgument<int>("adjs")),
|
adj_(this->template GetRepeatedArgument<int>("adjs")),
|
||||||
|
group_(this->template GetSingleArgument<int>("group", 1)),
|
||||||
order_(StringToStorageOrder(
|
order_(StringToStorageOrder(
|
||||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||||
shared_buffer_(
|
shared_buffer_(
|
||||||
@ -206,19 +209,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
|
|||||||
|
|
||||||
virtual ~ConvTransposeUnpoolBase() {}
|
virtual ~ConvTransposeUnpoolBase() {}
|
||||||
|
|
||||||
private:
|
|
||||||
LegacyPadding legacy_pad_;
|
|
||||||
int pad_;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
vector<int> kernel_;
|
|
||||||
vector<int> stride_;
|
|
||||||
vector<int> pads_;
|
|
||||||
vector<int> adj_;
|
|
||||||
StorageOrder order_;
|
|
||||||
bool shared_buffer_;
|
|
||||||
Workspace* ws_;
|
|
||||||
|
|
||||||
// Accessors for 2D conv params.
|
// Accessors for 2D conv params.
|
||||||
|
|
||||||
inline int pad_t() const {
|
inline int pad_t() const {
|
||||||
@ -289,14 +280,35 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LegacyPadding legacy_pad_;
|
||||||
|
int pad_;
|
||||||
|
|
||||||
|
std::vector<int> kernel_;
|
||||||
|
std::vector<int> stride_;
|
||||||
|
std::vector<int> pads_;
|
||||||
|
std::vector<int> adj_;
|
||||||
|
int group_;
|
||||||
|
StorageOrder order_;
|
||||||
|
bool shared_buffer_;
|
||||||
|
Workspace* ws_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
|
#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
|
||||||
USE_OPERATOR_FUNCTIONS(Context); \
|
USE_OPERATOR_FUNCTIONS(Context); \
|
||||||
using ConvTransposeUnpoolBase<Context>::kernel_; \
|
using ConvTransposeUnpoolBase<Context>::kernel_; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::kernel_h; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::kernel_w; \
|
||||||
using ConvTransposeUnpoolBase<Context>::stride_; \
|
using ConvTransposeUnpoolBase<Context>::stride_; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::stride_h; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::stride_w; \
|
||||||
using ConvTransposeUnpoolBase<Context>::pads_; \
|
using ConvTransposeUnpoolBase<Context>::pads_; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::pad_t; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::pad_l; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::pad_b; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::pad_r; \
|
||||||
using ConvTransposeUnpoolBase<Context>::adj_; \
|
using ConvTransposeUnpoolBase<Context>::adj_; \
|
||||||
|
using ConvTransposeUnpoolBase<Context>::group_; \
|
||||||
using ConvTransposeUnpoolBase<Context>::order_; \
|
using ConvTransposeUnpoolBase<Context>::order_; \
|
||||||
using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
|
using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
|
||||||
using ConvTransposeUnpoolBase<Context>::ws_
|
using ConvTransposeUnpoolBase<Context>::ws_
|
||||||
|
@ -6,6 +6,7 @@ import numpy as np
|
|||||||
from hypothesis import assume, given, settings
|
from hypothesis import assume, given, settings
|
||||||
import hypothesis.strategies as st
|
import hypothesis.strategies as st
|
||||||
|
|
||||||
|
from caffe2.proto import caffe2_pb2
|
||||||
from caffe2.python import core, utils
|
from caffe2.python import core, utils
|
||||||
import caffe2.python.hypothesis_test_util as hu
|
import caffe2.python.hypothesis_test_util as hu
|
||||||
import caffe2.python.hip_test_util as hiputl
|
import caffe2.python.hip_test_util as hiputl
|
||||||
@ -360,6 +361,68 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||||||
for i in outputs_to_check:
|
for i in outputs_to_check:
|
||||||
self.assertGradientChecks(gc, op, inputs, i, [0])
|
self.assertGradientChecks(gc, op, inputs, i, [0])
|
||||||
|
|
||||||
|
@given(stride=st.integers(1, 3),
|
||||||
|
pad=st.integers(0, 3),
|
||||||
|
kernel=st.integers(1, 3),
|
||||||
|
adj=st.integers(0, 2),
|
||||||
|
size=st.integers(7, 10),
|
||||||
|
input_channels=st.integers(1, 8),
|
||||||
|
output_channels=st.integers(1, 8),
|
||||||
|
batch_size=st.integers(1, 4),
|
||||||
|
group=st.integers(1, 4),
|
||||||
|
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||||
|
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
|
||||||
|
shared_buffer=st.booleans(),
|
||||||
|
use_bias=st.booleans(),
|
||||||
|
**hu.gcs)
|
||||||
|
def test_convolution_transpose_with_group(
|
||||||
|
self, stride, pad, kernel, adj, size, input_channels,
|
||||||
|
output_channels, batch_size, group, order, engine, shared_buffer,
|
||||||
|
use_bias, gc, dc):
|
||||||
|
assume(adj < stride)
|
||||||
|
# TODO: Group conv_transpose in NHWC not implemented for GPU yet.
|
||||||
|
assume(group == 1 or order == "NCHW" or
|
||||||
|
gc.device_type == caffe2_pb2.CPU)
|
||||||
|
if group != 1 and order == "NHWC":
|
||||||
|
dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
|
||||||
|
|
||||||
|
if hiputl.run_in_hip(gc, dc) and order == "NHWC":
|
||||||
|
engine = ""
|
||||||
|
|
||||||
|
op = core.CreateOperator(
|
||||||
|
"ConvTranspose",
|
||||||
|
["X", "w", "b"] if use_bias else ["X", "w"],
|
||||||
|
["Y"],
|
||||||
|
stride=stride,
|
||||||
|
kernel=kernel,
|
||||||
|
pad=pad,
|
||||||
|
adj=adj,
|
||||||
|
group=group,
|
||||||
|
order=order,
|
||||||
|
engine=engine,
|
||||||
|
shared_buffer=int(shared_buffer),
|
||||||
|
device_option=gc,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_channels *= group
|
||||||
|
output_channels *= group
|
||||||
|
|
||||||
|
X = np.random.rand(
|
||||||
|
batch_size, size, size, input_channels).astype(np.float32) - 0.5
|
||||||
|
w = np.random.rand(
|
||||||
|
input_channels, kernel, kernel, int(output_channels / group)) \
|
||||||
|
.astype(np.float32) - 0.5
|
||||||
|
b = np.random.rand(output_channels).astype(np.float32) - 0.5
|
||||||
|
if order == "NCHW":
|
||||||
|
X = utils.NHWC2NCHW(X)
|
||||||
|
w = utils.NHWC2NCHW(w)
|
||||||
|
|
||||||
|
inputs = [X, w, b] if use_bias else [X, w]
|
||||||
|
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
self.assertGradientChecks(gc, op, inputs, i, [0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
import unittest
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user