diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index fe0696eefc1a..899348b89d1a 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -207,13 +207,13 @@ public: // cast the data pointer to a __restrict__ pointer. // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor // as an argument. - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor packed_accessor() const& { + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + PackedTensorAccessor packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data()"); AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data()),sizes().data(),strides().data()); + return PackedTensorAccessor(static_cast::PtrType>(data()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits> + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> PackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; diff --git a/aten/src/ATen/core/TensorAccessor.h b/aten/src/ATen/core/TensorAccessor.h index 442be6331ed3..d008f9e8b461 100644 --- a/aten/src/ATen/core/TensorAccessor.h +++ b/aten/src/ATen/core/TensorAccessor.h @@ -26,15 +26,15 @@ struct RestrictPtrTraits { // to functions and types available there (e.g. IntList isn't). // The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers. -template class PtrTraits = DefaultPtrTraits> +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessorBase( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) + const index_t* sizes_, + const index_t* strides_) : data_(data_), sizes_(sizes_), strides_(strides_) {} C10_HOST IntList sizes() const { return IntList(sizes_,N); @@ -42,60 +42,62 @@ public: C10_HOST IntList strides() const { return IntList(strides_,N); } - C10_HOST_DEVICE int64_t stride(int64_t i) const { + C10_HOST_DEVICE index_t stride(index_t i) const { return strides_[i]; } - C10_HOST_DEVICE int64_t size(int64_t i) const { + C10_HOST_DEVICE index_t size(index_t i) const { return sizes_[i]; } - C10_HOST_DEVICE T* data() { + C10_HOST_DEVICE PtrType data() { return data_; } - C10_HOST_DEVICE const T* data() const { + C10_HOST_DEVICE const PtrType data() const { return data_; } - protected: PtrType data_; - const int64_t* sizes_; - const int64_t* strides_; + const index_t* sizes_; + const index_t* strides_; }; // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using // `Tensor.accessor()`. // For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only // indexing on the device uses `TensorAccessor`s. -template class PtrTraits = DefaultPtrTraits> -class TensorAccessor : public TensorAccessorBase { +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class TensorAccessor : public TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessor( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) - : TensorAccessorBase(data_, sizes_, strides_) {} + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} - C10_HOST_DEVICE TensorAccessor operator[](int64_t i) { - return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + C10_HOST_DEVICE TensorAccessor operator[](index_t i) { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } - C10_HOST_DEVICE const TensorAccessor operator[](int64_t i) const { - return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); + C10_HOST_DEVICE const TensorAccessor operator[](index_t i) const { + return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } }; -template class PtrTraits> -class TensorAccessor : public TensorAccessorBase { +template class PtrTraits, typename index_t> +class TensorAccessor : public TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessor( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) - : TensorAccessorBase(data_, sizes_, strides_) {} - C10_HOST_DEVICE T& operator[](int64_t i) { + const index_t* sizes_, + const index_t* strides_) + : TensorAccessorBase(data_,sizes_,strides_) {} + C10_HOST_DEVICE T & operator[](index_t i) { + return this->data_[this->strides_[0]*i]; + } + C10_HOST_DEVICE const T & operator[](index_t i) const { return this->data_[this->strides_[0]*i]; } }; @@ -109,69 +111,104 @@ public: // Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__. // Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available // on the device, so those functions are host only. -template class PtrTraits = DefaultPtrTraits> +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class PackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST PackedTensorAccessorBase( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) + const index_t* sizes_, + const index_t* strides_) : data_(data_) { std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); std::copy(strides_, strides_ + N, std::begin(this->strides_)); } - C10_HOST_DEVICE int64_t stride(int64_t i) const { + + // if index_t is not int64_t, we want to have an int64_t constructor + template ::value>::type> + C10_HOST PackedTensorAccessorBase( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : data_(data_) { + for (int i = 0; i < N; i++) { + this->sizes_[i] = sizes_[i]; + this->strides_[i] = strides_[i]; + } + } + + C10_HOST_DEVICE index_t stride(index_t i) const { return strides_[i]; } - C10_HOST_DEVICE int64_t size(int64_t i) const { + C10_HOST_DEVICE index_t size(index_t i) const { return sizes_[i]; } - + C10_HOST_DEVICE PtrType data() { + return data_; + } + C10_HOST_DEVICE const PtrType data() const { + return data_; + } protected: PtrType data_; - int64_t sizes_[N]; - int64_t strides_[N]; + index_t sizes_[N]; + index_t strides_[N]; }; -template class PtrTraits = DefaultPtrTraits> -class PackedTensorAccessor : public PackedTensorAccessorBase { +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +class PackedTensorAccessor : public PackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST PackedTensorAccessor( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_){}; + const index_t* sizes_, + const index_t* strides_) + : PackedTensorAccessorBase(data_, sizes_, strides_) {} - C10_DEVICE TensorAccessor operator[](int64_t i) { - int64_t* new_sizes = this->sizes_+1; - int64_t* new_strides = this->strides_+1; - return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + // if index_t is not int64_t, we want to have an int64_t constructor + template ::value>::type> + C10_HOST PackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : PackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE TensorAccessor operator[](index_t i) { + index_t* new_sizes = this->sizes_ + 1; + index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); } - C10_DEVICE const TensorAccessor operator[](int64_t i) const { - int64_t* new_sizes = this->sizes_+1; - int64_t* new_strides = this->strides_+1; - return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); + C10_DEVICE const TensorAccessor operator[](index_t i) const { + const index_t* new_sizes = this->sizes_ + 1; + const index_t* new_strides = this->strides_ + 1; + return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); } }; -template class PtrTraits> -class PackedTensorAccessor : public PackedTensorAccessorBase { +template class PtrTraits, typename index_t> +class PackedTensorAccessor : public PackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST PackedTensorAccessor( PtrType data_, - const int64_t* sizes_, - const int64_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_){}; + const index_t* sizes_, + const index_t* strides_) + : PackedTensorAccessorBase(data_, sizes_, strides_) {} - C10_DEVICE T& operator[](int64_t i) { - return this->data_[this->strides_[0]*i]; + // if index_t is not int64_t, we want to have an int64_t constructor + template ::value>::type> + C10_HOST PackedTensorAccessor( + PtrType data_, + const source_index_t* sizes_, + const source_index_t* strides_) + : PackedTensorAccessorBase(data_, sizes_, strides_) {} + + C10_DEVICE T & operator[](index_t i) { + return this->data_[this->strides_[0] * i]; } - C10_DEVICE const T& operator[](int64_t i) const { + C10_DEVICE const T& operator[](index_t i) const { return this->data_[this->strides_[0]*i]; } }; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 688a70c81c7a..c9cf16b92f71 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -482,6 +482,8 @@ _(aten, mv) \ _(aten, mvlgamma) \ _(aten, narrow) \ _(aten, narrow_copy) \ +_(aten, native_batch_norm) \ +_(aten, native_batch_norm_backward) \ _(aten, native_clone) \ _(aten, native_get_device) \ _(aten, native_norm) \ @@ -634,9 +636,6 @@ _(aten, th_pow) \ _(aten, th_resize_as) \ _(aten, th_tensor) \ _(aten, th_zero) \ -_(aten, thnn_batch_norm) \ -_(aten, thnn_batch_norm_backward) \ -_(aten, thnn_batch_norm_forward) \ _(aten, thnn_conv2d) \ _(aten, thnn_conv2d_backward) \ _(aten, thnn_conv2d_forward) \ diff --git a/aten/src/ATen/cuda/detail/IndexUtils.cuh b/aten/src/ATen/cuda/detail/IndexUtils.cuh index 9bbf8f7af88f..0f2e42115a78 100644 --- a/aten/src/ATen/cuda/detail/IndexUtils.cuh +++ b/aten/src/ATen/cuda/detail/IndexUtils.cuh @@ -9,7 +9,7 @@ namespace cuda { namespace detail { bool maybeOverlappingIndices(const at::Tensor& t); -bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits::max()); +bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits::max()); template TensorInfo diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index b3b557999372..8cdaa78b8355 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -1,6 +1,8 @@ #include "ATen/ATen.h" #include "ATen/NativeFunctions.h" - +#include "ATen/AccumulateType.h" +#include "ATen/CPUApplyUtils.h" +#include "ATen/Parallel.h" #include "ATen/Config.h" #include "ATen/detail/CUDAHooksInterface.h" @@ -25,6 +27,198 @@ namespace { } } +// TensorAccessor when it is defined to work around undefined... +template +static TensorAccessor conditional_accessor_1d(const Tensor& t) { + if (! t.defined()) { + return TensorAccessor(nullptr, nullptr, nullptr); + } + return t.accessor(); +} + + +template +std::tuple batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias, + const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) { + + using accscalar_t = at::acc_type; + Tensor output = at::empty_like(input); + + int64_t n_input = input.size(1); + int64_t n = input.numel() / n_input; + + Tensor save_mean; + Tensor save_invstd; + const int64_t zero = 0; + if (train) { + save_mean = at::empty({n_input}, input.options()); + save_invstd = at::empty({n_input}, input.options()); + } + auto save_mean_a = conditional_accessor_1d(save_mean); + auto save_invstd_a = conditional_accessor_1d(save_invstd); + + auto running_mean_a = conditional_accessor_1d(running_mean); + auto running_var_a = conditional_accessor_1d(running_var); + + parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) { + for (int64_t f = b_begin; f < b_end; ++f) { + Tensor in = input.select(1, f); + Tensor out = output.select(1, f); + + scalar_t mean, invstd; + + if (train) { + // compute mean per input + accscalar_t sum = 0; + CPU_tensor_apply1(in, [&] (const scalar_t& i) { + sum += i; + }); + + mean = (scalar_t) (sum / n); + save_mean_a[f] = mean; + + // compute variance per input + sum = 0; + CPU_tensor_apply1(in, [&] (const scalar_t& i) { + sum += (i - mean) * (i - mean); + }); + + if (sum == 0 && eps == 0.0) { + invstd = 0; + } else { + invstd = (scalar_t) (1 / std::sqrt(sum/n + eps)); + } + save_invstd_a[f] = invstd; + + // update running averages + if (running_mean.defined()) { + running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f]; + } + if (running_var.defined()) { + accscalar_t unbiased_var = sum / (n - 1); + running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f]; + } + } else { + mean = running_mean_a[f]; + invstd = 1 / std::sqrt(running_var_a[f] + eps); + } + + // compute output + scalar_t w = weight.defined() ? weight.data()[f * weight.stride(0)] : 1; + scalar_t b = bias.defined() ? bias.data()[f * bias.stride(0)] : 0; + + CPU_tensor_apply2(out, in, [&](scalar_t& o, const scalar_t& i) { + o = ((i - mean) * invstd) * w + b; + }); + } + }); + return std::make_tuple(output, save_mean, save_invstd); +} + + +template +std::tuple batch_norm_backward_cpu_template(const Tensor& grad_out_, const Tensor& input, const Tensor& weight, + const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, + bool train, double eps, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + + Tensor grad_input; + Tensor grad_weight; + Tensor grad_bias; + if (grad_input_mask[0]) { + grad_input = at::empty_like(input); + } + if (grad_input_mask[1]) { + grad_weight = at::empty_like(weight); + } + if (grad_input_mask[2]) { + grad_bias = at::empty_like(weight); + } + + auto weight_a = conditional_accessor_1d(weight); + auto grad_weight_a = conditional_accessor_1d(grad_weight); + auto grad_bias_a = conditional_accessor_1d(grad_bias); + + int64_t n_input = input.size(1); + int64_t n = input.numel() / n_input; + + auto save_mean_a = conditional_accessor_1d(save_mean); + auto save_invstd_a = conditional_accessor_1d(save_invstd); + + auto running_mean_a = conditional_accessor_1d(running_mean); + auto running_var_a = conditional_accessor_1d(running_var); + + + parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) { + for (int64_t f = b_begin; f < b_end; ++f) { + Tensor in = input.select(1, f); + Tensor grad_out = grad_out_.select(1, f); + + scalar_t w = weight.defined() ? weight_a[f] : 1; + + scalar_t mean, invstd; + if (train) { + mean = save_mean_a[f]; + invstd = save_invstd_a[f]; + } else { + mean = running_mean_a[f]; + invstd = 1 / std::sqrt(running_var_a[f] + eps); + } + + // sum over all gradOutput in feature plane + accscalar_t sum = 0; + CPU_tensor_apply1(grad_out, [&](const scalar_t& g) { + sum += g; + }); + + // dot product of the Q(X) and gradOuput + accscalar_t dotp = 0; + CPU_tensor_apply2(in, grad_out, [&](const scalar_t& i, const scalar_t& go) { + dotp += (i - mean) * go; + }); + + if (grad_input_mask[0]) { + Tensor grad_in = grad_input.select(1, f); + if (train) { + // when in training mode + // Q(X) = X - E[x] ; i.e. input centered to zero mean + // Y = Q(X) / σ ; i.e. BN output before weight and bias + // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w + + // projection of gradOutput on to output scaled by std + scalar_t k = (scalar_t) dotp * invstd * invstd / n; + + CPU_tensor_apply2(grad_in, in, [&](scalar_t& gi, const scalar_t& i) { + gi = (i - mean)* k; + }); + + accscalar_t grad_mean = sum / n; + CPU_tensor_apply2(grad_in, grad_out, [&](scalar_t& gi, const scalar_t& go) { + gi = (go - grad_mean - gi) * invstd * w; + }); + } else { + // when in evaluation mode + // Q(X) = X - running_mean ; i.e. input centered to zero mean + // Y = Q(X) / running_std ; i.e. BN output before weight and bias + // dL/dX = w / running_std + CPU_tensor_apply2(grad_in, grad_out, [&](scalar_t& gi, const scalar_t& go) { + gi = go * invstd * w; + }); + } + } + if (grad_input_mask[1]) { + grad_weight_a[f] = dotp * invstd; + } + + if (grad_input_mask[2]) { + grad_bias_a[f] = sum; + } + } + }); + return std::make_tuple(grad_input, grad_weight, grad_bias); +} + Tensor batch_norm( const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */, const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */, @@ -86,9 +280,8 @@ Tensor batch_norm( training, momentum, eps)); } - return at::thnn_batch_norm( - input.contiguous(), weight, bias, - running_mean, running_var, training, momentum, eps); + return std::get<0>(at::native_batch_norm(input, weight, bias, + running_mean, running_var, training, momentum, eps)); } Tensor instance_norm( @@ -226,4 +419,20 @@ Tensor group_norm(const Tensor& input, int64_t num_groups, } } +std::tuple batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias, + const Tensor& running_mean, const Tensor& running_var, + bool train, double momentum, double eps) { + return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] { + return batch_norm_cpu_template(self, weight, bias, running_mean, running_var, train, momentum, eps); + }); +} + +std::tuple batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const Tensor& weight, + const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, + bool train, double eps, std::array grad_input_mask) { + return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_backward", [&] { + return batch_norm_backward_cpu_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask); + }); +} + }} // at::native diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu new file mode 100644 index 000000000000..b73635fd26a0 --- /dev/null +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -0,0 +1,536 @@ +#include +#include +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/CUDAApplyUtils.cuh" + +namespace at { namespace native { + +namespace { + + +#if defined(__HIP_PLATFORM_HCC__) +constexpr int WARP_SIZE = 64; + +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} + +#else +constexpr int WARP_SIZE = 32; + +template +__forceinline__ __device__ double device_sqrt(scalar_t val) { + return std::sqrt(val); +} +#endif + +// The maximum number of threads in a block +#if defined(__HIP_PLATFORM_HCC__) +constexpr int MAX_BLOCK_SIZE = 256; +#else +constexpr int MAX_BLOCK_SIZE = 512; +#endif + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static int getNumThreads(int nElem) { +#if defined(__HIP_PLATFORM_HCC__) + int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; +#else + int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; +#endif + for (int i = 0; i != 5; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +// Returns the index of the most significant 1 bit in `val`. +__device__ __forceinline__ int getMSB(int val) { + return 31 - __clz(val); +} + +template +struct Float2 { + accscalar_t v1, v2; + __device__ Float2() {} + __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} + __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} + __device__ Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } +}; + +template +struct SumOp { + __device__ SumOp(const PTA& t) : tensor(t) {} + __device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) { + return static_cast(tensor[batch][plane][n]); + } + const PTA& tensor; +}; + + template +struct VarOp { + __device__ VarOp(accscalar_t m, const PTA& t) : mean(m), tensor(t) {} + __device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) { + accscalar_t val = tensor[batch][plane][n]; + return (val - mean) * (val - mean); + } + const accscalar_t mean; + const PTA& tensor; +}; + +template +struct GradOp { + __device__ GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +// Sum across all threads within a warp +template +static __device__ __forceinline__ T warpSum(T val) { + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); + } + return val; +} + +template +static __device__ __forceinline__ Float2 warpSum(Float2 value) { + value.v1 = warpSum(value.v1); + value.v2 = warpSum(value.v2); + return value; +} + +// Sum across (batch, x/y/z) applying Op() pointwise +// this works by first having each thread sum it's part +// of the data. Then there is a double-shuffeling reduction. +// First each warp (of WARP_SIZE threads) uses warpSum to reduce its +// data to the "warp leader", who writes its value into shared memory. +// Then a single warp reads the remaining (at most WARP_SIZE) items +// and reduces them using another warpSum. +// The implicit assumption is that there are no more +// than WARP_SIZE**2 threads. +template +__device__ scalar_t reduce(Op op, PTA tensor, int plane) { + // first the reductions each thread does separately + scalar_t sum = static_cast(0); + for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { + sum += op(batch, plane, x); + } + } + + // first warpSum to get one value per thread to + // one value per warp + sum = warpSum(sum); + + // this writes each warps item into shared memory + // there are at most WARP_SIZE items left because + // there are at most WARP_SIZE**2 threads at the beginning + __shared__ scalar_t shared[WARP_SIZE]; + __syncthreads(); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + if (tid % WARP_SIZE == 0) { + shared[tid / WARP_SIZE] = sum; + } + if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) { + // zero out the other entries in shared + shared[tid] = (scalar_t)0; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid / WARP_SIZE == 0) { + sum = warpSum(shared[tid]); + if (tid == 0) { + shared[0] = sum; + } + } + __syncthreads(); + + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +template +__global__ void batch_norm_transform_input_kernel( + const PackedTensorAccessor input, + PackedTensorAccessor output, + const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, + const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, + const PackedTensorAccessor weight, + const PackedTensorAccessor bias, + accscalar_t epsilon) { + + index_t plane = blockIdx.x; + + if (plane >= input.size(1)) { + return; + } + + accscalar_t gamma = weight.size(0) > 0 ? static_cast(weight[plane]) : static_cast(1); + accscalar_t beta = bias.size(0) > 0 ? static_cast(bias[plane]) : static_cast(0); + accscalar_t mean = static_cast(mean_[plane]); + accscalar_t invstd; + if (train) { + invstd = var_or_invstd[plane]; + } else { + invstd = static_cast(1) / device_sqrt(static_cast(var_or_invstd[plane]) + epsilon); + } + + index_t bs = input.size(0); + index_t fs = input.size(2); + + index_t bstep = blockDim.y * gridDim.y; + for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) { + auto o = output[batch][plane]; + auto i = input[batch][plane]; + for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) { + o[feature] = static_cast(gamma * (i[feature] - mean) * invstd + beta); + } + } +} + + +template +__global__ void batch_norm_collect_statistics_kernel( + const PackedTensorAccessor input, + const accscalar_t epsilon, + const accscalar_t momentum, + PackedTensorAccessor running_mean, + PackedTensorAccessor running_var, + PackedTensorAccessor save_mean, + PackedTensorAccessor save_invstd) { + + __shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE]; + + int plane = blockIdx.x; + int N = input.size(0) * input.size(2); + int tid = threadIdx.x + threadIdx.y * blockDim.x; + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum across the block) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire block. + // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. + accscalar_t* shared_avg_var = (accscalar_t*) &shared_n[WARP_SIZE]; + + // first the reductions each thread does separately + accscalar_t avg = 0; + accscalar_t var_n = 0; + int n = 0; + for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { + accscalar_t v = input[batch][plane][x]; + accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first warpSum to get one value per thread to + // one value per warp + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); + if (n + o_n > 0) { + var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + ((avg - o_avg) * (avg - o_avg) * n * o_n) / (n + o_n); + avg = (n * avg + o_n * o_avg)/(n+o_n); + n += o_n; + } + } + + // this writes each warps item into shared memory + // there are at most WARP_SIZE items left because + // there are at most WARP_SIZE**2 threads at the beginning + __syncthreads(); + if (tid % WARP_SIZE == 0) { + shared_n[tid / WARP_SIZE] = n; + shared_avg_var[tid / WARP_SIZE * 2] = avg; + shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n; + } + __syncthreads(); + // now have a second warpSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + + if (tid < WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : 0); + var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : 0); + } + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); + if (n + o_n > 0) { + var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + ((avg - o_avg) * (avg - o_avg) * n * o_n) / (n + o_n); + avg = (n * avg + o_n * o_avg)/(n+o_n); + n += o_n; + } + } + + // Save the mean, variance, and moving averages + if (tid == 0) { + accscalar_t invstd = 0; + if (var_n != static_cast(0) || epsilon != static_cast(0)) { + invstd = static_cast(1) / device_sqrt(var_n / N + epsilon); + } + save_mean[plane] = avg; + save_invstd[plane] = invstd; + if (running_mean.data() != NULL) { + running_mean[plane] = static_cast((1 - momentum) * running_mean[plane] + momentum * avg); + } + if (running_var.data() != NULL) { + accscalar_t unbiasedVar = var_n / (N - 1); + running_var[plane] = static_cast((1 - momentum) * running_var[plane] + momentum * unbiasedVar); + } + } + +} + +template +__global__ void batch_norm_backward_kernel( + const PackedTensorAccessor input, + const PackedTensorAccessor grad_output, + PackedTensorAccessor grad_input, + PackedTensorAccessor grad_weight, + PackedTensorAccessor grad_bias, + const PackedTensorAccessor weight, + const PackedTensorAccessor running_mean, + const PackedTensorAccessor running_var, + const PackedTensorAccessor save_mean, + const PackedTensorAccessor save_invstd, + bool train, + accscalar_t epsilon) { + + index_t plane = blockIdx.x; + index_t N = grad_output.size(0) * grad_output.size(2); + + accscalar_t mean, invstd; + if (train) { + mean = save_mean[plane]; + invstd = save_invstd[plane]; + } else { + mean = static_cast(running_mean[plane]); + invstd = static_cast(1) / device_sqrt(static_cast(running_var[plane]) + epsilon); + } + + accscalar_t weight_val = weight.size(0) > 0 ? static_cast(weight[plane]) : accscalar_t(1); + accscalar_t norm = accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp> g(mean, input, grad_output); + Float2 res = reduce, GradOp>>(g, grad_output, plane); + accscalar_t grad_output_sum = res.v1; + accscalar_t dot_p = res.v2; + + accscalar_t grad_mean = grad_output_sum * norm; + accscalar_t proj_scale = dot_p * norm * invstd * invstd; + accscalar_t grad_scale = invstd * weight_val; + + if (grad_input.data() != NULL) { + for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) { + for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) { + scalar_t go = grad_output[batch][plane][x]; + if (train) { + scalar_t inp = input[batch][plane][x]; + accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast((go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight.size(0) > 0) { + if (threadIdx.x == 0) { + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias.size(0) > 0) { + if (threadIdx.x == 0) { + grad_bias[plane] = static_cast(grad_output_sum); + } + } +} + +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +static PackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { + if (! t.defined()) { + const std::vector zeros(dim); + return PackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + } + return t.packed_accessor(); +} + +template +std::tuple batch_norm_cuda_template(const Tensor& input_, const Tensor& weight_, const Tensor& bias_, + const Tensor& running_mean_, const Tensor& running_var_, + bool train, double momentum, double epsilon) { + + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor save_mean_; + Tensor save_invstd_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions + auto output_reshaped = at::empty_like(input_reshaped); + + auto bs = input_reshaped.size(0); + auto features = input_reshaped.size(2); + auto input = input_reshaped.packed_accessor(); + auto input_options = input_.options(); + if (input_.type().scalarType() == at::ScalarType::Half) { + input_options = input_options.dtype(ScalarType::Float); + } + if (train) { + save_mean_ = at::empty({n_input}, input_options); + save_invstd_ = at::empty({n_input}, input_options); + } else { + save_mean_ = at::empty({0}, input_options); + save_invstd_ = at::empty({0}, input_options); + } + auto output = output_reshaped.packed_accessor(); + auto weight = packed_accessor_or_dummy(weight_); + auto bias = packed_accessor_or_dummy(bias_); + auto running_mean = packed_accessor_or_dummy(running_mean_); + auto running_var = packed_accessor_or_dummy(running_var_); + auto save_mean = save_mean_.packed_accessor(); + auto save_invstd = save_invstd_.packed_accessor(); + auto stream = at::cuda::getCurrentCUDAStream(); + + // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, + // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks + // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // The various planes are independent, so we use blocks for them. + int tf = std::max(getNumThreads(input.size(2)/4), + std::min(getNumThreads(input.size(2)), 64)); + int tb = std::max(64/tf, 1); + dim3 blocks_trans(input.size(1), std::max(1, std::min((256*1024)/input.size(1), + (input.size(0)+tb-1)/tb))); + dim3 threads_trans(tf, tb); + if (!train) { + batch_norm_transform_input_kernel <<>> + (input, output, running_mean, running_var, weight, bias, epsilon); + } else { + // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in + // the feature dimension, we'll use some threads for blocks + dim3 blocks(input.size(1)); + tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + batch_norm_collect_statistics_kernel <<>> + (input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd); + batch_norm_transform_input_kernel <<>> + (input, output, save_mean, save_invstd, weight, bias, epsilon); + } + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output_reshaped.view(input_.sizes()), save_mean_, save_invstd_); +} + +template +std::tuple batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, + const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, + bool train, double epsilon, std::array grad_input_mask) { + + using accscalar_t = at::acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_); + } + + auto input = input_reshaped.packed_accessor(); + auto grad_output = grad_output_reshaped.packed_accessor(); + auto grad_input = packed_accessor_or_dummy(grad_input_reshaped); + auto weight = packed_accessor_or_dummy(weight_); + auto grad_weight = packed_accessor_or_dummy(grad_weight_); + auto grad_bias = packed_accessor_or_dummy(grad_bias_); + auto running_mean = packed_accessor_or_dummy(running_mean_); + auto running_var = packed_accessor_or_dummy(running_var_); + auto save_mean = packed_accessor_or_dummy(save_mean_); + auto save_invstd = packed_accessor_or_dummy(save_invstd_); + + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 blocks(input.size(1)); + int tf = getNumThreads(input.size(2)); + dim3 threads(tf, std::max(1, MAX_BLOCK_SIZE/tf)); + + batch_norm_backward_kernel <<>> + (input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var, + save_mean, save_invstd, train, epsilon); + THCudaCheck(cudaGetLastError()); + + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +} // anonymous namespace + +std::tuple batch_norm_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias, + const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] { + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_cuda_template(self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } else { + return batch_norm_cuda_template(self, weight, bias, running_mean, running_var, train, momentum, epsilon); + } + }); +} + +std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, + const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array grad_input_mask) { + return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] { + if (cuda::detail::canUse32BitIndexMath(self)) { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } else { + return batch_norm_backward_cuda_template(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); + } + }); +} + +} } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f295886cfebb..707cacd0359b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1224,6 +1224,16 @@ variants: function, method device_guard: false +- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: batch_norm_cpu + CUDA: batch_norm_cuda + +- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, double eps, std::array output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: batch_norm_backward_cpu + CUDA: batch_norm_backward_cuda + - func: ones(IntList size, TensorOptions options={}) -> Tensor - func: ones_out(Tensor result, IntList size) -> Tensor diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index 399e94f1b86a..b5747817957f 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -217,18 +217,6 @@ output: self_->dim() == 0 grad_input: output_->dim() == 0 -# Batch normalization - -# The buffers here are somewhat hazardous, because their type will be -# based off of self, even though you may plausibly wish running_mean -# and running_var to have different precision than self (e.g., -# BatchNorm on half). Fortunately, THNN doesn't actually ever do this, -# so the buffer allocation code is "correct". If you ever do fix this, -# you should just port the function entirely to a native ATen function. -- name: thnn_batch_norm(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps) - cname: BatchNormalization - buffers: [save_mean, save_std] - # Convolutions - name: thnn_conv_transpose2d(Tensor self, Tensor weight, IntList[2] kernel_size, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] output_padding=0, IntList[2] dilation=1) diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index 65f0c91b8210..37835767d472 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -207,13 +207,13 @@ public: // cast the data pointer to a __restrict__ pointer. // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor // as an argument. - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor packed_accessor() const& { + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + PackedTensorAccessor packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data()"); AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data()),sizes().data(),strides().data()); + return PackedTensorAccessor(static_cast::PtrType>(data()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits> + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> PackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; diff --git a/aten/src/THCUNN/BatchNormalization.cu b/aten/src/THCUNN/BatchNormalization.cu deleted file mode 100644 index 97579d1c4aef..000000000000 --- a/aten/src/THCUNN/BatchNormalization.cu +++ /dev/null @@ -1,303 +0,0 @@ -#include "THCUNN.h" -#include "common.h" -#include "TH/THHalf.h" -#include "THCHalfAutoNumerics.cuh" -#include "THCTensor.hpp" - -#include "THCDeviceTensor.cuh" -#include "THCDeviceTensorUtils.cuh" -#include "THCDeviceUtils.cuh" -#if defined(__HIP_PLATFORM_HCC__) -const int WARP_SIZE = 64; -#else -const int WARP_SIZE = 32; -#endif - -// The maximum number of threads in a block -#if defined(__HIP_PLATFORM_HCC__) -const int MAX_BLOCK_SIZE = 256; -#else -const int MAX_BLOCK_SIZE = 512; -#endif - -// Number of threads in a block given an input size up to MAX_BLOCK_SIZE -static int getNumThreads(int nElem) { -#if defined(__HIP_PLATFORM_HCC__) - int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; -#else - int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; -#endif - for (int i = 0; i != 5; ++i) { - if (nElem <= threadSizes[i]) { - return threadSizes[i]; - } - } - return MAX_BLOCK_SIZE; -} - -// Returns the index of the most significant 1 bit in `val`. -__device__ __forceinline__ int getMSB(int val) { - return 31 - __clz(val); -} - -template -struct Float2 { - Acctype v1, v2; - __device__ Float2() {} - __device__ Float2(Dtype v1, Dtype v2) : v1(ScalarConvert::to(v1)), v2(ScalarConvert::to(v2)) {} - __device__ Float2(Dtype v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} - __device__ Float2(int v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} - __device__ Float2& operator+=(const Float2& a) { - v1 += a.v1; - v2 += a.v2; - return *this; - } -}; - -template -struct SumOp { - __device__ SumOp(const DeviceTensor3 t) : tensor(t) {} - __device__ __forceinline__ Acctype operator()(int batch, int plane, int n) { - return ScalarConvert::to(tensor[batch][plane][n]); - } - const DeviceTensor3 tensor; -}; - -template -struct VarOp { - __device__ VarOp(Acctype m, const DeviceTensor3 t) : mean(m), tensor(t) {} - __device__ __forceinline__ Acctype operator()(int batch, int plane, int n) { - Dtype val = tensor[batch][plane][n]; - return (val - mean) * (val - mean); - } - const Acctype mean; - const DeviceTensor3 tensor; -}; - -template -struct GradOp { - __device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g) - : mean(m), input(i), gradOutput(g) {} - __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { - Dtype g = gradOutput[batch][plane][n]; - Dtype c = ScalarConvert::to(input[batch][plane][n] - mean); - return Float2(g, g * c); - } - const Acctype mean; - const DeviceTensor3 input; - const DeviceTensor3 gradOutput; -}; - -// Sum across all threads within a warp -template -static __device__ __forceinline__ T warpSum(T val) { -#if __CUDA_ARCH__ >= 300 - for (int i = 0; i < getMSB(WARP_SIZE); ++i) { - val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); - } -#else - __shared__ T values[MAX_BLOCK_SIZE]; - values[threadIdx.x] = val; - __threadfence_block(); - const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; - for (int i = 1; i < WARP_SIZE; i++) { - val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; - } -#endif - return val; -} - -template -static __device__ __forceinline__ Float2 warpSum(Float2 value) { - value.v1 = warpSum(value.v1); - value.v2 = warpSum(value.v2); - return value; -} - -// Sum across (batch, x/y/z) applying Op() pointwise -template -__device__ T reduce(Op op, DeviceTensor3 tensor, int plane) { - T sum = (T)0; - for (int batch = 0; batch < tensor.getSize(0); ++batch) { - for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) { - sum += op(batch, plane, x); - } - } - - // sum over NumThreads within a warp - sum = warpSum(sum); - - // 'transpose', and reduce within warp again - __shared__ T shared[WARP_SIZE]; - __syncthreads(); - if (threadIdx.x % WARP_SIZE == 0) { - shared[threadIdx.x / WARP_SIZE] = sum; - } - if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { - // zero out the other entries in shared - shared[threadIdx.x] = (T)0; - } - __syncthreads(); - if (threadIdx.x / WARP_SIZE == 0) { - sum = warpSum(shared[threadIdx.x]); - if (threadIdx.x == 0) { - shared[0] = sum; - } - } - __syncthreads(); - - // Everyone picks it up, should be broadcast into the whole gradInput - return shared[0]; -} - -template -__global__ void BatchNormalizationUpdateOutputInference_kernel( - const DeviceTensor3 input, - DeviceTensor3 output, - const DeviceTensor1 runningMean, - const DeviceTensor1 runningVar, - const DeviceTensor1 weight, - const DeviceTensor1 bias, - Acctype epsilon) { - - int plane = blockIdx.x; - - Acctype invstd = Acctype(1) / sqrt(runningVar[plane].ldg() + epsilon); - Acctype mean = ScalarConvert::to(runningMean[plane].ldg()); - Acctype gamma = weight.numElements() > 0 ? ScalarConvert::to(weight[plane].ldg()) : Acctype(1); - Acctype beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane].ldg()) : Acctype(0); - - // Write normalized and update the output - for (int batch = 0; batch < input.getSize(0); batch++) { - for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { - Dtype inp = input[batch][plane][x].ldg(); - output[batch][plane][x] = ScalarConvert::to(gamma * (inp - mean) * invstd + beta); - } - } -} - -template -__global__ void BatchNormalizationUpdateOutput_kernel( - const DeviceTensor3 input, - DeviceTensor3 output, - const DeviceTensor1 weight, - const DeviceTensor1 bias, - const Acctype epsilon, - const Acctype momentum, - DeviceTensor1 runningMean, - DeviceTensor1 runningVar, - DeviceTensor1 saveMean, - DeviceTensor1 saveStd) { - - int plane = blockIdx.x; - int N = input.getSize(0) * input.getSize(2); - - Acctype norm = Acctype(1) / N; - - // Compute the mean and variance across (batch, x/y/z) - Acctype mean = reduce(SumOp(input), input, plane) * norm; - __syncthreads(); - Acctype varN = reduce(VarOp(mean, input), input, plane); - Acctype invStd = 0; - if (varN != Acctype(0) || epsilon != Acctype(0)) { - invStd = 1 / sqrt(varN * norm + epsilon); - } - - // Save the mean, variance, and moving averages - if (threadIdx.x == 0) { - // Momentum based writeback - Acctype unbiasedVar = varN / (N - 1); - saveMean[plane] = ScalarConvert::to(mean); - saveStd[plane] = ScalarConvert::to(invStd); - if (runningMean.data() != NULL) { - runningMean[plane] = ScalarConvert::to((1 - momentum) * runningMean[plane] + momentum * mean); - } - if (runningVar.data() != NULL) { - runningVar[plane] = ScalarConvert::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar); - } - } - - // Write normalized and update the output - Acctype gamma = weight.numElements() > 0 ? ScalarConvert::to(weight[plane]) : ScalarConvert::to(1); - Acctype beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) : ScalarConvert::to(0); - for (int batch = 0; batch < input.getSize(0); ++batch) { - for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { - Dtype inp = input[batch][plane][x].ldg(); - output[batch][plane][x] = ScalarConvert::to(gamma * (inp - mean) * invStd + beta); - } - } -} - -template -__global__ void BatchNormalizationBackward_kernel( - const DeviceTensor3 input, - const DeviceTensor3 gradOutput, - DeviceTensor3 gradInput, - DeviceTensor1 gradWeight, - DeviceTensor1 gradBias, - const DeviceTensor1 weight, - const DeviceTensor1 runningMean, - const DeviceTensor1 runningVar, - const DeviceTensor1 saveMean, - const DeviceTensor1 saveStd, - bool train, - Acctype scale, - double eps) { - - int plane = blockIdx.x; - int N = gradOutput.getSize(0) * gradOutput.getSize(2); - - Acctype mean, stdVal; - if (train) { - mean = ScalarConvert::to(saveMean[plane]); - stdVal = ScalarConvert::to(saveStd[plane]); - } else { - mean = ScalarConvert::to(runningMean[plane]); - stdVal = 1 / sqrt(runningVar[plane] + eps); - } - - Acctype weightVal = weight.numElements() > 0 ? ScalarConvert::to(weight[plane]) : Acctype(1); - Acctype norm = Acctype(1) / N; - - // Compute two values across (batch, x/y/z) in one pass: - // 1. Sum(gradOutput) - // 2. DotProduct(input - mean, gradOutput) - GradOp g(mean, input, gradOutput); - Float2 res = reduce, GradOp, DeviceTensor3>(g, gradOutput, plane); - Acctype gradOutputSum = res.v1; - Acctype dotP = res.v2; - - Acctype gradMean = gradOutputSum * norm; - Acctype projScale = dotP * norm * stdVal * stdVal; - Acctype gradScale = stdVal * weightVal; - - if (gradInput.numElements() > 0) { - for (int batch = 0; batch < gradOutput.getSize(0); ++batch) { - for (int x = threadIdx.x; x < gradOutput.getSize(2); x += blockDim.x) { - Dtype gradOut = gradOutput[batch][plane][x]; - if (train) { - Dtype inp = input[batch][plane][x]; - Acctype proj = (inp - mean) * projScale; - gradInput[batch][plane][x] = ScalarConvert::to((gradOut - proj - gradMean) * gradScale); - } else { - gradInput[batch][plane][x] = ScalarConvert::to(gradOut * gradScale); - } - } - } - } - - if (gradWeight.numElements() > 0) { - if (threadIdx.x == 0) { - gradWeight[plane] += ScalarConvert::to(scale * dotP * stdVal); - } - } - - if (gradBias.numElements() > 0) { - if (threadIdx.x == 0) { - gradBias[plane] += ScalarConvert::to(scale * gradOutputSum); - } - } -} - -#include "generic/BatchNormalization.cu" -#include "THCGenerateFloatTypes.h" diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt index 76382f7b1b04..ce37a2067be6 100644 --- a/aten/src/THCUNN/CMakeLists.txt +++ b/aten/src/THCUNN/CMakeLists.txt @@ -1,7 +1,6 @@ SET(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/AbsCriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/Abs.cu -${CMAKE_CURRENT_SOURCE_DIR}/BatchNormalization.cu ${CMAKE_CURRENT_SOURCE_DIR}/BCECriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/ClassNLLCriterion.cu ${CMAKE_CURRENT_SOURCE_DIR}/Col2Im.cu diff --git a/aten/src/THCUNN/generic/BatchNormalization.cu b/aten/src/THCUNN/generic/BatchNormalization.cu deleted file mode 100644 index 227cc4765679..000000000000 --- a/aten/src/THCUNN/generic/BatchNormalization.cu +++ /dev/null @@ -1,108 +0,0 @@ -#ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/BatchNormalization.cu" -#else - -#define DeviceTensor3 THCDeviceTensor -#define DeviceTensor1 THCDeviceTensor - -template -static THCDeviceTensor THNN_(devicetensor)(THCState *state, THCTensor *t) { - if (!t) { - return THCDeviceTensor(); - } - - int inDim = THCTensor_nDimensionLegacyAll(state, t); - if (inDim == Dim) { - return toDeviceTensor(state, t); - } - - // View in which the last dimensions are collapsed or expanded as needed - THAssert(t->is_contiguous()); - int size[Dim]; - for (int i = 0; i < Dim || i < inDim; ++i) { - if (i < Dim && i < inDim) { - size[i] = THTensor_sizeLegacyNoScalars(t, i); - } else if (i < Dim) { - size[i] = 1; - } else { - size[Dim - 1] *= THTensor_sizeLegacyNoScalars(t, i); - } - } - return THCDeviceTensor(t->data(), size); -} - -void THNN_(BatchNormalization_updateOutput)( - THCState *state, THCTensor *input_, THCTensor *output_, - THCTensor *weight_, THCTensor *bias_, THCTensor *runningMean_, - THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_, - bool train, double momentum, double eps) { - - THCTensor_(resizeAs)(state, output_, input_); - if (train) { - int64_t nInput = THCTensor_(size)(state, input_, 1); - THCTensor_(resize1d)(state, saveMean_, nInput); - THCTensor_(resize1d)(state, saveStd_, nInput); - } - DeviceTensor3 input = THNN_(devicetensor)<3>(state, input_); - DeviceTensor3 output = THNN_(devicetensor)<3>(state, output_); - DeviceTensor1 weight = THNN_(devicetensor)<1>(state, weight_); - DeviceTensor1 bias = THNN_(devicetensor)<1>(state, bias_); - DeviceTensor1 runningMean = THNN_(devicetensor)<1>(state, runningMean_); - DeviceTensor1 runningVar = THNN_(devicetensor)<1>(state, runningVar_); - DeviceTensor1 saveMean = THNN_(devicetensor)<1>(state, saveMean_); - DeviceTensor1 saveStd = THNN_(devicetensor)<1>(state, saveStd_); - - cudaStream_t s = THCState_getCurrentStream(state); - cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state); - - if (!train) { - dim3 blocks(input.getSize(1)); - dim3 threads(getNumThreads(input.getSize(2))); - BatchNormalizationUpdateOutputInference_kernel <<>>( - input, output, runningMean, runningVar, weight, bias, eps); - } else { - dim3 blocks(input.getSize(1)); - dim3 threads(getNumThreads(input.getSize(2))); - BatchNormalizationUpdateOutput_kernel <<>>( - input, output, weight, bias, static_cast(eps), static_cast(momentum), runningMean, runningVar, - saveMean, saveStd); - } - THCudaCheck(cudaGetLastError()); -} - -void THNN_(BatchNormalization_backward)( - THCState *state, THCTensor *input_, THCTensor *gradOutput_, - THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_, - THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_, - THCTensor *saveMean_, THCTensor *saveStd_, bool train, double scale, double eps) { - - THCUNN_check_shape(state, input_, gradOutput_); - if (gradInput_) { - THCTensor_(resizeAs)(state, gradInput_, input_); - } - - DeviceTensor3 input = THNN_(devicetensor)<3>(state, input_); - DeviceTensor3 gradOutput = THNN_(devicetensor)<3>(state, gradOutput_); - DeviceTensor3 gradInput = THNN_(devicetensor)<3>(state, gradInput_); - DeviceTensor1 gradWeight = THNN_(devicetensor)<1>(state, gradWeight_); - DeviceTensor1 gradBias = THNN_(devicetensor)<1>(state, gradBias_); - DeviceTensor1 weight = THNN_(devicetensor)<1>(state, weight_); - DeviceTensor1 runningMean = THNN_(devicetensor)<1>(state, runningMean_); - DeviceTensor1 runningVar = THNN_(devicetensor)<1>(state, runningVar_); - DeviceTensor1 saveMean = THNN_(devicetensor)<1>(state, saveMean_); - DeviceTensor1 saveStd = THNN_(devicetensor)<1>(state, saveStd_); - - cudaStream_t s = THCState_getCurrentStream(state); - - dim3 blocks(gradOutput.getSize(1)); - dim3 threads(getNumThreads(gradOutput.getSize(2))); - BatchNormalizationBackward_kernel <<>>( - input, gradOutput, gradInput, gradWeight, gradBias, weight, runningMean, runningVar, - saveMean, saveStd, train, scale, eps); - THCudaCheck(cudaGetLastError()); -} - -#undef DeviceTensor3 -#undef DeviceTensor1 - -#endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index d22d96cd0d7a..3cc5e0d518ae 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -30,36 +30,6 @@ THC_API void THNN_(AbsCriterion_updateGradInput)( THCTensor *gradInput, int64_t reduction); -THC_API void THNN_(BatchNormalization_updateOutput)( - THCState *state, - THCTensor *input_, - THCTensor *output_, - THCTensor *weight_, // [OPTIONAL] - THCTensor *bias_, // [OPTIONAL] - THCTensor *runningMean_, // [OPTIONAL] if train - THCTensor *runningVar_, // [OPTIONAL] if train - THCTensor *saveMean_, - THCTensor *saveStd_, - bool train, - double momentum, - double eps); - -THC_API void THNN_(BatchNormalization_backward)( - THCState *state, - THCTensor *input_, - THCTensor *gradOutput_, - THCTensor *gradInput_, // [OPTIONAL] - THCTensor *gradWeight_, // [OPTIONAL] - THCTensor *gradBias_, // [OPTIONAL] - THCTensor *weight_, // [OPTIONAL] - THCTensor *runningMean_, // [OPTIONAL] if train - THCTensor *runningVar_, // [OPTIONAL] if train - THCTensor *saveMean_, // [OPTIONAL] if !train - THCTensor *saveStd_, // [OPTIONAL] if !train - bool train, - double scale, - double eps); - THC_API void THNN_(BCECriterion_updateOutput)( THCState *state, THCTensor *input, diff --git a/aten/src/THNN/generic/BatchNormalization.c b/aten/src/THNN/generic/BatchNormalization.c deleted file mode 100644 index 1d481cb59ff4..000000000000 --- a/aten/src/THNN/generic/BatchNormalization.c +++ /dev/null @@ -1,160 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/BatchNormalization.c" -#else - -void THNN_(BatchNormalization_updateOutput)( - THNNState *state, THTensor *input, THTensor *output, - THTensor *weight, THTensor *bias, - THTensor *running_mean, THTensor *running_var, - THTensor *save_mean, THTensor *save_std, - bool train, double momentum, double eps) -{ - THTensor_(resizeAs)(output, input); - int64_t nInput = THTensor_(size)(input, 1); - int64_t f; - ptrdiff_t n = THTensor_(nElement)(input) / nInput; - - if (train) { - THTensor_(resize1d)(save_mean, nInput); - THTensor_(resize1d)(save_std, nInput); - } - - #pragma omp parallel for - for (f = 0; f < nInput; ++f) { - THTensor *in = THTensor_(newSelect)(input, 1, f); - THTensor *out = THTensor_(newSelect)(output, 1, f); - - scalar_t mean, invstd; - - if (train) { - // compute mean per input - accreal sum = 0; - TH_TENSOR_APPLY(scalar_t, in, sum += *in_data;); - - mean = (scalar_t) sum / n; - THTensor_(set1d)(save_mean, f, (scalar_t) mean); - - // compute variance per input - sum = 0; - TH_TENSOR_APPLY(scalar_t, in, - sum += (*in_data - mean) * (*in_data - mean);); - - if (sum == 0 && eps == 0.0) { - invstd = 0; - } else { - invstd = (scalar_t) (1 / sqrt(sum/n + eps)); - } - THTensor_(set1d)(save_std, f, (scalar_t) invstd); - - // update running averages - if (running_mean) { - THTensor_(set1d)(running_mean, f, - (scalar_t) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f))); - } - if (running_var) { - accreal unbiased_var = sum / (n - 1); - THTensor_(set1d)(running_var, f, - (scalar_t) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f))); - } - } else { - mean = THTensor_(get1d)(running_mean, f); - invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps); - } - - // compute output - scalar_t w = weight ? THTensor_(get1d)(weight, f) : 1; - scalar_t b = bias ? THTensor_(get1d)(bias, f) : 0; - - TH_TENSOR_APPLY2(scalar_t, in, scalar_t, out, - *out_data = (scalar_t) (((*in_data - mean) * invstd) * w + b);); - - c10::raw::intrusive_ptr::decref(out); - c10::raw::intrusive_ptr::decref(in); - } -} - -void THNN_(BatchNormalization_backward)( - THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, - THTensor *gradWeight, THTensor *gradBias, THTensor *weight, - THTensor *running_mean, THTensor *running_var, - THTensor *save_mean, THTensor *save_std, - bool train, double scale, double eps) -{ - THNN_CHECK_SHAPE(input, gradOutput); - int64_t nInput = THTensor_(size)(input, 1); - int64_t f; - ptrdiff_t n = THTensor_(nElement)(input) / nInput; - - if (gradInput) { - THTensor_(resizeAs)(gradInput, input); - } - - #pragma omp parallel for - for (f = 0; f < nInput; ++f) { - THTensor *in = THTensor_(newSelect)(input, 1, f); - THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f); - scalar_t w = weight ? THTensor_(get1d)(weight, f) : 1; - scalar_t mean, invstd; - if (train) { - mean = THTensor_(get1d)(save_mean, f); - invstd = THTensor_(get1d)(save_std, f); - } else { - mean = THTensor_(get1d)(running_mean, f); - invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps); - } - - // sum over all gradOutput in feature plane - accreal sum = 0; - TH_TENSOR_APPLY(scalar_t, gradOut, sum += *gradOut_data;); - - // dot product of the Q(X) and gradOuput - accreal dotp = 0; - TH_TENSOR_APPLY2(scalar_t, in, scalar_t, gradOut, - dotp += (*in_data - mean) * (*gradOut_data);); - - if (gradInput) { - THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f); - - if (train) { - // when in training mode - // Q(X) = X - E[x] ; i.e. input centered to zero mean - // Y = Q(X) / σ ; i.e. BN output before weight and bias - // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w - - // projection of gradOutput on to output scaled by std - scalar_t k = (scalar_t) dotp * invstd * invstd / n; - TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, in, - *gradIn_data = (*in_data - mean) * k;); - - accreal gradMean = sum / n; - TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, gradOut, - *gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;); - - } else { - // when in evaluation mode - // Q(X) = X - running_mean ; i.e. input centered to zero mean - // Y = Q(X) / running_std ; i.e. BN output before weight and bias - // dL/dX = w / running_std - TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, gradOut, - *gradIn_data = *gradOut_data * invstd * w;); - } - - c10::raw::intrusive_ptr::decref(gradIn); - } - - if (gradWeight) { - scalar_t val = THTensor_(get1d)(gradWeight, f); - THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd); - } - - if (gradBias) { - scalar_t val = THTensor_(get1d)(gradBias, f); - THTensor_(set1d)(gradBias, f, val + scale * sum); - } - - c10::raw::intrusive_ptr::decref(gradOut); - c10::raw::intrusive_ptr::decref(in); - } -} - -#endif diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index 9ffb83ba4449..a520f1898cea 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -479,35 +479,6 @@ TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)( int osizeW, bool align_corners); -TH_API void THNN_(BatchNormalization_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - THTensor *weight, // [OPTIONAL] - THTensor *bias, // [OPTIONAL] - THTensor *running_mean, // [OPTIONAL] if train - THTensor *running_var, // [OPTIONAL] if train - THTensor *save_mean, - THTensor *save_std, - bool train, - double momentum, - double eps); -TH_API void THNN_(BatchNormalization_backward)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, // [OPTIONAL] - THTensor *gradWeight, // [OPTIONAL] - THTensor *gradBias, // [OPTIONAL] - THTensor *weight, // [OPTIONAL] - THTensor *running_mean, // [OPTIONAL] if train - THTensor *running_var, // [OPTIONAL] if train - THTensor *save_mean, // [OPTIONAL] if !train - THTensor *save_std, // [OPTIONAL] if !train - bool train, - double scale, - double eps); - TH_API void THNN_(SpatialConvolutionMM_updateOutput)( THNNState *state, THTensor *input, diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index e6984cc014c1..2328c4a443ac 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -139,9 +139,6 @@ #include "generic/FeatureLPPooling.c" #include "THGenerateFloatTypes.h" -#include "generic/BatchNormalization.c" -#include "THGenerateFloatTypes.h" - #include "generic/unfold.c" #include "THGenerateFloatTypes.h" diff --git a/test/test_nn.py b/test/test_nn.py index 54d6b26d4775..e0d5f4c1c919 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2789,6 +2789,17 @@ class TestNN(NNTestCase): def test_Conv2d_naive_groups_cuda(self, dtype=torch.float): self._test_Conv2d_naive_groups("cuda", dtype) + def test_batchnorm_grad(self): + self._test_batchnorm_grad() + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @skipIfRocm + def test_batchnorm_grad_cuda(self): + self._test_batchnorm_grad("cuda") + if TEST_CUDNN: + with torch.backends.cudnn.flags(enabled=False): + self._test_batchnorm_grad("cuda") + def test_batchnorm_eval(self): self._test_batchnorm_eval() @@ -2796,6 +2807,9 @@ class TestNN(NNTestCase): @skipIfRocm def test_batchnorm_eval_cuda(self, dtype=torch.float): self._test_batchnorm_eval("cuda", dtype) + if TEST_CUDNN: + with torch.backends.cudnn.flags(enabled=False): + self._test_batchnorm_eval("cuda", dtype) def test_batchnorm_simple_average(self): self._test_batchnorm_simple_average() @@ -2804,6 +2818,9 @@ class TestNN(NNTestCase): @skipIfRocm def test_batchnorm_simple_average_cuda(self): self._test_batchnorm_simple_average(torch.cuda.FloatTensor) + if TEST_CUDNN: + with torch.backends.cudnn.flags(enabled=False): + self._test_batchnorm_simple_average(torch.cuda.FloatTensor) def test_MaxPool1d_indices(self): self._test_maxpool_indices(1) @@ -5021,6 +5038,9 @@ class TestNN(NNTestCase): @skipIfRocm def test_batchnorm_update_stats_cuda(self): self._test_batchnorm_update_stats("cuda", torch.float) + if TEST_CUDNN: + with torch.backends.cudnn.flags(enabled=False): + self._test_batchnorm_update_stats("cuda", torch.float) def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self): input = torch.rand(2, 10) @@ -5056,6 +5076,18 @@ class TestNN(NNTestCase): with self.assertRaises(RuntimeError): F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size))) + def _test_batchnorm_grad(self, device="cpu", dtype=torch.double): + bs, n_feat, size_feat = 4, 5, 6 + input = torch.arange(bs * n_feat * size_feat, device=device, + requires_grad=True, dtype=dtype).view(bs, n_feat, size_feat) + weight = torch.arange(1, n_feat + 1, device=device, requires_grad=True, dtype=dtype) + bias = torch.arange(n_feat, device=device, requires_grad=True, dtype=dtype) + running_mean = 1 - torch.arange(n_feat, device=device, dtype=dtype) + running_var = 2 * torch.arange(n_feat, device=device, dtype=dtype) + for training in [False, True]: + _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias, + training, 0.1, 0.0001)) + def _test_batchnorm_eval(self, device="cpu", dtype=torch.float): module = nn.BatchNorm1d(3).to(device, dtype) module.eval() diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index bca641e315a2..98fb7ceb764e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -516,6 +516,14 @@ - name: mvlgamma(Tensor self, int64_t p) self: mvlgamma_backward(grad, self, p) +- name: native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps) + input, weight, bias: native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) + +- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, double eps, std::array output_mask) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) + save_mean: not_implemented("native_batch_norm_backward save_mean") + save_invstd: not_implemented("native_batch_norm_backward save_invstd") + - name: ne_(Tensor self, Scalar other) self: zeros_like(self) @@ -1035,14 +1043,6 @@ - name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding) self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding) -- name: thnn_batch_norm_forward(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps) - self, weight, bias: thnn_batch_norm_backward(grad.contiguous(), self, weight, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask) - -- name: thnn_batch_norm_backward(Tensor grad_output, Tensor self, Tensor weight, Tensor running_mean, Tensor running_var, bool training, double eps, Tensor save_mean, Tensor save_std, std::array output_mask) - self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask) - save_mean: not_implemented("thnn_batch_norm_backward save_mean") - save_std: not_implemented("thnn_batch_norm_backward save_std") - - name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation) self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask) @@ -1293,7 +1293,7 @@ # work.) # NB2: The quotes around the gradient are needed to appease YAML parsing rules. - name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) - input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)" + input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)" # HACK: save_mean and save_var are going to be passed in as # requires_grad variables (even though we'll never backprop through @@ -1325,7 +1325,7 @@ grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) - input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)" + input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)" - name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon) save_mean: not_implemented("miopen_batch_norm_backward save_mean") diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index ef0a1ccb8f61..d0dcf951016b 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -1883,7 +1883,7 @@ std::tuple batchnorm_double_backward( bool training, double eps, const Tensor & save_mean, - const Tensor & save_std, + const Tensor & save_invstd, std::array output_mask) { bool affine = gamma.defined(); @@ -1907,9 +1907,12 @@ std::tuple batchnorm_double_backward( for (auto s : input.sizes().slice(2)) { M *= s; } - auto mu = unsqueeze_dim1(training ? save_mean : running_mean, input); + // for half inputs, save_mean, save_invstd are float (ideally, we would cast + // everything else, but not now) + auto mu = unsqueeze_dim1(training ? save_mean.to(input.type().scalarType()) : running_mean, input); auto input_sub_mu = input - mu; - auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), input); + auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_invstd.to(input.type().scalarType()) + : running_var.add(Scalar(eps)).pow(-0.5), input); auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2); auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);