Add support for group ConvTranspose (#18794)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18794

Add support for group ConvTranspose

Reviewed By: houseroad

Differential Revision: D14741327

fbshipit-source-id: 5d947ca044bf8495dd7f8f56122441ebbcc6c7e4
This commit is contained in:
Xiaomeng Yang
2019-04-04 11:46:37 -07:00
committed by Facebook Github Bot
parent 8732a1b42e
commit b145dcca04
4 changed files with 771 additions and 492 deletions

View File

@ -1,7 +1,10 @@
#include "caffe2/operators/conv_transpose_op.h"
#include <vector>
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/operators/conv_op_cache_cudnn.h"
#include "caffe2/operators/conv_transpose_op.h"
#include "caffe2/operators/op_utils_cudnn.h"
namespace caffe2 {
@ -49,6 +52,7 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
if (InputSize() == 3) {
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_for_bias_));
}
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
@ -59,27 +63,59 @@ class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
if (InputSize() == 3) {
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_for_bias_));
}
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
}
protected:
vector<int64_t> cudnn_input_dims_;
vector<int64_t> cudnn_filter_dims_;
void SetTensor4DDescriptorWithGroup(
const cudnnDataType_t data_type,
const int N,
const int C,
const int H,
const int W,
cudnnTensorDescriptor_t* desc) const {
#if CUDNN_VERSION_MIN(7, 0, 0)
const int CC = C;
#else
const int CC = C / group_;
#endif
switch (order_) {
case StorageOrder::NCHW: {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
*desc, data_type, N, CC, H, W, C * H * W, H * W, W, 1));
break;
}
case StorageOrder::NHWC: {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptorEx(
*desc, data_type, N, CC, H, W, H * W * C, 1, W * C, C));
break;
}
default: {
LOG(FATAL) << "Unknown storage order: " << order_;
}
}
}
std::vector<std::int64_t> cudnn_input_dims_;
std::vector<std::int64_t> cudnn_filter_dims_;
CuDNNWrapper cudnn_wrapper_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnFilterDescriptor_t filter_desc_;
cudnnTensorDescriptor_t bias_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnTensorDescriptor_t top_desc_for_bias_;
cudnnConvolutionDescriptor_t conv_desc_;
const size_t cudnn_ws_nbytes_limit_;
size_t cudnn_ws_nbytes_;
bool exhaustive_search_;
bool deterministic_;
size_t cudnn_state_;
vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
std::vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
bool enable_tensor_core_;
};
@ -141,10 +177,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
int C = 0;
switch (order_) {
case StorageOrder::NHWC:
C = filter.dim32(3);
C = filter.dim32(3) * group_;
break;
case StorageOrder::NCHW:
C = filter.dim32(1);
C = filter.dim32(1) * group_;
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
@ -162,9 +198,8 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
H_out = Y->dim32(1);
W_out = Y->dim32(2);
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
CAFFE_ENFORCE_EQ(filter.dim32(3), C);
CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
break;
case StorageOrder::NCHW:
N = X.dim32(0);
@ -173,13 +208,14 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
W = X.dim32(3);
H_out = Y->dim32(2);
W_out = Y->dim32(3);
CAFFE_ENFORCE_EQ(filter.dim32(1), C);
CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
}
CAFFE_ENFORCE_EQ(M % group_, 0);
if (InputSize() == 3) {
auto& bias = Input(BIAS);
@ -188,30 +224,29 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
}
// Set up the cudnn algorithms & workspace if necessary
bool input_changed = (X.sizes() != cudnn_input_dims_);
bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
const bool input_changed = (X.sizes() != cudnn_input_dims_);
const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.sizes().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
M,
H,
W));
SetTensor4DDescriptorWithGroup(
cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
}
if (filter_changed) {
cudnn_filter_dims_ = filter.sizes().vec();
#if CUDNN_VERSION_MIN(7, 0, 0)
const int MM = M;
#else
const int MM = M / group_;
#endif
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
filter_desc_,
cudnnTypeWrapper<T>::type,
GetCudnnTensorFormat(order_),
M,
C,
MM,
C / group_,
kernel_h(),
kernel_w()));
if (InputSize() == 3) {
@ -226,14 +261,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
}
}
// Set the output
SetTensor4DDescriptorWithGroup(
cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
if (InputSize() == 3) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
top_desc_,
top_desc_for_bias_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
H_out,
W_out));
}
// Set the convolution descriptor
CAFFE_ENFORCE_EQ(
pad_t(),
@ -268,6 +308,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
1,
CUDNN_CROSS_CORRELATION));
#endif
#if CUDNN_VERSION_MIN(7, 0, 0)
// enable TensorCore math if desired
enable_tensor_core_ &= TensorCoreAvailable();
@ -275,7 +316,10 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
CUDNN_ENFORCE(
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
}
// set cuDNN groups if appropriate
CUDNN_ENFORCE(cudnnSetConvolutionGroupCount(conv_desc_, group_));
#endif
if (force_algo_[ALGO_DGRAD] >= 0) {
bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
} else if (deterministic_) {
@ -331,24 +375,56 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
}
const T* X_data = X.template data<T>();
const T* filter_data = filter.template data<T>();
T* Y_data = Y->template mutable_data<T>();
// Now, actually run the computation.
// Filter
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionBackwardData(
state->cudnn_handle(),
cudnnTypeWrapper<T>::kOne(),
filter_desc_,
filter.template data<T>(),
filter_data,
bottom_desc_,
X.template data<T>(),
X_data,
conv_desc_,
bwd_data_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T>::kZero(),
top_desc_,
Y->template mutable_data<T>()));
Y_data));
});
#else
const int X_HxW = H * W;
const int Y_HxW = H_out * W_out;
const int group_offset_X =
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
const int group_offset_Y =
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
const int group_offset_filter = filter.numel() / group_;
for (int i = 0; i < group_; ++i) {
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(
cudnnConvolutionBackwardData(state->cudnn_handle(),
cudnnTypeWrapper<T>::kOne(),
filter_desc_,
filter_data + i * group_offset_filter,
bottom_desc_,
X_data + i * group_offset_X;
conv_desc_,
bwd_data_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T_DX>::kZero(),
top_desc_,
Y_data + i * group_offset_Y));
});
}
#endif
// Bias
if (InputSize() == 3) {
CUDNN_ENFORCE(cudnnAddTensor(
@ -357,7 +433,7 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
bias_desc_,
Input(BIAS).template data<T>(),
cudnnTypeWrapper<T>::kOne(),
top_desc_,
top_desc_for_bias_,
Y->template mutable_data<T>()));
}
// Done.
@ -368,19 +444,19 @@ bool CudnnConvTransposeOp<T>::RunOnDevice() {
// consolidating them.
template <typename T>
bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
auto& X = Input(INPUT);
auto& filter = Input(FILTER);
auto& dY = Input(OUTPUT_GRAD);
const auto& X = Input(INPUT);
const auto& filter = Input(FILTER);
const auto& dY = Input(OUTPUT_GRAD);
CAFFE_ENFORCE_EQ(X.dim(), 4);
CAFFE_ENFORCE_EQ(filter.dim(), 4);
int C = 0;
switch (order_) {
case StorageOrder::NHWC:
C = filter.dim32(3);
C = filter.dim32(3) * group_;
break;
case StorageOrder::NCHW:
C = filter.dim32(1);
C = filter.dim32(1) * group_;
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
@ -398,7 +474,7 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
CAFFE_ENFORCE_EQ(filter.dim32(3), C);
CAFFE_ENFORCE_EQ(filter.dim32(3), C / group_);
break;
case StorageOrder::NCHW:
N = X.dim32(0);
@ -407,41 +483,42 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
W = X.dim32(3);
H_out = dY.dim32(2);
W_out = dY.dim32(3);
CAFFE_ENFORCE_EQ(filter.dim32(1), C);
CAFFE_ENFORCE_EQ(filter.dim32(1), C / group_);
CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
}
CAFFE_ENFORCE_EQ(M % group_, 0);
// Since we only handle LegacyPadding::NOTSET, we don't need to
// compute padding.
auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
// Set up the cudnn algorithms & workspace if necessary
bool input_changed = (X.sizes() != cudnn_input_dims_);
bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
const bool input_changed = (X.sizes() != cudnn_input_dims_);
const bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
if (input_changed || filter_changed) {
VLOG(1) << "Changing the cudnn descriptor configurations.";
if (input_changed) {
cudnn_input_dims_ = X.sizes().vec();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
M,
H,
W));
SetTensor4DDescriptorWithGroup(
cudnnTypeWrapper<T>::type, N, M, H, W, &bottom_desc_);
}
if (filter_changed) {
cudnn_filter_dims_ = filter.sizes().vec();
#if CUDNN_VERSION_MIN(7, 0, 0)
const int MM = M;
#else
const int MM = M / group_;
#endif
CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
filter_desc_,
cudnnTypeWrapper<T>::type,
GetCudnnTensorFormat(order_),
M,
C,
MM,
C / group_,
kernel_h(),
kernel_w()));
if (!no_bias_) {
@ -456,14 +533,19 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
}
}
// Set the output
SetTensor4DDescriptorWithGroup(
cudnnTypeWrapper<T>::type, N, C, H_out, W_out, &top_desc_);
if (!no_bias_) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
top_desc_,
top_desc_for_bias_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
H_out,
W_out));
}
// Set the convolution descriptor
CAFFE_ENFORCE_EQ(
pad_t(),
@ -504,6 +586,8 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
CUDNN_ENFORCE(
cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
}
// set cuDNN groups if appropriate
CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc_, group_));
#endif
if (force_algo_[ALGO_WGRAD] >= 0) {
bwd_filter_algo_ =
@ -622,13 +706,14 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
cudnn_wrapper_.inline_cudnn_handle(),
cudnnTypeWrapper<T>::kOne(),
top_desc_,
top_desc_for_bias_,
dY.template data<T>(),
cudnnTypeWrapper<T>::kZero(),
bias_desc_,
dbias->template mutable_data<T>()));
}
#if CUDNN_VERSION_MIN(7, 0, 0)
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
state->cudnn_handle(),
@ -647,7 +732,6 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
// Compute the gradient w.r.t. the input.
auto* dX = Output(
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
X.sizes(),
@ -668,6 +752,55 @@ bool CudnnConvTransposeGradientOp<T>::RunOnDevice() {
dX->template mutable_data<T>()));
}
});
#else
const int X_HxW = H * W;
const int Y_HxW = H_out * W_out;
const int group_offset_X =
order_ == StorageOrder::NCHW ? M / group_ * X_HxW : M / group_;
const int group_offset_Y =
order_ == StorageOrder::NCHW ? C / group_ * Y_HxW : C / group_;
const int group_offset_filter = filter.numel() / group_;
for (int i = 0; i < group_; ++i) {
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
state->cudnn_handle(),
cudnnTypeWrapper<T>::kOne(),
top_desc_,
dY.template data<T>() + i * group_offset_Y,
bottom_desc_,
X.template data<T>() + i * group_offset_X,
conv_desc_,
bwd_filter_algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T>::kZero(),
filter_desc_,
dfilter->template mutable_data<T>() + i * group_offset_filter));
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
// Compute the gradient w.r.t. the input.
auto* dX = Output(
no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
X.sizes(),
at::dtype<T>());
cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
CUDNN_ENFORCE(cudnnConvolutionForward(
state->cudnn_handle(),
cudnnTypeWrapper<T>::kOne(),
top_desc_,
dY.template data<T>() + i * group_offset_Y,
filter_desc_,
filter.template data<T>() + i * group_offset_filter,
conv_desc_,
algo_,
state->workspace().get(cudnn_ws_nbytes_),
cudnn_ws_nbytes_,
cudnnTypeWrapper<T>::kZero(),
bottom_desc_,
dX->template mutable_data<T>() + i * group_offset_X));
});
}
}
#endif
return true;
}

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,9 @@ template <class Context>
class ConvTransposeUnpoolBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
explicit ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws)
explicit ConvTransposeUnpoolBase(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<Context>(operator_def, ws),
legacy_pad_(
static_cast<LegacyPadding>(this->template GetSingleArgument<int>(
@ -27,6 +29,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
stride_(this->template GetRepeatedArgument<int>("strides")),
pads_(this->template GetRepeatedArgument<int>("pads")),
adj_(this->template GetRepeatedArgument<int>("adjs")),
group_(this->template GetSingleArgument<int>("group", 1)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
shared_buffer_(
@ -206,19 +209,7 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
virtual ~ConvTransposeUnpoolBase() {}
private:
LegacyPadding legacy_pad_;
int pad_;
protected:
vector<int> kernel_;
vector<int> stride_;
vector<int> pads_;
vector<int> adj_;
StorageOrder order_;
bool shared_buffer_;
Workspace* ws_;
// Accessors for 2D conv params.
inline int pad_t() const {
@ -289,14 +280,35 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
break;
}
}
LegacyPadding legacy_pad_;
int pad_;
std::vector<int> kernel_;
std::vector<int> stride_;
std::vector<int> pads_;
std::vector<int> adj_;
int group_;
StorageOrder order_;
bool shared_buffer_;
Workspace* ws_;
};
#define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
USE_OPERATOR_FUNCTIONS(Context); \
using ConvTransposeUnpoolBase<Context>::kernel_; \
using ConvTransposeUnpoolBase<Context>::kernel_h; \
using ConvTransposeUnpoolBase<Context>::kernel_w; \
using ConvTransposeUnpoolBase<Context>::stride_; \
using ConvTransposeUnpoolBase<Context>::stride_h; \
using ConvTransposeUnpoolBase<Context>::stride_w; \
using ConvTransposeUnpoolBase<Context>::pads_; \
using ConvTransposeUnpoolBase<Context>::pad_t; \
using ConvTransposeUnpoolBase<Context>::pad_l; \
using ConvTransposeUnpoolBase<Context>::pad_b; \
using ConvTransposeUnpoolBase<Context>::pad_r; \
using ConvTransposeUnpoolBase<Context>::adj_; \
using ConvTransposeUnpoolBase<Context>::group_; \
using ConvTransposeUnpoolBase<Context>::order_; \
using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
using ConvTransposeUnpoolBase<Context>::ws_

View File

@ -6,6 +6,7 @@ import numpy as np
from hypothesis import assume, given, settings
import hypothesis.strategies as st
from caffe2.proto import caffe2_pb2
from caffe2.python import core, utils
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.hip_test_util as hiputl
@ -360,6 +361,68 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
for i in outputs_to_check:
self.assertGradientChecks(gc, op, inputs, i, [0])
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(1, 3),
adj=st.integers(0, 2),
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 4),
group=st.integers(1, 4),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
**hu.gcs)
def test_convolution_transpose_with_group(
self, stride, pad, kernel, adj, size, input_channels,
output_channels, batch_size, group, order, engine, shared_buffer,
use_bias, gc, dc):
assume(adj < stride)
# TODO: Group conv_transpose in NHWC not implemented for GPU yet.
assume(group == 1 or order == "NCHW" or
gc.device_type == caffe2_pb2.CPU)
if group != 1 and order == "NHWC":
dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
if hiputl.run_in_hip(gc, dc) and order == "NHWC":
engine = ""
op = core.CreateOperator(
"ConvTranspose",
["X", "w", "b"] if use_bias else ["X", "w"],
["Y"],
stride=stride,
kernel=kernel,
pad=pad,
adj=adj,
group=group,
order=order,
engine=engine,
shared_buffer=int(shared_buffer),
device_option=gc,
)
input_channels *= group
output_channels *= group
X = np.random.rand(
batch_size, size, size, input_channels).astype(np.float32) - 0.5
w = np.random.rand(
input_channels, kernel, kernel, int(output_channels / group)) \
.astype(np.float32) - 0.5
b = np.random.rand(output_channels).astype(np.float32) - 0.5
if order == "NCHW":
X = utils.NHWC2NCHW(X)
w = utils.NHWC2NCHW(w)
inputs = [X, w, b] if use_bias else [X, w]
self.assertDeviceChecks(dc, op, inputs, [0])
for i in range(len(inputs)):
self.assertGradientChecks(gc, op, inputs, i, [0])
if __name__ == "__main__":
import unittest
unittest.main()