mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Spatial batch norm; currently just based on cudnn.
This commit is contained in:
@ -26,6 +26,7 @@ cc_library(
|
||||
"relu_op.cc",
|
||||
"sigmoid_op.cc",
|
||||
"softmax_op.cc",
|
||||
"spatial_batch_norm_op.cc",
|
||||
"summarize_op.cc",
|
||||
"tanh_op.cc",
|
||||
"tensor_protos_db_input.cc",
|
||||
@ -97,6 +98,7 @@ cc_library(
|
||||
srcs = [
|
||||
"conv_op_cudnn.cc",
|
||||
"softmax_op_cudnn.cc",
|
||||
"spatial_batch_norm_op_cudnn.cc",
|
||||
],
|
||||
deps = [
|
||||
":operators_headers",
|
||||
|
65
caffe2/operators/spatial_batch_norm_op.cc
Normal file
65
caffe2/operators/spatial_batch_norm_op.cc
Normal file
@ -0,0 +1,65 @@
|
||||
#include "caffe2/operators/spatial_batch_norm_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
// TODO: implement the CPU version of spatial batch normalization.
|
||||
|
||||
// Spatial batch normalization's gradient, depending on the various input sizes,
|
||||
// is a bit more complex than usual gradient operators.
|
||||
namespace {
|
||||
struct GetSpatialBNGradient : public GetGradientDefBase {
|
||||
static vector<OperatorDef>* Create(const OperatorDef& def) {
|
||||
// Check if we are in training or testing mode.
|
||||
bool is_test = false;
|
||||
if (HasArgument(def, "is_test")) {
|
||||
const auto& arg = GetArgument(def, "is_test");
|
||||
CAFFE_CHECK(arg.has_i());
|
||||
is_test = arg.i();
|
||||
}
|
||||
vector<string> grad_outputs{GI(def, 0), GI(def, 1), GI(def, 2)};
|
||||
vector<string> grad_inputs;
|
||||
if (is_test) {
|
||||
// This is in testing mode. The operator should have five input:
|
||||
// X, scale, bias, estimated_mean, estimated_inv_variance
|
||||
// The gradient inputs are:
|
||||
// X, scale, dY, estimated_mean, estimated_inv_variance
|
||||
CAFFE_CHECK_EQ(def.input_size(), 5);
|
||||
CAFFE_CHECK_EQ(def.output_size(), 1);
|
||||
grad_inputs = vector<string>{
|
||||
I(def, 0), I(def, 1), GO(def, 0), I(def, 3), I(def, 4)};
|
||||
} else {
|
||||
CAFFE_CHECK_EQ(def.input_size(), 3);
|
||||
CAFFE_CHECK(def.output_size() == 3 || def.output_size() == 5);
|
||||
// This is in training mode. The operator should have either three output:
|
||||
// Y, running_mean, running_inv_variance
|
||||
// or five:
|
||||
// Y, running_mean, running_inv_variance, saved_mean,
|
||||
// saved_inv_variance
|
||||
switch (def.output_size()) {
|
||||
case 3:
|
||||
// The original operator does not have saved mean and inv variance,
|
||||
// so the gradient operator cannot take advantage of that.
|
||||
// The gradient inputs are:
|
||||
// X, scale, dY
|
||||
grad_inputs = vector<string>{I(def, 0), I(def, 1), GO(def, 0)};
|
||||
break;
|
||||
case 5:
|
||||
// The original operator does have saved mean and inv variance,
|
||||
// and the gradient operator can take advantage of that.
|
||||
// The gradient inputs are:
|
||||
// X, scale, dY, saved_mean, saved_inv_variance
|
||||
grad_inputs = vector<string>{
|
||||
I(def, 0), I(def, 1), GO(def, 0), O(def, 3), O(def, 4)};
|
||||
break;
|
||||
default:
|
||||
CAFFE_LOG_FATAL << "Should not happen.";
|
||||
}
|
||||
}
|
||||
return SingleGradientDef(
|
||||
"SpatialBNGradient", "", grad_inputs, grad_outputs);
|
||||
}
|
||||
};
|
||||
REGISTER_GRADIENT(SpatialBN, GetSpatialBNGradient);
|
||||
} // namespace
|
||||
|
||||
} // namespace caffe2
|
81
caffe2/operators/spatial_batch_norm_op.h
Normal file
81
caffe2/operators/spatial_batch_norm_op.h
Normal file
@ -0,0 +1,81 @@
|
||||
#ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
|
||||
#define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class SpatialBNOpBase : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SpatialBNOpBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
is_test_(OperatorBase::GetSingleArgument<int>("is_test", 0)),
|
||||
epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)),
|
||||
momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9)),
|
||||
order_(StringToStorageOrder(
|
||||
OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
|
||||
CAFFE_CHECK((is_test_ && InputSize() == 5) ||
|
||||
(!is_test_ && InputSize() == 3));
|
||||
CAFFE_CHECK((is_test_ && OutputSize() == 1) ||
|
||||
(!is_test_ && (OutputSize() == 3 || OutputSize() == 5)));
|
||||
CAFFE_CHECK_GT(epsilon_, 0);
|
||||
CAFFE_CHECK_GE(momentum_, 0);
|
||||
CAFFE_CHECK_LE(momentum_, 1);
|
||||
}
|
||||
~SpatialBNOpBase() {}
|
||||
|
||||
protected:
|
||||
bool is_test_;
|
||||
double epsilon_;
|
||||
double momentum_;
|
||||
StorageOrder order_;
|
||||
// Input: X, scale, bias (if training mode)
|
||||
// Input: X, scale, bias, estimated_mean, estimated_inv_variance
|
||||
// (if inference mode)
|
||||
// Output: Y, running_mean, running_inv_variance (if training mode, type 1)
|
||||
// Output: Y, running_mean, running_inv_variance, saved_mean,
|
||||
// saved_inv_variance (if training mode, type 2)
|
||||
// Output: Y (if inference mode)
|
||||
INPUT_OUTPUT_STATS(3, 5, 1, 5);
|
||||
INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_INV_VAR);
|
||||
OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_INV_VAR, SAVED_MEAN, SAVED_INV_VAR);
|
||||
DISABLE_COPY_AND_ASSIGN(SpatialBNOpBase);
|
||||
};
|
||||
|
||||
template <class Context>
|
||||
class SpatialBNGradientOpBase : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SpatialBNGradientOpBase(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
is_test_(OperatorBase::GetSingleArgument<int>("is_test", 0)),
|
||||
epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)),
|
||||
order_(StringToStorageOrder(
|
||||
OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
|
||||
CAFFE_CHECK((InputSize() == 6) || (!is_test_ && InputSize() == 4));
|
||||
CAFFE_CHECK_EQ(OutputSize(), 3);
|
||||
}
|
||||
~SpatialBNGradientOpBase() {}
|
||||
|
||||
protected:
|
||||
bool is_test_;
|
||||
double epsilon_;
|
||||
StorageOrder order_;
|
||||
// Input: X, scale, dY (type 1)
|
||||
// Input: X, scale, dY, saved_mean, saved_inv_variance
|
||||
// (type 2, faster, and also necessary if one wants to compute gradient
|
||||
// in testing mode)
|
||||
// Output: dX, dscale, dbias
|
||||
INPUT_OUTPUT_STATS(3, 5, 3, 3);
|
||||
INPUT_TAGS(INPUT, SCALE, OUTPUT_GRAD, SAVED_MEAN, SAVED_INV_VAR);
|
||||
OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
|
||||
DISABLE_COPY_AND_ASSIGN(SpatialBNGradientOpBase);
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
|
238
caffe2/operators/spatial_batch_norm_op_cudnn.cc
Normal file
238
caffe2/operators/spatial_batch_norm_op_cudnn.cc
Normal file
@ -0,0 +1,238 @@
|
||||
#include "caffe2/core/common_cudnn.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/spatial_batch_norm_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
static_assert(CUDNN_VERSION >= 4000,
|
||||
"CudnnSpatialBN requires cudnn version 4.0 or above.");
|
||||
|
||||
constexpr cudnnBatchNormMode_t kSpatialBNMode = CUDNN_BATCHNORM_SPATIAL;
|
||||
|
||||
template <typename T>
|
||||
class CudnnSpatialBNOp final : public SpatialBNOpBase<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
CudnnSpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SpatialBNOpBase<CUDAContext>(operator_def, ws),
|
||||
cudnn_wrapper_(&device_context_) {
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&data_desc_));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_param_desc_));
|
||||
if (epsilon_ < CUDNN_BN_MIN_EPSILON) {
|
||||
CAFFE_LOG_ERROR << "Provided epsilon is smaller than "
|
||||
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
|
||||
<< "CUDNN_BN_MIN_EPSILON instead.";
|
||||
epsilon_ = CUDNN_BN_MIN_EPSILON;
|
||||
}
|
||||
}
|
||||
|
||||
~CudnnSpatialBNOp() {
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(data_desc_));
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(bn_param_desc_));
|
||||
}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
CuDNNWrapper cudnn_wrapper_;
|
||||
cudnnTensorDescriptor_t data_desc_;
|
||||
cudnnTensorDescriptor_t bn_param_desc_;
|
||||
vector<int> cudnn_input_dims_;
|
||||
DISABLE_COPY_AND_ASSIGN(CudnnSpatialBNOp);
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
class CudnnSpatialBNGradientOp final
|
||||
: public SpatialBNGradientOpBase<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
CudnnSpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: SpatialBNGradientOpBase<CUDAContext>(operator_def, ws),
|
||||
cudnn_wrapper_(&device_context_) {
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&data_desc_));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&bn_param_desc_));
|
||||
if (epsilon_ < CUDNN_BN_MIN_EPSILON) {
|
||||
CAFFE_LOG_ERROR << "Provided epsilon is smaller than "
|
||||
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
|
||||
<< "CUDNN_BN_MIN_EPSILON instead.";
|
||||
epsilon_ = CUDNN_BN_MIN_EPSILON;
|
||||
}
|
||||
}
|
||||
|
||||
~CudnnSpatialBNGradientOp() {
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(data_desc_));
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(bn_param_desc_));
|
||||
}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
CuDNNWrapper cudnn_wrapper_;
|
||||
cudnnTensorDescriptor_t data_desc_;
|
||||
cudnnTensorDescriptor_t bn_param_desc_;
|
||||
vector<int> cudnn_input_dims_;
|
||||
DISABLE_COPY_AND_ASSIGN(CudnnSpatialBNGradientOp);
|
||||
};
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Implementations
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
bool CudnnSpatialBNOp<T>::RunOnDevice() {
|
||||
const auto& X = Input(INPUT);
|
||||
const auto& scale = Input(SCALE);
|
||||
const auto& bias = Input(BIAS);
|
||||
|
||||
CAFFE_DCHECK_EQ(X.ndim(), 4);
|
||||
const int N = X.dim(0);
|
||||
const int C = (order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(3));
|
||||
const int H = (order_ == StorageOrder::NCHW ? X.dim(2) : X.dim(1));
|
||||
const int W = (order_ == StorageOrder::NCHW ? X.dim(3) : X.dim(2));
|
||||
CAFFE_DCHECK_EQ(scale.ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(bias.ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(scale.dim(0), C);
|
||||
CAFFE_DCHECK_EQ(bias.dim(0), C);
|
||||
// See if we need to reshape.
|
||||
if (X.dims() != cudnn_input_dims_) {
|
||||
CAFFE_VLOG(1) << "Setting descriptors.";
|
||||
cudnn_input_dims_ = X.dims();
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
data_desc_, GetCudnnTensorFormat(order_),
|
||||
cudnnTypeWrapper<T>::type, N, C, H, W));
|
||||
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(
|
||||
bn_param_desc_, data_desc_, kSpatialBNMode));
|
||||
}
|
||||
|
||||
// Now, depending on whether we are running test or not, we have two paths.
|
||||
const typename cudnnTypeWrapper<T>::ScalingParamType kOne = 1;
|
||||
const typename cudnnTypeWrapper<T>::ScalingParamType kZero = 0;
|
||||
if (is_test_) {
|
||||
// Run inference mode.
|
||||
const auto& est_mean = Input(EST_MEAN);
|
||||
const auto& est_inv_var = Input(EST_INV_VAR);
|
||||
CAFFE_DCHECK_EQ(est_mean.ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(est_inv_var.ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(est_mean.dim(0), C);
|
||||
CAFFE_DCHECK_EQ(est_inv_var.dim(0), C);
|
||||
|
||||
auto* Y = Output(OUTPUT);
|
||||
Y->ReshapeLike(X);
|
||||
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
|
||||
cudnn_wrapper_.cudnn_handle(), kSpatialBNMode, &kOne, &kZero,
|
||||
data_desc_, X.template data<T>(),
|
||||
data_desc_, Y->template mutable_data<T>(),
|
||||
bn_param_desc_, scale.template data<T>(), bias.template data<T>(),
|
||||
est_mean.template data<T>(), est_inv_var.template data<T>(),
|
||||
epsilon_));
|
||||
} else {
|
||||
// Run training mode.
|
||||
auto* Y = Output(OUTPUT);
|
||||
Y->ReshapeLike(X);
|
||||
// obtain running mean and running inv var, and see if we need to
|
||||
// initialize them.
|
||||
auto* running_mean = Output(RUNNING_MEAN);
|
||||
auto* running_inv_var = Output(RUNNING_INV_VAR);
|
||||
double this_momentum;
|
||||
if (running_mean->size() == 0) {
|
||||
CAFFE_VLOG(1) << "Initializing running mean and var.";
|
||||
// Need to do initialization
|
||||
running_mean->Reshape(C);
|
||||
running_inv_var->Reshape(C);
|
||||
this_momentum = 1;
|
||||
} else {
|
||||
// Does not need to do initialization.
|
||||
CAFFE_DCHECK_EQ(running_mean->ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(running_inv_var->ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(running_mean->dim(0), C);
|
||||
CAFFE_DCHECK_EQ(running_inv_var->dim(0), C);
|
||||
this_momentum = momentum_;
|
||||
}
|
||||
// If specified, save the mean and inv var results.
|
||||
void* save_mean_data = nullptr;
|
||||
void* save_inv_var_data = nullptr;
|
||||
if (OutputSize() == 5) {
|
||||
auto* save_mean = Output(SAVED_MEAN);
|
||||
auto* save_inv_var = Output(SAVED_INV_VAR);
|
||||
save_mean->Reshape(C);
|
||||
save_inv_var->Reshape(C);
|
||||
save_mean_data = save_mean->template mutable_data<T>();
|
||||
save_inv_var_data = save_inv_var->template mutable_data<T>();
|
||||
}
|
||||
|
||||
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
|
||||
cudnn_wrapper_.cudnn_handle(), kSpatialBNMode, &kOne, &kZero,
|
||||
data_desc_, X.template data<T>(),
|
||||
data_desc_, Y->template mutable_data<T>(),
|
||||
bn_param_desc_, scale.template data<T>(), bias.template data<T>(),
|
||||
this_momentum, running_mean->template mutable_data<T>(),
|
||||
running_inv_var->template mutable_data<T>(), epsilon_,
|
||||
save_mean_data, save_inv_var_data));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
bool CudnnSpatialBNGradientOp<T>::RunOnDevice() {
|
||||
const auto& X = Input(INPUT);
|
||||
const auto& scale = Input(SCALE);
|
||||
const auto& dY = Input(OUTPUT_GRAD);
|
||||
|
||||
CAFFE_DCHECK_EQ(X.ndim(), 4);
|
||||
const int N = X.dim(0);
|
||||
const int C = (order_ == StorageOrder::NCHW ? X.dim(1) : X.dim(3));
|
||||
const int H = (order_ == StorageOrder::NCHW ? X.dim(2) : X.dim(1));
|
||||
const int W = (order_ == StorageOrder::NCHW ? X.dim(3) : X.dim(2));
|
||||
CAFFE_DCHECK_EQ(scale.ndim(), 1);
|
||||
CAFFE_DCHECK_EQ(scale.dim(0), C);
|
||||
// See if we need to reshape.
|
||||
if (X.dims() != cudnn_input_dims_) {
|
||||
cudnn_input_dims_ = X.dims();
|
||||
CUDNN_CHECK(cudnnSetTensor4dDescriptor(
|
||||
data_desc_, GetCudnnTensorFormat(order_),
|
||||
cudnnTypeWrapper<T>::type, N, C, H, W));
|
||||
CUDNN_CHECK(cudnnDeriveBNTensorDescriptor(
|
||||
bn_param_desc_, data_desc_, kSpatialBNMode));
|
||||
}
|
||||
|
||||
auto* dX = Output(INPUT_GRAD);
|
||||
auto* dScale = Output(SCALE_GRAD);
|
||||
auto* dBias = Output(BIAS_GRAD);
|
||||
dX->ReshapeLike(X);
|
||||
dScale->ReshapeLike(scale);
|
||||
dBias->ReshapeLike(scale);
|
||||
|
||||
const void* saved_mean_data = nullptr;
|
||||
const void* saved_inv_var_data = nullptr;
|
||||
if (InputSize() == 5) {
|
||||
const auto& saved_mean = Input(SAVED_MEAN);
|
||||
const auto& saved_inv_var = Input(SAVED_INV_VAR);
|
||||
saved_mean_data = saved_mean.template data<T>();
|
||||
saved_inv_var_data = saved_inv_var.template data<T>();
|
||||
}
|
||||
|
||||
const typename cudnnTypeWrapper<T>::ScalingParamType kOne = 1;
|
||||
const typename cudnnTypeWrapper<T>::ScalingParamType kZero = 0;
|
||||
CUDNN_CHECK(cudnnBatchNormalizationBackward(
|
||||
cudnn_wrapper_.cudnn_handle(), kSpatialBNMode, &kOne, &kZero,
|
||||
data_desc_, X.template data<T>(), data_desc_, dY.template data<T>(),
|
||||
data_desc_, dX->template mutable_data<T>(),
|
||||
bn_param_desc_, scale.template data<T>(),
|
||||
dScale->template mutable_data<T>(), dBias->template mutable_data<T>(),
|
||||
epsilon_, saved_mean_data, saved_inv_var_data));
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Since there is no default implementation for spatial batch normalization,
|
||||
// we will register the cudnn version as the default as well.
|
||||
REGISTER_CUDA_OPERATOR(SpatialBN, CudnnSpatialBNOp<float>);
|
||||
REGISTER_CUDA_OPERATOR(SpatialBNGradient, CudnnSpatialBNGradientOp<float>);
|
||||
|
||||
REGISTER_CUDNN_OPERATOR(SpatialBN, CudnnSpatialBNOp<float>);
|
||||
REGISTER_CUDNN_OPERATOR(SpatialBNGradient, CudnnSpatialBNGradientOp<float>);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
Reference in New Issue
Block a user