mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Native batch norm (#13263)
Summary: - Move batch norm from TH(CU)NN to native - Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode) - It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Compared to the ill-fated #12368 - I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own) - I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`. - I have merged master Pull Request resolved: https://github.com/pytorch/pytorch/pull/13263 Differential Revision: D12946254 Pulled By: SsnL fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
This commit is contained in:
committed by
Facebook Github Bot
parent
392ca1e59f
commit
14004cbef6
@ -207,13 +207,13 @@ public:
|
|||||||
// cast the data pointer to a __restrict__ pointer.
|
// cast the data pointer to a __restrict__ pointer.
|
||||||
// In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor
|
// In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor
|
||||||
// as an argument.
|
// as an argument.
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
PackedTensorAccessor<T,N,PtrTraits> packed_accessor() const& {
|
PackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const& {
|
||||||
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");
|
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");
|
||||||
AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim());
|
AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim());
|
||||||
return PackedTensorAccessor<T,N,PtrTraits>(static_cast<typename PtrTraits<T>::PtrType>(data<T>()),sizes().data(),strides().data());
|
return PackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data<T>()),sizes().data(),strides().data());
|
||||||
}
|
}
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
PackedTensorAccessor<T,N> packed_accessor() && = delete;
|
PackedTensorAccessor<T,N> packed_accessor() && = delete;
|
||||||
|
|
||||||
Tensor operator-() const;
|
Tensor operator-() const;
|
||||||
|
@ -26,15 +26,15 @@ struct RestrictPtrTraits {
|
|||||||
// to functions and types available there (e.g. IntList isn't).
|
// to functions and types available there (e.g. IntList isn't).
|
||||||
|
|
||||||
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
class TensorAccessorBase {
|
class TensorAccessorBase {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
|
|
||||||
C10_HOST_DEVICE TensorAccessorBase(
|
C10_HOST_DEVICE TensorAccessorBase(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
: data_(data_), sizes_(sizes_), strides_(strides_) {}
|
||||||
C10_HOST IntList sizes() const {
|
C10_HOST IntList sizes() const {
|
||||||
return IntList(sizes_,N);
|
return IntList(sizes_,N);
|
||||||
@ -42,60 +42,62 @@ public:
|
|||||||
C10_HOST IntList strides() const {
|
C10_HOST IntList strides() const {
|
||||||
return IntList(strides_,N);
|
return IntList(strides_,N);
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE int64_t stride(int64_t i) const {
|
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||||
return strides_[i];
|
return strides_[i];
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE int64_t size(int64_t i) const {
|
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||||
return sizes_[i];
|
return sizes_[i];
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE T* data() {
|
C10_HOST_DEVICE PtrType data() {
|
||||||
return data_;
|
return data_;
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE const T* data() const {
|
C10_HOST_DEVICE const PtrType data() const {
|
||||||
return data_;
|
return data_;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
PtrType data_;
|
PtrType data_;
|
||||||
const int64_t* sizes_;
|
const index_t* sizes_;
|
||||||
const int64_t* strides_;
|
const index_t* strides_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using
|
||||||
// `Tensor.accessor<T, N>()`.
|
// `Tensor.accessor<T, N>()`.
|
||||||
// For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only
|
// For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only
|
||||||
// indexing on the device uses `TensorAccessor`s.
|
// indexing on the device uses `TensorAccessor`s.
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits> {
|
class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
|
|
||||||
C10_HOST_DEVICE TensorAccessor(
|
C10_HOST_DEVICE TensorAccessor(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: TensorAccessorBase<T, N>(data_, sizes_, strides_) {}
|
: TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||||
|
|
||||||
C10_HOST_DEVICE TensorAccessor<T, N - 1> operator[](int64_t i) {
|
C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||||
return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_HOST_DEVICE const TensorAccessor<T, N - 1> operator[](int64_t i) const {
|
C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {
|
||||||
return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T, template <typename U> class PtrTraits>
|
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||||
class TensorAccessor<T,1,PtrTraits> : public TensorAccessorBase<T,1,PtrTraits> {
|
class TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
|
|
||||||
C10_HOST_DEVICE TensorAccessor(
|
C10_HOST_DEVICE TensorAccessor(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: TensorAccessorBase<T, 1, PtrTraits>(data_, sizes_, strides_) {}
|
: TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
|
||||||
C10_HOST_DEVICE T& operator[](int64_t i) {
|
C10_HOST_DEVICE T & operator[](index_t i) {
|
||||||
|
return this->data_[this->strides_[0]*i];
|
||||||
|
}
|
||||||
|
C10_HOST_DEVICE const T & operator[](index_t i) const {
|
||||||
return this->data_[this->strides_[0]*i];
|
return this->data_[this->strides_[0]*i];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -109,69 +111,104 @@ public:
|
|||||||
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.
|
||||||
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available
|
||||||
// on the device, so those functions are host only.
|
// on the device, so those functions are host only.
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
class PackedTensorAccessorBase {
|
class PackedTensorAccessorBase {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
C10_HOST PackedTensorAccessorBase(
|
C10_HOST PackedTensorAccessorBase(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: data_(data_) {
|
: data_(data_) {
|
||||||
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));
|
||||||
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
std::copy(strides_, strides_ + N, std::begin(this->strides_));
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE int64_t stride(int64_t i) const {
|
|
||||||
|
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||||
|
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
|
||||||
|
C10_HOST PackedTensorAccessorBase(
|
||||||
|
PtrType data_,
|
||||||
|
const source_index_t* sizes_,
|
||||||
|
const source_index_t* strides_)
|
||||||
|
: data_(data_) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
this->sizes_[i] = sizes_[i];
|
||||||
|
this->strides_[i] = strides_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_HOST_DEVICE index_t stride(index_t i) const {
|
||||||
return strides_[i];
|
return strides_[i];
|
||||||
}
|
}
|
||||||
C10_HOST_DEVICE int64_t size(int64_t i) const {
|
C10_HOST_DEVICE index_t size(index_t i) const {
|
||||||
return sizes_[i];
|
return sizes_[i];
|
||||||
}
|
}
|
||||||
|
C10_HOST_DEVICE PtrType data() {
|
||||||
|
return data_;
|
||||||
|
}
|
||||||
|
C10_HOST_DEVICE const PtrType data() const {
|
||||||
|
return data_;
|
||||||
|
}
|
||||||
protected:
|
protected:
|
||||||
PtrType data_;
|
PtrType data_;
|
||||||
int64_t sizes_[N];
|
index_t sizes_[N];
|
||||||
int64_t strides_[N];
|
index_t strides_[N];
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
class PackedTensorAccessor : public PackedTensorAccessorBase<T,N,PtrTraits> {
|
class PackedTensorAccessor : public PackedTensorAccessorBase<T,N,PtrTraits,index_t> {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
|
|
||||||
C10_HOST PackedTensorAccessor(
|
C10_HOST PackedTensorAccessor(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: PackedTensorAccessorBase<T, N, PtrTraits>(data_, sizes_, strides_){};
|
: PackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||||
|
|
||||||
C10_DEVICE TensorAccessor<T, N - 1> operator[](int64_t i) {
|
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||||
int64_t* new_sizes = this->sizes_+1;
|
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
|
||||||
int64_t* new_strides = this->strides_+1;
|
C10_HOST PackedTensorAccessor(
|
||||||
return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
PtrType data_,
|
||||||
|
const source_index_t* sizes_,
|
||||||
|
const source_index_t* strides_)
|
||||||
|
: PackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||||
|
|
||||||
|
C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {
|
||||||
|
index_t* new_sizes = this->sizes_ + 1;
|
||||||
|
index_t* new_strides = this->strides_ + 1;
|
||||||
|
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_DEVICE const TensorAccessor<T, N - 1> operator[](int64_t i) const {
|
C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {
|
||||||
int64_t* new_sizes = this->sizes_+1;
|
const index_t* new_sizes = this->sizes_ + 1;
|
||||||
int64_t* new_strides = this->strides_+1;
|
const index_t* new_strides = this->strides_ + 1;
|
||||||
return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T, template <typename U> class PtrTraits>
|
template<typename T, template <typename U> class PtrTraits, typename index_t>
|
||||||
class PackedTensorAccessor<T,1,PtrTraits> : public PackedTensorAccessorBase<T,1,PtrTraits> {
|
class PackedTensorAccessor<T,1,PtrTraits,index_t> : public PackedTensorAccessorBase<T,1,PtrTraits,index_t> {
|
||||||
public:
|
public:
|
||||||
typedef typename PtrTraits<T>::PtrType PtrType;
|
typedef typename PtrTraits<T>::PtrType PtrType;
|
||||||
C10_HOST PackedTensorAccessor(
|
C10_HOST PackedTensorAccessor(
|
||||||
PtrType data_,
|
PtrType data_,
|
||||||
const int64_t* sizes_,
|
const index_t* sizes_,
|
||||||
const int64_t* strides_)
|
const index_t* strides_)
|
||||||
: PackedTensorAccessorBase<T, 1, PtrTraits>(data_, sizes_, strides_){};
|
: PackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||||
|
|
||||||
C10_DEVICE T& operator[](int64_t i) {
|
// if index_t is not int64_t, we want to have an int64_t constructor
|
||||||
return this->data_[this->strides_[0]*i];
|
template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>
|
||||||
|
C10_HOST PackedTensorAccessor(
|
||||||
|
PtrType data_,
|
||||||
|
const source_index_t* sizes_,
|
||||||
|
const source_index_t* strides_)
|
||||||
|
: PackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
|
||||||
|
|
||||||
|
C10_DEVICE T & operator[](index_t i) {
|
||||||
|
return this->data_[this->strides_[0] * i];
|
||||||
}
|
}
|
||||||
C10_DEVICE const T& operator[](int64_t i) const {
|
C10_DEVICE const T& operator[](index_t i) const {
|
||||||
return this->data_[this->strides_[0]*i];
|
return this->data_[this->strides_[0]*i];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -482,6 +482,8 @@ _(aten, mv) \
|
|||||||
_(aten, mvlgamma) \
|
_(aten, mvlgamma) \
|
||||||
_(aten, narrow) \
|
_(aten, narrow) \
|
||||||
_(aten, narrow_copy) \
|
_(aten, narrow_copy) \
|
||||||
|
_(aten, native_batch_norm) \
|
||||||
|
_(aten, native_batch_norm_backward) \
|
||||||
_(aten, native_clone) \
|
_(aten, native_clone) \
|
||||||
_(aten, native_get_device) \
|
_(aten, native_get_device) \
|
||||||
_(aten, native_norm) \
|
_(aten, native_norm) \
|
||||||
@ -634,9 +636,6 @@ _(aten, th_pow) \
|
|||||||
_(aten, th_resize_as) \
|
_(aten, th_resize_as) \
|
||||||
_(aten, th_tensor) \
|
_(aten, th_tensor) \
|
||||||
_(aten, th_zero) \
|
_(aten, th_zero) \
|
||||||
_(aten, thnn_batch_norm) \
|
|
||||||
_(aten, thnn_batch_norm_backward) \
|
|
||||||
_(aten, thnn_batch_norm_forward) \
|
|
||||||
_(aten, thnn_conv2d) \
|
_(aten, thnn_conv2d) \
|
||||||
_(aten, thnn_conv2d_backward) \
|
_(aten, thnn_conv2d_backward) \
|
||||||
_(aten, thnn_conv2d_forward) \
|
_(aten, thnn_conv2d_forward) \
|
||||||
|
@ -9,7 +9,7 @@ namespace cuda {
|
|||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
bool maybeOverlappingIndices(const at::Tensor& t);
|
bool maybeOverlappingIndices(const at::Tensor& t);
|
||||||
bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits<int64_t>::max());
|
bool canUse32BitIndexMath(const at::Tensor &t, int64_t max_elem=std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
template <typename scalar, typename IndexType>
|
template <typename scalar, typename IndexType>
|
||||||
TensorInfo<scalar, IndexType>
|
TensorInfo<scalar, IndexType>
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#include "ATen/ATen.h"
|
#include "ATen/ATen.h"
|
||||||
#include "ATen/NativeFunctions.h"
|
#include "ATen/NativeFunctions.h"
|
||||||
|
#include "ATen/AccumulateType.h"
|
||||||
|
#include "ATen/CPUApplyUtils.h"
|
||||||
|
#include "ATen/Parallel.h"
|
||||||
#include "ATen/Config.h"
|
#include "ATen/Config.h"
|
||||||
|
|
||||||
#include "ATen/detail/CUDAHooksInterface.h"
|
#include "ATen/detail/CUDAHooksInterface.h"
|
||||||
@ -25,6 +27,198 @@ namespace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TensorAccessor when it is defined to work around undefined...
|
||||||
|
template <typename scalar_t>
|
||||||
|
static TensorAccessor<scalar_t, 1> conditional_accessor_1d(const Tensor& t) {
|
||||||
|
if (! t.defined()) {
|
||||||
|
return TensorAccessor<scalar_t, 1>(nullptr, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
return t.accessor<scalar_t, 1>();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias,
|
||||||
|
const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {
|
||||||
|
|
||||||
|
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||||
|
Tensor output = at::empty_like(input);
|
||||||
|
|
||||||
|
int64_t n_input = input.size(1);
|
||||||
|
int64_t n = input.numel() / n_input;
|
||||||
|
|
||||||
|
Tensor save_mean;
|
||||||
|
Tensor save_invstd;
|
||||||
|
const int64_t zero = 0;
|
||||||
|
if (train) {
|
||||||
|
save_mean = at::empty({n_input}, input.options());
|
||||||
|
save_invstd = at::empty({n_input}, input.options());
|
||||||
|
}
|
||||||
|
auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
|
||||||
|
auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
|
||||||
|
|
||||||
|
auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
|
||||||
|
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
|
||||||
|
|
||||||
|
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
|
||||||
|
for (int64_t f = b_begin; f < b_end; ++f) {
|
||||||
|
Tensor in = input.select(1, f);
|
||||||
|
Tensor out = output.select(1, f);
|
||||||
|
|
||||||
|
scalar_t mean, invstd;
|
||||||
|
|
||||||
|
if (train) {
|
||||||
|
// compute mean per input
|
||||||
|
accscalar_t sum = 0;
|
||||||
|
CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
|
||||||
|
sum += i;
|
||||||
|
});
|
||||||
|
|
||||||
|
mean = (scalar_t) (sum / n);
|
||||||
|
save_mean_a[f] = mean;
|
||||||
|
|
||||||
|
// compute variance per input
|
||||||
|
sum = 0;
|
||||||
|
CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
|
||||||
|
sum += (i - mean) * (i - mean);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sum == 0 && eps == 0.0) {
|
||||||
|
invstd = 0;
|
||||||
|
} else {
|
||||||
|
invstd = (scalar_t) (1 / std::sqrt(sum/n + eps));
|
||||||
|
}
|
||||||
|
save_invstd_a[f] = invstd;
|
||||||
|
|
||||||
|
// update running averages
|
||||||
|
if (running_mean.defined()) {
|
||||||
|
running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
|
||||||
|
}
|
||||||
|
if (running_var.defined()) {
|
||||||
|
accscalar_t unbiased_var = sum / (n - 1);
|
||||||
|
running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mean = running_mean_a[f];
|
||||||
|
invstd = 1 / std::sqrt(running_var_a[f] + eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute output
|
||||||
|
scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
|
||||||
|
scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
|
||||||
|
|
||||||
|
CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
|
||||||
|
o = ((i - mean) * invstd) * w + b;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return std::make_tuple(output, save_mean, save_invstd);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
|
||||||
|
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
|
||||||
|
bool train, double eps, std::array<bool,3> grad_input_mask) {
|
||||||
|
|
||||||
|
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||||
|
|
||||||
|
Tensor grad_input;
|
||||||
|
Tensor grad_weight;
|
||||||
|
Tensor grad_bias;
|
||||||
|
if (grad_input_mask[0]) {
|
||||||
|
grad_input = at::empty_like(input);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[1]) {
|
||||||
|
grad_weight = at::empty_like(weight);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[2]) {
|
||||||
|
grad_bias = at::empty_like(weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto weight_a = conditional_accessor_1d<scalar_t>(weight);
|
||||||
|
auto grad_weight_a = conditional_accessor_1d<scalar_t>(grad_weight);
|
||||||
|
auto grad_bias_a = conditional_accessor_1d<scalar_t>(grad_bias);
|
||||||
|
|
||||||
|
int64_t n_input = input.size(1);
|
||||||
|
int64_t n = input.numel() / n_input;
|
||||||
|
|
||||||
|
auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
|
||||||
|
auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
|
||||||
|
|
||||||
|
auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
|
||||||
|
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
|
||||||
|
|
||||||
|
|
||||||
|
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
|
||||||
|
for (int64_t f = b_begin; f < b_end; ++f) {
|
||||||
|
Tensor in = input.select(1, f);
|
||||||
|
Tensor grad_out = grad_out_.select(1, f);
|
||||||
|
|
||||||
|
scalar_t w = weight.defined() ? weight_a[f] : 1;
|
||||||
|
|
||||||
|
scalar_t mean, invstd;
|
||||||
|
if (train) {
|
||||||
|
mean = save_mean_a[f];
|
||||||
|
invstd = save_invstd_a[f];
|
||||||
|
} else {
|
||||||
|
mean = running_mean_a[f];
|
||||||
|
invstd = 1 / std::sqrt(running_var_a[f] + eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum over all gradOutput in feature plane
|
||||||
|
accscalar_t sum = 0;
|
||||||
|
CPU_tensor_apply1<scalar_t>(grad_out, [&](const scalar_t& g) {
|
||||||
|
sum += g;
|
||||||
|
});
|
||||||
|
|
||||||
|
// dot product of the Q(X) and gradOuput
|
||||||
|
accscalar_t dotp = 0;
|
||||||
|
CPU_tensor_apply2<scalar_t,scalar_t>(in, grad_out, [&](const scalar_t& i, const scalar_t& go) {
|
||||||
|
dotp += (i - mean) * go;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (grad_input_mask[0]) {
|
||||||
|
Tensor grad_in = grad_input.select(1, f);
|
||||||
|
if (train) {
|
||||||
|
// when in training mode
|
||||||
|
// Q(X) = X - E[x] ; i.e. input centered to zero mean
|
||||||
|
// Y = Q(X) / σ ; i.e. BN output before weight and bias
|
||||||
|
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
|
||||||
|
|
||||||
|
// projection of gradOutput on to output scaled by std
|
||||||
|
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
|
||||||
|
|
||||||
|
CPU_tensor_apply2<scalar_t,scalar_t>(grad_in, in, [&](scalar_t& gi, const scalar_t& i) {
|
||||||
|
gi = (i - mean)* k;
|
||||||
|
});
|
||||||
|
|
||||||
|
accscalar_t grad_mean = sum / n;
|
||||||
|
CPU_tensor_apply2<scalar_t,scalar_t>(grad_in, grad_out, [&](scalar_t& gi, const scalar_t& go) {
|
||||||
|
gi = (go - grad_mean - gi) * invstd * w;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// when in evaluation mode
|
||||||
|
// Q(X) = X - running_mean ; i.e. input centered to zero mean
|
||||||
|
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
|
||||||
|
// dL/dX = w / running_std
|
||||||
|
CPU_tensor_apply2<scalar_t,scalar_t>(grad_in, grad_out, [&](scalar_t& gi, const scalar_t& go) {
|
||||||
|
gi = go * invstd * w;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grad_input_mask[1]) {
|
||||||
|
grad_weight_a[f] = dotp * invstd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_input_mask[2]) {
|
||||||
|
grad_bias_a[f] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||||
|
}
|
||||||
|
|
||||||
Tensor batch_norm(
|
Tensor batch_norm(
|
||||||
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
|
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
|
||||||
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
|
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
|
||||||
@ -86,9 +280,8 @@ Tensor batch_norm(
|
|||||||
training, momentum, eps));
|
training, momentum, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
return at::thnn_batch_norm(
|
return std::get<0>(at::native_batch_norm(input, weight, bias,
|
||||||
input.contiguous(), weight, bias,
|
running_mean, running_var, training, momentum, eps));
|
||||||
running_mean, running_var, training, momentum, eps);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor instance_norm(
|
Tensor instance_norm(
|
||||||
@ -226,4 +419,20 @@ Tensor group_norm(const Tensor& input, int64_t num_groups,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
|
||||||
|
const Tensor& running_mean, const Tensor& running_var,
|
||||||
|
bool train, double momentum, double eps) {
|
||||||
|
return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
|
||||||
|
return batch_norm_cpu_template<scalar_t>(self, weight, bias, running_mean, running_var, train, momentum, eps);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const Tensor& weight,
|
||||||
|
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
|
||||||
|
bool train, double eps, std::array<bool,3> grad_input_mask) {
|
||||||
|
return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_backward", [&] {
|
||||||
|
return batch_norm_backward_cpu_template<scalar_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, eps, grad_input_mask);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
}} // at::native
|
}} // at::native
|
||||||
|
536
aten/src/ATen/native/cuda/Normalization.cu
Normal file
536
aten/src/ATen/native/cuda/Normalization.cu
Normal file
@ -0,0 +1,536 @@
|
|||||||
|
#include <THC/THCDeviceUtils.cuh>
|
||||||
|
#include <THC/THCGeneral.h>
|
||||||
|
#include "ATen/ATen.h"
|
||||||
|
#include "ATen/AccumulateType.h"
|
||||||
|
#include "ATen/cuda/CUDAContext.h"
|
||||||
|
#include "ATen/cuda/CUDAApplyUtils.cuh"
|
||||||
|
|
||||||
|
namespace at { namespace native {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(__HIP_PLATFORM_HCC__)
|
||||||
|
constexpr int WARP_SIZE = 64;
|
||||||
|
|
||||||
|
// take these out when ROCm implements std:: math functions
|
||||||
|
#include <math.h>
|
||||||
|
template <typename scalar_t>
|
||||||
|
static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__forceinline__ __device__ float device_sqrt(float val) {
|
||||||
|
return ::sqrtf(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__forceinline__ __device__ double device_sqrt(double val) {
|
||||||
|
return ::sqrt(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
constexpr int WARP_SIZE = 32;
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__forceinline__ __device__ double device_sqrt(scalar_t val) {
|
||||||
|
return std::sqrt(val);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// The maximum number of threads in a block
|
||||||
|
#if defined(__HIP_PLATFORM_HCC__)
|
||||||
|
constexpr int MAX_BLOCK_SIZE = 256;
|
||||||
|
#else
|
||||||
|
constexpr int MAX_BLOCK_SIZE = 512;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||||
|
static int getNumThreads(int nElem) {
|
||||||
|
#if defined(__HIP_PLATFORM_HCC__)
|
||||||
|
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||||
|
#else
|
||||||
|
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i != 5; ++i) {
|
||||||
|
if (nElem <= threadSizes[i]) {
|
||||||
|
return threadSizes[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MAX_BLOCK_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the index of the most significant 1 bit in `val`.
|
||||||
|
__device__ __forceinline__ int getMSB(int val) {
|
||||||
|
return 31 - __clz(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t>
|
||||||
|
struct Float2 {
|
||||||
|
accscalar_t v1, v2;
|
||||||
|
__device__ Float2() {}
|
||||||
|
__device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast<accscalar_t>(v1)), v2(static_cast<accscalar_t>(v2)) {}
|
||||||
|
__device__ Float2(int v) : v1(static_cast<accscalar_t>(v)), v2(static_cast<accscalar_t>(v)) {}
|
||||||
|
__device__ Float2& operator+=(const Float2& a) {
|
||||||
|
v1 += a.v1;
|
||||||
|
v2 += a.v2;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, typename PTA>
|
||||||
|
struct SumOp {
|
||||||
|
__device__ SumOp(const PTA& t) : tensor(t) {}
|
||||||
|
__device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) {
|
||||||
|
return static_cast<accscalar_t>(tensor[batch][plane][n]);
|
||||||
|
}
|
||||||
|
const PTA& tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, typename PTA>
|
||||||
|
struct VarOp {
|
||||||
|
__device__ VarOp(accscalar_t m, const PTA& t) : mean(m), tensor(t) {}
|
||||||
|
__device__ __forceinline__ accscalar_t operator()(int batch, int plane, int n) {
|
||||||
|
accscalar_t val = tensor[batch][plane][n];
|
||||||
|
return (val - mean) * (val - mean);
|
||||||
|
}
|
||||||
|
const accscalar_t mean;
|
||||||
|
const PTA& tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, typename PTA>
|
||||||
|
struct GradOp {
|
||||||
|
__device__ GradOp(accscalar_t m, const PTA& i, const PTA& g)
|
||||||
|
: mean(m), input(i), grad_output(g) {}
|
||||||
|
__device__ __forceinline__ Float2<scalar_t, accscalar_t> operator()(int batch, int plane, int n) {
|
||||||
|
accscalar_t g = grad_output[batch][plane][n];
|
||||||
|
accscalar_t c = static_cast<accscalar_t>(input[batch][plane][n]) - mean;
|
||||||
|
return Float2<scalar_t, accscalar_t>(g, g * c);
|
||||||
|
}
|
||||||
|
const accscalar_t mean;
|
||||||
|
const PTA& input;
|
||||||
|
const PTA& grad_output;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sum across all threads within a warp
|
||||||
|
template <typename T>
|
||||||
|
static __device__ __forceinline__ T warpSum(T val) {
|
||||||
|
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
||||||
|
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t>
|
||||||
|
static __device__ __forceinline__ Float2<scalar_t, accscalar_t> warpSum(Float2<scalar_t, accscalar_t> value) {
|
||||||
|
value.v1 = warpSum(value.v1);
|
||||||
|
value.v2 = warpSum(value.v2);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sum across (batch, x/y/z) applying Op() pointwise
|
||||||
|
// this works by first having each thread sum it's part
|
||||||
|
// of the data. Then there is a double-shuffeling reduction.
|
||||||
|
// First each warp (of WARP_SIZE threads) uses warpSum to reduce its
|
||||||
|
// data to the "warp leader", who writes its value into shared memory.
|
||||||
|
// Then a single warp reads the remaining (at most WARP_SIZE) items
|
||||||
|
// and reduces them using another warpSum.
|
||||||
|
// The implicit assumption is that there are no more
|
||||||
|
// than WARP_SIZE**2 threads.
|
||||||
|
template<typename scalar_t, typename Op, typename PTA>
|
||||||
|
__device__ scalar_t reduce(Op op, PTA tensor, int plane) {
|
||||||
|
// first the reductions each thread does separately
|
||||||
|
scalar_t sum = static_cast<scalar_t>(0);
|
||||||
|
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
|
||||||
|
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
|
||||||
|
sum += op(batch, plane, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// first warpSum to get one value per thread to
|
||||||
|
// one value per warp
|
||||||
|
sum = warpSum(sum);
|
||||||
|
|
||||||
|
// this writes each warps item into shared memory
|
||||||
|
// there are at most WARP_SIZE items left because
|
||||||
|
// there are at most WARP_SIZE**2 threads at the beginning
|
||||||
|
__shared__ scalar_t shared[WARP_SIZE];
|
||||||
|
__syncthreads();
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
if (tid % WARP_SIZE == 0) {
|
||||||
|
shared[tid / WARP_SIZE] = sum;
|
||||||
|
}
|
||||||
|
if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) {
|
||||||
|
// zero out the other entries in shared
|
||||||
|
shared[tid] = (scalar_t)0;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// now have a second warpSum to reduce the intermediate values
|
||||||
|
// from shared memory to a single number. The very first
|
||||||
|
// thread writes it to shared memory.
|
||||||
|
|
||||||
|
if (tid / WARP_SIZE == 0) {
|
||||||
|
sum = warpSum(shared[tid]);
|
||||||
|
if (tid == 0) {
|
||||||
|
shared[0] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Everyone picks it up, should be broadcast into the whole grad_input
|
||||||
|
return shared[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, bool train, typename index_t>
|
||||||
|
__global__ void batch_norm_transform_input_kernel(
|
||||||
|
const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
|
||||||
|
PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> output,
|
||||||
|
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
|
||||||
|
const PackedTensorAccessor<typename std::conditional<train, accscalar_t, scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
|
||||||
|
const PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> weight,
|
||||||
|
const PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> bias,
|
||||||
|
accscalar_t epsilon) {
|
||||||
|
|
||||||
|
index_t plane = blockIdx.x;
|
||||||
|
|
||||||
|
if (plane >= input.size(1)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
accscalar_t gamma = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : static_cast<accscalar_t>(1);
|
||||||
|
accscalar_t beta = bias.size(0) > 0 ? static_cast<accscalar_t>(bias[plane]) : static_cast<accscalar_t>(0);
|
||||||
|
accscalar_t mean = static_cast<accscalar_t>(mean_[plane]);
|
||||||
|
accscalar_t invstd;
|
||||||
|
if (train) {
|
||||||
|
invstd = var_or_invstd[plane];
|
||||||
|
} else {
|
||||||
|
invstd = static_cast<accscalar_t>(1) / device_sqrt(static_cast<accscalar_t>(var_or_invstd[plane]) + epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
index_t bs = input.size(0);
|
||||||
|
index_t fs = input.size(2);
|
||||||
|
|
||||||
|
index_t bstep = blockDim.y * gridDim.y;
|
||||||
|
for (index_t batch = threadIdx.y + blockIdx.y * blockDim.y; batch < bs; batch += bstep) {
|
||||||
|
auto o = output[batch][plane];
|
||||||
|
auto i = input[batch][plane];
|
||||||
|
for (index_t feature = threadIdx.x; feature < fs; feature += blockDim.x) {
|
||||||
|
o[feature] = static_cast<scalar_t>(gamma * (i[feature] - mean) * invstd + beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||||
|
__global__ void batch_norm_collect_statistics_kernel(
|
||||||
|
const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
|
||||||
|
const accscalar_t epsilon,
|
||||||
|
const accscalar_t momentum,
|
||||||
|
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
|
||||||
|
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
|
||||||
|
PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
|
||||||
|
PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_invstd) {
|
||||||
|
|
||||||
|
__shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];
|
||||||
|
|
||||||
|
int plane = blockIdx.x;
|
||||||
|
int N = input.size(0) * input.size(2);
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
|
||||||
|
// Compute the mean and variance across (batch, x/y/z)
|
||||||
|
// this uses the Welford (in the for loop)/parallel algorithm (to sum across the block)
|
||||||
|
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||||
|
// and the parallel algorithm on the same page.
|
||||||
|
// We use two shuffles to reduce across the entire block.
|
||||||
|
// https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description.
|
||||||
|
accscalar_t* shared_avg_var = (accscalar_t*) &shared_n[WARP_SIZE];
|
||||||
|
|
||||||
|
// first the reductions each thread does separately
|
||||||
|
accscalar_t avg = 0;
|
||||||
|
accscalar_t var_n = 0;
|
||||||
|
int n = 0;
|
||||||
|
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
|
||||||
|
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
|
||||||
|
accscalar_t v = input[batch][plane][x];
|
||||||
|
accscalar_t d1 = v - avg;
|
||||||
|
n++;
|
||||||
|
avg += d1 / n;
|
||||||
|
var_n += d1 * (v - avg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// first warpSum to get one value per thread to
|
||||||
|
// one value per warp
|
||||||
|
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
||||||
|
accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);
|
||||||
|
int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE);
|
||||||
|
if (n + o_n > 0) {
|
||||||
|
var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + ((avg - o_avg) * (avg - o_avg) * n * o_n) / (n + o_n);
|
||||||
|
avg = (n * avg + o_n * o_avg)/(n+o_n);
|
||||||
|
n += o_n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// this writes each warps item into shared memory
|
||||||
|
// there are at most WARP_SIZE items left because
|
||||||
|
// there are at most WARP_SIZE**2 threads at the beginning
|
||||||
|
__syncthreads();
|
||||||
|
if (tid % WARP_SIZE == 0) {
|
||||||
|
shared_n[tid / WARP_SIZE] = n;
|
||||||
|
shared_avg_var[tid / WARP_SIZE * 2] = avg;
|
||||||
|
shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// now have a second warpSum to reduce the intermediate values
|
||||||
|
// from shared memory to a single number. The very first
|
||||||
|
// thread writes it to shared memory.
|
||||||
|
|
||||||
|
if (tid < WARP_SIZE) {
|
||||||
|
n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0);
|
||||||
|
avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : 0);
|
||||||
|
var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : 0);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
||||||
|
accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE);
|
||||||
|
int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE);
|
||||||
|
if (n + o_n > 0) {
|
||||||
|
var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + ((avg - o_avg) * (avg - o_avg) * n * o_n) / (n + o_n);
|
||||||
|
avg = (n * avg + o_n * o_avg)/(n+o_n);
|
||||||
|
n += o_n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the mean, variance, and moving averages
|
||||||
|
if (tid == 0) {
|
||||||
|
accscalar_t invstd = 0;
|
||||||
|
if (var_n != static_cast<accscalar_t>(0) || epsilon != static_cast<accscalar_t>(0)) {
|
||||||
|
invstd = static_cast<accscalar_t>(1) / device_sqrt(var_n / N + epsilon);
|
||||||
|
}
|
||||||
|
save_mean[plane] = avg;
|
||||||
|
save_invstd[plane] = invstd;
|
||||||
|
if (running_mean.data() != NULL) {
|
||||||
|
running_mean[plane] = static_cast<scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
|
||||||
|
}
|
||||||
|
if (running_var.data() != NULL) {
|
||||||
|
accscalar_t unbiasedVar = var_n / (N - 1);
|
||||||
|
running_var[plane] = static_cast<scalar_t>((1 - momentum) * running_var[plane] + momentum * unbiasedVar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename accscalar_t, typename index_t>
|
||||||
|
__global__ void batch_norm_backward_kernel(
|
||||||
|
const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> input,
|
||||||
|
const PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_output,
|
||||||
|
PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t> grad_input,
|
||||||
|
PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_weight,
|
||||||
|
PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> grad_bias,
|
||||||
|
const PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> weight,
|
||||||
|
const PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> running_mean,
|
||||||
|
const PackedTensorAccessor<scalar_t, 1, DefaultPtrTraits, index_t> running_var,
|
||||||
|
const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> save_mean,
|
||||||
|
const PackedTensorAccessor<accscalar_t, 1, DefaultPtrTraits, index_t> save_invstd,
|
||||||
|
bool train,
|
||||||
|
accscalar_t epsilon) {
|
||||||
|
|
||||||
|
index_t plane = blockIdx.x;
|
||||||
|
index_t N = grad_output.size(0) * grad_output.size(2);
|
||||||
|
|
||||||
|
accscalar_t mean, invstd;
|
||||||
|
if (train) {
|
||||||
|
mean = save_mean[plane];
|
||||||
|
invstd = save_invstd[plane];
|
||||||
|
} else {
|
||||||
|
mean = static_cast<accscalar_t>(running_mean[plane]);
|
||||||
|
invstd = static_cast<accscalar_t>(1) / device_sqrt(static_cast<accscalar_t>(running_var[plane]) + epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
accscalar_t weight_val = weight.size(0) > 0 ? static_cast<accscalar_t>(weight[plane]) : accscalar_t(1);
|
||||||
|
accscalar_t norm = accscalar_t(1) / N;
|
||||||
|
|
||||||
|
// Compute two values across (batch, x/y/z) in one pass:
|
||||||
|
// 1. Sum(grad_output)
|
||||||
|
// 2. DotProduct(input - mean, grad_output)
|
||||||
|
GradOp<scalar_t, accscalar_t, PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>> g(mean, input, grad_output);
|
||||||
|
Float2<scalar_t, accscalar_t> res = reduce<Float2<scalar_t, accscalar_t>, GradOp<scalar_t, accscalar_t,
|
||||||
|
PackedTensorAccessor<scalar_t, 3, DefaultPtrTraits, index_t>>>(g, grad_output, plane);
|
||||||
|
accscalar_t grad_output_sum = res.v1;
|
||||||
|
accscalar_t dot_p = res.v2;
|
||||||
|
|
||||||
|
accscalar_t grad_mean = grad_output_sum * norm;
|
||||||
|
accscalar_t proj_scale = dot_p * norm * invstd * invstd;
|
||||||
|
accscalar_t grad_scale = invstd * weight_val;
|
||||||
|
|
||||||
|
if (grad_input.data() != NULL) {
|
||||||
|
for (int batch = threadIdx.y; batch < grad_output.size(0); batch += blockDim.y) {
|
||||||
|
for (int x = threadIdx.x; x < grad_output.size(2); x += blockDim.x) {
|
||||||
|
scalar_t go = grad_output[batch][plane][x];
|
||||||
|
if (train) {
|
||||||
|
scalar_t inp = input[batch][plane][x];
|
||||||
|
accscalar_t proj = (inp - mean) * proj_scale;
|
||||||
|
grad_input[batch][plane][x] = static_cast<scalar_t>((go - proj - grad_mean) * grad_scale);
|
||||||
|
} else {
|
||||||
|
grad_input[batch][plane][x] = static_cast<scalar_t>(go * grad_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_weight.size(0) > 0) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
grad_weight[plane] = static_cast<scalar_t>(dot_p * invstd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grad_bias.size(0) > 0) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
grad_bias[plane] = static_cast<scalar_t>(grad_output_sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
|
static PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> packed_accessor_or_dummy(const Tensor& t) {
|
||||||
|
if (! t.defined()) {
|
||||||
|
const std::vector<index_t> zeros(dim);
|
||||||
|
return PackedTensorAccessor<scalar_t, dim, PtrTraits, index_t>(nullptr, zeros.data(), zeros.data());
|
||||||
|
}
|
||||||
|
return t.packed_accessor<scalar_t, dim, PtrTraits, index_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, typename index_t>
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda_template(const Tensor& input_, const Tensor& weight_, const Tensor& bias_,
|
||||||
|
const Tensor& running_mean_, const Tensor& running_var_,
|
||||||
|
bool train, double momentum, double epsilon) {
|
||||||
|
|
||||||
|
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||||
|
int64_t n_input = input_.size(1);
|
||||||
|
Tensor save_mean_;
|
||||||
|
Tensor save_invstd_;
|
||||||
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
|
||||||
|
auto output_reshaped = at::empty_like(input_reshaped);
|
||||||
|
|
||||||
|
auto bs = input_reshaped.size(0);
|
||||||
|
auto features = input_reshaped.size(2);
|
||||||
|
auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto input_options = input_.options();
|
||||||
|
if (input_.type().scalarType() == at::ScalarType::Half) {
|
||||||
|
input_options = input_options.dtype(ScalarType::Float);
|
||||||
|
}
|
||||||
|
if (train) {
|
||||||
|
save_mean_ = at::empty({n_input}, input_options);
|
||||||
|
save_invstd_ = at::empty({n_input}, input_options);
|
||||||
|
} else {
|
||||||
|
save_mean_ = at::empty({0}, input_options);
|
||||||
|
save_invstd_ = at::empty({0}, input_options);
|
||||||
|
}
|
||||||
|
auto output = output_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
|
||||||
|
auto weight = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(weight_);
|
||||||
|
auto bias = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(bias_);
|
||||||
|
auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
|
||||||
|
auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
|
||||||
|
auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto save_invstd = save_invstd_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean,
|
||||||
|
// weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks
|
||||||
|
// and good occupancy. Quiet likely, we could go with even more blocks than 1024.
|
||||||
|
// The various planes are independent, so we use blocks for them.
|
||||||
|
int tf = std::max<int>(getNumThreads(input.size(2)/4),
|
||||||
|
std::min<int>(getNumThreads(input.size(2)), 64));
|
||||||
|
int tb = std::max<int>(64/tf, 1);
|
||||||
|
dim3 blocks_trans(input.size(1), std::max<int>(1, std::min<int>((256*1024)/input.size(1),
|
||||||
|
(input.size(0)+tb-1)/tb)));
|
||||||
|
dim3 threads_trans(tf, tb);
|
||||||
|
if (!train) {
|
||||||
|
batch_norm_transform_input_kernel<scalar_t, accscalar_t, false, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
||||||
|
(input, output, running_mean, running_var, weight, bias, epsilon);
|
||||||
|
} else {
|
||||||
|
// for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
|
||||||
|
// the feature dimension, we'll use some threads for blocks
|
||||||
|
dim3 blocks(input.size(1));
|
||||||
|
tf = getNumThreads(input.size(2));
|
||||||
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
||||||
|
batch_norm_collect_statistics_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
||||||
|
(input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
|
||||||
|
batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
|
||||||
|
(input, output, save_mean, save_invstd, weight, bias, epsilon);
|
||||||
|
}
|
||||||
|
THCudaCheck(cudaGetLastError());
|
||||||
|
return std::make_tuple(output_reshaped.view(input_.sizes()), save_mean_, save_invstd_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, typename index_t>
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda_template(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_,
|
||||||
|
const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_,
|
||||||
|
bool train, double epsilon, std::array<bool,3> grad_input_mask) {
|
||||||
|
|
||||||
|
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||||
|
Tensor grad_input_;
|
||||||
|
Tensor grad_input_reshaped;
|
||||||
|
Tensor grad_weight_;
|
||||||
|
Tensor grad_bias_;
|
||||||
|
auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1});
|
||||||
|
auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes());
|
||||||
|
|
||||||
|
if (grad_input_mask[0]) {
|
||||||
|
grad_input_ = at::empty_like(input_);
|
||||||
|
grad_input_reshaped = grad_input_.view(input_reshaped.sizes());
|
||||||
|
}
|
||||||
|
if (grad_input_mask[1]) {
|
||||||
|
grad_weight_ = at::empty_like(weight_);
|
||||||
|
}
|
||||||
|
if (grad_input_mask[2]) {
|
||||||
|
grad_bias_ = at::empty_like(weight_);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input = input_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
|
||||||
|
auto grad_output = grad_output_reshaped.packed_accessor<scalar_t, 3, DefaultPtrTraits, index_t>();
|
||||||
|
auto grad_input = packed_accessor_or_dummy<scalar_t, 3, DefaultPtrTraits, index_t>(grad_input_reshaped);
|
||||||
|
auto weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(weight_);
|
||||||
|
auto grad_weight = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_weight_);
|
||||||
|
auto grad_bias = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(grad_bias_);
|
||||||
|
auto running_mean = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(running_mean_);
|
||||||
|
auto running_var = packed_accessor_or_dummy<scalar_t, 1, DefaultPtrTraits, index_t>(running_var_);
|
||||||
|
auto save_mean = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(save_mean_);
|
||||||
|
auto save_invstd = packed_accessor_or_dummy<accscalar_t, 1, DefaultPtrTraits, index_t>(save_invstd_);
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
dim3 blocks(input.size(1));
|
||||||
|
int tf = getNumThreads(input.size(2));
|
||||||
|
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
|
||||||
|
|
||||||
|
batch_norm_backward_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
|
||||||
|
(input, grad_output, grad_input, grad_weight, grad_bias, weight, running_mean, running_var,
|
||||||
|
save_mean, save_invstd, train, epsilon);
|
||||||
|
THCudaCheck(cudaGetLastError());
|
||||||
|
|
||||||
|
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_cuda(const Tensor& self, const Tensor& weight, const Tensor& bias,
|
||||||
|
const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double epsilon) {
|
||||||
|
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm", [&] {
|
||||||
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
|
return batch_norm_cuda_template<scalar_t, int32_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
|
} else {
|
||||||
|
return batch_norm_cuda_template<scalar_t, int64_t>(self, weight, bias, running_mean, running_var, train, momentum, epsilon);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var,
|
||||||
|
const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array<bool,3> grad_input_mask) {
|
||||||
|
return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
|
||||||
|
if (cuda::detail::canUse32BitIndexMath(self)) {
|
||||||
|
return batch_norm_backward_cuda_template<scalar_t, int32_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
|
||||||
|
} else {
|
||||||
|
return batch_norm_backward_cuda_template<scalar_t, int64_t>(grad_out, self, weight, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} } // namespace at::native
|
@ -1224,6 +1224,16 @@
|
|||||||
variants: function, method
|
variants: function, method
|
||||||
device_guard: false
|
device_guard: false
|
||||||
|
|
||||||
|
- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps) -> (Tensor, Tensor, Tensor)
|
||||||
|
dispatch:
|
||||||
|
CPU: batch_norm_cpu
|
||||||
|
CUDA: batch_norm_cuda
|
||||||
|
|
||||||
|
- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, double eps, std::array<bool,3> output_mask) -> (Tensor, Tensor, Tensor)
|
||||||
|
dispatch:
|
||||||
|
CPU: batch_norm_backward_cpu
|
||||||
|
CUDA: batch_norm_backward_cuda
|
||||||
|
|
||||||
- func: ones(IntList size, TensorOptions options={}) -> Tensor
|
- func: ones(IntList size, TensorOptions options={}) -> Tensor
|
||||||
|
|
||||||
- func: ones_out(Tensor result, IntList size) -> Tensor
|
- func: ones_out(Tensor result, IntList size) -> Tensor
|
||||||
|
@ -217,18 +217,6 @@
|
|||||||
output: self_->dim() == 0
|
output: self_->dim() == 0
|
||||||
grad_input: output_->dim() == 0
|
grad_input: output_->dim() == 0
|
||||||
|
|
||||||
# Batch normalization
|
|
||||||
|
|
||||||
# The buffers here are somewhat hazardous, because their type will be
|
|
||||||
# based off of self, even though you may plausibly wish running_mean
|
|
||||||
# and running_var to have different precision than self (e.g.,
|
|
||||||
# BatchNorm on half). Fortunately, THNN doesn't actually ever do this,
|
|
||||||
# so the buffer allocation code is "correct". If you ever do fix this,
|
|
||||||
# you should just port the function entirely to a native ATen function.
|
|
||||||
- name: thnn_batch_norm(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
|
|
||||||
cname: BatchNormalization
|
|
||||||
buffers: [save_mean, save_std]
|
|
||||||
|
|
||||||
# Convolutions
|
# Convolutions
|
||||||
|
|
||||||
- name: thnn_conv_transpose2d(Tensor self, Tensor weight, IntList[2] kernel_size, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] output_padding=0, IntList[2] dilation=1)
|
- name: thnn_conv_transpose2d(Tensor self, Tensor weight, IntList[2] kernel_size, Tensor bias={}, IntList[2] stride=1, IntList[2] padding=0, IntList[2] output_padding=0, IntList[2] dilation=1)
|
||||||
|
@ -207,13 +207,13 @@ public:
|
|||||||
// cast the data pointer to a __restrict__ pointer.
|
// cast the data pointer to a __restrict__ pointer.
|
||||||
// In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor
|
// In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor
|
||||||
// as an argument.
|
// as an argument.
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
PackedTensorAccessor<T,N,PtrTraits> packed_accessor() const& {
|
PackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const& {
|
||||||
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");
|
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");
|
||||||
AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim());
|
AT_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim());
|
||||||
return PackedTensorAccessor<T,N,PtrTraits>(static_cast<typename PtrTraits<T>::PtrType>(data<T>()),sizes().data(),strides().data());
|
return PackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data<T>()),sizes().data(),strides().data());
|
||||||
}
|
}
|
||||||
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
|
template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
|
||||||
PackedTensorAccessor<T,N> packed_accessor() && = delete;
|
PackedTensorAccessor<T,N> packed_accessor() && = delete;
|
||||||
|
|
||||||
Tensor operator-() const;
|
Tensor operator-() const;
|
||||||
|
@ -1,303 +0,0 @@
|
|||||||
#include "THCUNN.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "TH/THHalf.h"
|
|
||||||
#include "THCHalfAutoNumerics.cuh"
|
|
||||||
#include "THCTensor.hpp"
|
|
||||||
|
|
||||||
#include "THCDeviceTensor.cuh"
|
|
||||||
#include "THCDeviceTensorUtils.cuh"
|
|
||||||
#include "THCDeviceUtils.cuh"
|
|
||||||
#if defined(__HIP_PLATFORM_HCC__)
|
|
||||||
const int WARP_SIZE = 64;
|
|
||||||
#else
|
|
||||||
const int WARP_SIZE = 32;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// The maximum number of threads in a block
|
|
||||||
#if defined(__HIP_PLATFORM_HCC__)
|
|
||||||
const int MAX_BLOCK_SIZE = 256;
|
|
||||||
#else
|
|
||||||
const int MAX_BLOCK_SIZE = 512;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
|
||||||
static int getNumThreads(int nElem) {
|
|
||||||
#if defined(__HIP_PLATFORM_HCC__)
|
|
||||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
|
||||||
#else
|
|
||||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
|
||||||
#endif
|
|
||||||
for (int i = 0; i != 5; ++i) {
|
|
||||||
if (nElem <= threadSizes[i]) {
|
|
||||||
return threadSizes[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return MAX_BLOCK_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the index of the most significant 1 bit in `val`.
|
|
||||||
__device__ __forceinline__ int getMSB(int val) {
|
|
||||||
return 31 - __clz(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype>
|
|
||||||
struct Float2 {
|
|
||||||
Acctype v1, v2;
|
|
||||||
__device__ Float2() {}
|
|
||||||
__device__ Float2(Dtype v1, Dtype v2) : v1(ScalarConvert<Dtype, Acctype>::to(v1)), v2(ScalarConvert<Dtype, Acctype>::to(v2)) {}
|
|
||||||
__device__ Float2(Dtype v) : v1(ScalarConvert<Dtype, Acctype>::to(v)), v2(ScalarConvert<Dtype, Acctype>::to(v)) {}
|
|
||||||
__device__ Float2(int v) : v1(ScalarConvert<int, Acctype>::to(v)), v2(ScalarConvert<int, Acctype>::to(v)) {}
|
|
||||||
__device__ Float2& operator+=(const Float2& a) {
|
|
||||||
v1 += a.v1;
|
|
||||||
v2 += a.v2;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor3>
|
|
||||||
struct SumOp {
|
|
||||||
__device__ SumOp(const DeviceTensor3 t) : tensor(t) {}
|
|
||||||
__device__ __forceinline__ Acctype operator()(int batch, int plane, int n) {
|
|
||||||
return ScalarConvert<Dtype, Acctype>::to(tensor[batch][plane][n]);
|
|
||||||
}
|
|
||||||
const DeviceTensor3 tensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor3>
|
|
||||||
struct VarOp {
|
|
||||||
__device__ VarOp(Acctype m, const DeviceTensor3 t) : mean(m), tensor(t) {}
|
|
||||||
__device__ __forceinline__ Acctype operator()(int batch, int plane, int n) {
|
|
||||||
Dtype val = tensor[batch][plane][n];
|
|
||||||
return (val - mean) * (val - mean);
|
|
||||||
}
|
|
||||||
const Acctype mean;
|
|
||||||
const DeviceTensor3 tensor;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor3>
|
|
||||||
struct GradOp {
|
|
||||||
__device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g)
|
|
||||||
: mean(m), input(i), gradOutput(g) {}
|
|
||||||
__device__ __forceinline__ Float2<Dtype, Acctype> operator()(int batch, int plane, int n) {
|
|
||||||
Dtype g = gradOutput[batch][plane][n];
|
|
||||||
Dtype c = ScalarConvert<Acctype, Dtype>::to(input[batch][plane][n] - mean);
|
|
||||||
return Float2<Dtype, Acctype>(g, g * c);
|
|
||||||
}
|
|
||||||
const Acctype mean;
|
|
||||||
const DeviceTensor3 input;
|
|
||||||
const DeviceTensor3 gradOutput;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Sum across all threads within a warp
|
|
||||||
template <typename T>
|
|
||||||
static __device__ __forceinline__ T warpSum(T val) {
|
|
||||||
#if __CUDA_ARCH__ >= 300
|
|
||||||
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
|
||||||
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
__shared__ T values[MAX_BLOCK_SIZE];
|
|
||||||
values[threadIdx.x] = val;
|
|
||||||
__threadfence_block();
|
|
||||||
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
|
|
||||||
for (int i = 1; i < WARP_SIZE; i++) {
|
|
||||||
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype>
|
|
||||||
static __device__ __forceinline__ Float2<Dtype, Acctype> warpSum(Float2<Dtype, Acctype> value) {
|
|
||||||
value.v1 = warpSum(value.v1);
|
|
||||||
value.v2 = warpSum(value.v2);
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum across (batch, x/y/z) applying Op() pointwise
|
|
||||||
template<typename T, typename Op, typename DeviceTensor3>
|
|
||||||
__device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
|
|
||||||
T sum = (T)0;
|
|
||||||
for (int batch = 0; batch < tensor.getSize(0); ++batch) {
|
|
||||||
for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) {
|
|
||||||
sum += op(batch, plane, x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sum over NumThreads within a warp
|
|
||||||
sum = warpSum(sum);
|
|
||||||
|
|
||||||
// 'transpose', and reduce within warp again
|
|
||||||
__shared__ T shared[WARP_SIZE];
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x % WARP_SIZE == 0) {
|
|
||||||
shared[threadIdx.x / WARP_SIZE] = sum;
|
|
||||||
}
|
|
||||||
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
|
|
||||||
// zero out the other entries in shared
|
|
||||||
shared[threadIdx.x] = (T)0;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x / WARP_SIZE == 0) {
|
|
||||||
sum = warpSum(shared[threadIdx.x]);
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
shared[0] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Everyone picks it up, should be broadcast into the whole gradInput
|
|
||||||
return shared[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor1, typename DeviceTensor3>
|
|
||||||
__global__ void BatchNormalizationUpdateOutputInference_kernel(
|
|
||||||
const DeviceTensor3 input,
|
|
||||||
DeviceTensor3 output,
|
|
||||||
const DeviceTensor1 runningMean,
|
|
||||||
const DeviceTensor1 runningVar,
|
|
||||||
const DeviceTensor1 weight,
|
|
||||||
const DeviceTensor1 bias,
|
|
||||||
Acctype epsilon) {
|
|
||||||
|
|
||||||
int plane = blockIdx.x;
|
|
||||||
|
|
||||||
Acctype invstd = Acctype(1) / sqrt(runningVar[plane].ldg() + epsilon);
|
|
||||||
Acctype mean = ScalarConvert<Dtype, Acctype>::to(runningMean[plane].ldg());
|
|
||||||
Acctype gamma = weight.numElements() > 0 ? ScalarConvert<Dtype, Acctype>::to(weight[plane].ldg()) : Acctype(1);
|
|
||||||
Acctype beta = bias.numElements() > 0 ? ScalarConvert<Dtype, Acctype>::to(bias[plane].ldg()) : Acctype(0);
|
|
||||||
|
|
||||||
// Write normalized and update the output
|
|
||||||
for (int batch = 0; batch < input.getSize(0); batch++) {
|
|
||||||
for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
|
|
||||||
Dtype inp = input[batch][plane][x].ldg();
|
|
||||||
output[batch][plane][x] = ScalarConvert<Acctype, Dtype>::to(gamma * (inp - mean) * invstd + beta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor1, typename DeviceTensor3>
|
|
||||||
__global__ void BatchNormalizationUpdateOutput_kernel(
|
|
||||||
const DeviceTensor3 input,
|
|
||||||
DeviceTensor3 output,
|
|
||||||
const DeviceTensor1 weight,
|
|
||||||
const DeviceTensor1 bias,
|
|
||||||
const Acctype epsilon,
|
|
||||||
const Acctype momentum,
|
|
||||||
DeviceTensor1 runningMean,
|
|
||||||
DeviceTensor1 runningVar,
|
|
||||||
DeviceTensor1 saveMean,
|
|
||||||
DeviceTensor1 saveStd) {
|
|
||||||
|
|
||||||
int plane = blockIdx.x;
|
|
||||||
int N = input.getSize(0) * input.getSize(2);
|
|
||||||
|
|
||||||
Acctype norm = Acctype(1) / N;
|
|
||||||
|
|
||||||
// Compute the mean and variance across (batch, x/y/z)
|
|
||||||
Acctype mean = reduce<Acctype>(SumOp<Dtype, Acctype, DeviceTensor3>(input), input, plane) * norm;
|
|
||||||
__syncthreads();
|
|
||||||
Acctype varN = reduce<Acctype>(VarOp<Dtype, Acctype, DeviceTensor3>(mean, input), input, plane);
|
|
||||||
Acctype invStd = 0;
|
|
||||||
if (varN != Acctype(0) || epsilon != Acctype(0)) {
|
|
||||||
invStd = 1 / sqrt(varN * norm + epsilon);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the mean, variance, and moving averages
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
// Momentum based writeback
|
|
||||||
Acctype unbiasedVar = varN / (N - 1);
|
|
||||||
saveMean[plane] = ScalarConvert<Acctype, Dtype>::to(mean);
|
|
||||||
saveStd[plane] = ScalarConvert<Acctype, Dtype>::to(invStd);
|
|
||||||
if (runningMean.data() != NULL) {
|
|
||||||
runningMean[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningMean[plane] + momentum * mean);
|
|
||||||
}
|
|
||||||
if (runningVar.data() != NULL) {
|
|
||||||
runningVar[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write normalized and update the output
|
|
||||||
Acctype gamma = weight.numElements() > 0 ? ScalarConvert<Dtype, Acctype>::to(weight[plane]) : ScalarConvert<int, Acctype>::to(1);
|
|
||||||
Acctype beta = bias.numElements() > 0 ? ScalarConvert<Dtype, Acctype>::to(bias[plane]) : ScalarConvert<int, Acctype>::to(0);
|
|
||||||
for (int batch = 0; batch < input.getSize(0); ++batch) {
|
|
||||||
for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
|
|
||||||
Dtype inp = input[batch][plane][x].ldg();
|
|
||||||
output[batch][plane][x] = ScalarConvert<Acctype, Dtype>::to(gamma * (inp - mean) * invStd + beta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Dtype, typename Acctype, typename DeviceTensor1, typename DeviceTensor3>
|
|
||||||
__global__ void BatchNormalizationBackward_kernel(
|
|
||||||
const DeviceTensor3 input,
|
|
||||||
const DeviceTensor3 gradOutput,
|
|
||||||
DeviceTensor3 gradInput,
|
|
||||||
DeviceTensor1 gradWeight,
|
|
||||||
DeviceTensor1 gradBias,
|
|
||||||
const DeviceTensor1 weight,
|
|
||||||
const DeviceTensor1 runningMean,
|
|
||||||
const DeviceTensor1 runningVar,
|
|
||||||
const DeviceTensor1 saveMean,
|
|
||||||
const DeviceTensor1 saveStd,
|
|
||||||
bool train,
|
|
||||||
Acctype scale,
|
|
||||||
double eps) {
|
|
||||||
|
|
||||||
int plane = blockIdx.x;
|
|
||||||
int N = gradOutput.getSize(0) * gradOutput.getSize(2);
|
|
||||||
|
|
||||||
Acctype mean, stdVal;
|
|
||||||
if (train) {
|
|
||||||
mean = ScalarConvert<Dtype, Acctype>::to(saveMean[plane]);
|
|
||||||
stdVal = ScalarConvert<Dtype, Acctype>::to(saveStd[plane]);
|
|
||||||
} else {
|
|
||||||
mean = ScalarConvert<Dtype, Acctype>::to(runningMean[plane]);
|
|
||||||
stdVal = 1 / sqrt(runningVar[plane] + eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
Acctype weightVal = weight.numElements() > 0 ? ScalarConvert<Dtype, Acctype>::to(weight[plane]) : Acctype(1);
|
|
||||||
Acctype norm = Acctype(1) / N;
|
|
||||||
|
|
||||||
// Compute two values across (batch, x/y/z) in one pass:
|
|
||||||
// 1. Sum(gradOutput)
|
|
||||||
// 2. DotProduct(input - mean, gradOutput)
|
|
||||||
GradOp<Dtype, Acctype, DeviceTensor3> g(mean, input, gradOutput);
|
|
||||||
Float2<Dtype, Acctype> res = reduce<Float2<Dtype, Acctype>, GradOp<Dtype, Acctype, DeviceTensor3>, DeviceTensor3>(g, gradOutput, plane);
|
|
||||||
Acctype gradOutputSum = res.v1;
|
|
||||||
Acctype dotP = res.v2;
|
|
||||||
|
|
||||||
Acctype gradMean = gradOutputSum * norm;
|
|
||||||
Acctype projScale = dotP * norm * stdVal * stdVal;
|
|
||||||
Acctype gradScale = stdVal * weightVal;
|
|
||||||
|
|
||||||
if (gradInput.numElements() > 0) {
|
|
||||||
for (int batch = 0; batch < gradOutput.getSize(0); ++batch) {
|
|
||||||
for (int x = threadIdx.x; x < gradOutput.getSize(2); x += blockDim.x) {
|
|
||||||
Dtype gradOut = gradOutput[batch][plane][x];
|
|
||||||
if (train) {
|
|
||||||
Dtype inp = input[batch][plane][x];
|
|
||||||
Acctype proj = (inp - mean) * projScale;
|
|
||||||
gradInput[batch][plane][x] = ScalarConvert<Acctype, Dtype>::to((gradOut - proj - gradMean) * gradScale);
|
|
||||||
} else {
|
|
||||||
gradInput[batch][plane][x] = ScalarConvert<Acctype, Dtype>::to(gradOut * gradScale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradWeight.numElements() > 0) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
gradWeight[plane] += ScalarConvert<Acctype, Dtype>::to(scale * dotP * stdVal);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradBias.numElements() > 0) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
gradBias[plane] += ScalarConvert<Acctype, Dtype>::to(scale * gradOutputSum);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "generic/BatchNormalization.cu"
|
|
||||||
#include "THCGenerateFloatTypes.h"
|
|
@ -1,7 +1,6 @@
|
|||||||
SET(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
|
SET(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/AbsCriterion.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/AbsCriterion.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/Abs.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/Abs.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/BatchNormalization.cu
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/BCECriterion.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/BCECriterion.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ClassNLLCriterion.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/ClassNLLCriterion.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/Col2Im.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/Col2Im.cu
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
#ifndef THC_GENERIC_FILE
|
|
||||||
#define THC_GENERIC_FILE "generic/BatchNormalization.cu"
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define DeviceTensor3 THCDeviceTensor<scalar_t, 3>
|
|
||||||
#define DeviceTensor1 THCDeviceTensor<scalar_t, 1>
|
|
||||||
|
|
||||||
template <int Dim>
|
|
||||||
static THCDeviceTensor<scalar_t, Dim> THNN_(devicetensor)(THCState *state, THCTensor *t) {
|
|
||||||
if (!t) {
|
|
||||||
return THCDeviceTensor<scalar_t, Dim>();
|
|
||||||
}
|
|
||||||
|
|
||||||
int inDim = THCTensor_nDimensionLegacyAll(state, t);
|
|
||||||
if (inDim == Dim) {
|
|
||||||
return toDeviceTensor<scalar_t, Dim>(state, t);
|
|
||||||
}
|
|
||||||
|
|
||||||
// View in which the last dimensions are collapsed or expanded as needed
|
|
||||||
THAssert(t->is_contiguous());
|
|
||||||
int size[Dim];
|
|
||||||
for (int i = 0; i < Dim || i < inDim; ++i) {
|
|
||||||
if (i < Dim && i < inDim) {
|
|
||||||
size[i] = THTensor_sizeLegacyNoScalars(t, i);
|
|
||||||
} else if (i < Dim) {
|
|
||||||
size[i] = 1;
|
|
||||||
} else {
|
|
||||||
size[Dim - 1] *= THTensor_sizeLegacyNoScalars(t, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return THCDeviceTensor<scalar_t, Dim>(t->data<scalar_t>(), size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void THNN_(BatchNormalization_updateOutput)(
|
|
||||||
THCState *state, THCTensor *input_, THCTensor *output_,
|
|
||||||
THCTensor *weight_, THCTensor *bias_, THCTensor *runningMean_,
|
|
||||||
THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_,
|
|
||||||
bool train, double momentum, double eps) {
|
|
||||||
|
|
||||||
THCTensor_(resizeAs)(state, output_, input_);
|
|
||||||
if (train) {
|
|
||||||
int64_t nInput = THCTensor_(size)(state, input_, 1);
|
|
||||||
THCTensor_(resize1d)(state, saveMean_, nInput);
|
|
||||||
THCTensor_(resize1d)(state, saveStd_, nInput);
|
|
||||||
}
|
|
||||||
DeviceTensor3 input = THNN_(devicetensor)<3>(state, input_);
|
|
||||||
DeviceTensor3 output = THNN_(devicetensor)<3>(state, output_);
|
|
||||||
DeviceTensor1 weight = THNN_(devicetensor)<1>(state, weight_);
|
|
||||||
DeviceTensor1 bias = THNN_(devicetensor)<1>(state, bias_);
|
|
||||||
DeviceTensor1 runningMean = THNN_(devicetensor)<1>(state, runningMean_);
|
|
||||||
DeviceTensor1 runningVar = THNN_(devicetensor)<1>(state, runningVar_);
|
|
||||||
DeviceTensor1 saveMean = THNN_(devicetensor)<1>(state, saveMean_);
|
|
||||||
DeviceTensor1 saveStd = THNN_(devicetensor)<1>(state, saveStd_);
|
|
||||||
|
|
||||||
cudaStream_t s = THCState_getCurrentStream(state);
|
|
||||||
cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state);
|
|
||||||
|
|
||||||
if (!train) {
|
|
||||||
dim3 blocks(input.getSize(1));
|
|
||||||
dim3 threads(getNumThreads(input.getSize(2)));
|
|
||||||
BatchNormalizationUpdateOutputInference_kernel<scalar_t, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
|
|
||||||
input, output, runningMean, runningVar, weight, bias, eps);
|
|
||||||
} else {
|
|
||||||
dim3 blocks(input.getSize(1));
|
|
||||||
dim3 threads(getNumThreads(input.getSize(2)));
|
|
||||||
BatchNormalizationUpdateOutput_kernel<scalar_t, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
|
|
||||||
input, output, weight, bias, static_cast<accreal>(eps), static_cast<accreal>(momentum), runningMean, runningVar,
|
|
||||||
saveMean, saveStd);
|
|
||||||
}
|
|
||||||
THCudaCheck(cudaGetLastError());
|
|
||||||
}
|
|
||||||
|
|
||||||
void THNN_(BatchNormalization_backward)(
|
|
||||||
THCState *state, THCTensor *input_, THCTensor *gradOutput_,
|
|
||||||
THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_,
|
|
||||||
THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_,
|
|
||||||
THCTensor *saveMean_, THCTensor *saveStd_, bool train, double scale, double eps) {
|
|
||||||
|
|
||||||
THCUNN_check_shape(state, input_, gradOutput_);
|
|
||||||
if (gradInput_) {
|
|
||||||
THCTensor_(resizeAs)(state, gradInput_, input_);
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceTensor3 input = THNN_(devicetensor)<3>(state, input_);
|
|
||||||
DeviceTensor3 gradOutput = THNN_(devicetensor)<3>(state, gradOutput_);
|
|
||||||
DeviceTensor3 gradInput = THNN_(devicetensor)<3>(state, gradInput_);
|
|
||||||
DeviceTensor1 gradWeight = THNN_(devicetensor)<1>(state, gradWeight_);
|
|
||||||
DeviceTensor1 gradBias = THNN_(devicetensor)<1>(state, gradBias_);
|
|
||||||
DeviceTensor1 weight = THNN_(devicetensor)<1>(state, weight_);
|
|
||||||
DeviceTensor1 runningMean = THNN_(devicetensor)<1>(state, runningMean_);
|
|
||||||
DeviceTensor1 runningVar = THNN_(devicetensor)<1>(state, runningVar_);
|
|
||||||
DeviceTensor1 saveMean = THNN_(devicetensor)<1>(state, saveMean_);
|
|
||||||
DeviceTensor1 saveStd = THNN_(devicetensor)<1>(state, saveStd_);
|
|
||||||
|
|
||||||
cudaStream_t s = THCState_getCurrentStream(state);
|
|
||||||
|
|
||||||
dim3 blocks(gradOutput.getSize(1));
|
|
||||||
dim3 threads(getNumThreads(gradOutput.getSize(2)));
|
|
||||||
BatchNormalizationBackward_kernel<scalar_t, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
|
|
||||||
input, gradOutput, gradInput, gradWeight, gradBias, weight, runningMean, runningVar,
|
|
||||||
saveMean, saveStd, train, scale, eps);
|
|
||||||
THCudaCheck(cudaGetLastError());
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef DeviceTensor3
|
|
||||||
#undef DeviceTensor1
|
|
||||||
|
|
||||||
#endif
|
|
@ -30,36 +30,6 @@ THC_API void THNN_(AbsCriterion_updateGradInput)(
|
|||||||
THCTensor *gradInput,
|
THCTensor *gradInput,
|
||||||
int64_t reduction);
|
int64_t reduction);
|
||||||
|
|
||||||
THC_API void THNN_(BatchNormalization_updateOutput)(
|
|
||||||
THCState *state,
|
|
||||||
THCTensor *input_,
|
|
||||||
THCTensor *output_,
|
|
||||||
THCTensor *weight_, // [OPTIONAL]
|
|
||||||
THCTensor *bias_, // [OPTIONAL]
|
|
||||||
THCTensor *runningMean_, // [OPTIONAL] if train
|
|
||||||
THCTensor *runningVar_, // [OPTIONAL] if train
|
|
||||||
THCTensor *saveMean_,
|
|
||||||
THCTensor *saveStd_,
|
|
||||||
bool train,
|
|
||||||
double momentum,
|
|
||||||
double eps);
|
|
||||||
|
|
||||||
THC_API void THNN_(BatchNormalization_backward)(
|
|
||||||
THCState *state,
|
|
||||||
THCTensor *input_,
|
|
||||||
THCTensor *gradOutput_,
|
|
||||||
THCTensor *gradInput_, // [OPTIONAL]
|
|
||||||
THCTensor *gradWeight_, // [OPTIONAL]
|
|
||||||
THCTensor *gradBias_, // [OPTIONAL]
|
|
||||||
THCTensor *weight_, // [OPTIONAL]
|
|
||||||
THCTensor *runningMean_, // [OPTIONAL] if train
|
|
||||||
THCTensor *runningVar_, // [OPTIONAL] if train
|
|
||||||
THCTensor *saveMean_, // [OPTIONAL] if !train
|
|
||||||
THCTensor *saveStd_, // [OPTIONAL] if !train
|
|
||||||
bool train,
|
|
||||||
double scale,
|
|
||||||
double eps);
|
|
||||||
|
|
||||||
THC_API void THNN_(BCECriterion_updateOutput)(
|
THC_API void THNN_(BCECriterion_updateOutput)(
|
||||||
THCState *state,
|
THCState *state,
|
||||||
THCTensor *input,
|
THCTensor *input,
|
||||||
|
@ -1,160 +0,0 @@
|
|||||||
#ifndef TH_GENERIC_FILE
|
|
||||||
#define TH_GENERIC_FILE "generic/BatchNormalization.c"
|
|
||||||
#else
|
|
||||||
|
|
||||||
void THNN_(BatchNormalization_updateOutput)(
|
|
||||||
THNNState *state, THTensor *input, THTensor *output,
|
|
||||||
THTensor *weight, THTensor *bias,
|
|
||||||
THTensor *running_mean, THTensor *running_var,
|
|
||||||
THTensor *save_mean, THTensor *save_std,
|
|
||||||
bool train, double momentum, double eps)
|
|
||||||
{
|
|
||||||
THTensor_(resizeAs)(output, input);
|
|
||||||
int64_t nInput = THTensor_(size)(input, 1);
|
|
||||||
int64_t f;
|
|
||||||
ptrdiff_t n = THTensor_(nElement)(input) / nInput;
|
|
||||||
|
|
||||||
if (train) {
|
|
||||||
THTensor_(resize1d)(save_mean, nInput);
|
|
||||||
THTensor_(resize1d)(save_std, nInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (f = 0; f < nInput; ++f) {
|
|
||||||
THTensor *in = THTensor_(newSelect)(input, 1, f);
|
|
||||||
THTensor *out = THTensor_(newSelect)(output, 1, f);
|
|
||||||
|
|
||||||
scalar_t mean, invstd;
|
|
||||||
|
|
||||||
if (train) {
|
|
||||||
// compute mean per input
|
|
||||||
accreal sum = 0;
|
|
||||||
TH_TENSOR_APPLY(scalar_t, in, sum += *in_data;);
|
|
||||||
|
|
||||||
mean = (scalar_t) sum / n;
|
|
||||||
THTensor_(set1d)(save_mean, f, (scalar_t) mean);
|
|
||||||
|
|
||||||
// compute variance per input
|
|
||||||
sum = 0;
|
|
||||||
TH_TENSOR_APPLY(scalar_t, in,
|
|
||||||
sum += (*in_data - mean) * (*in_data - mean););
|
|
||||||
|
|
||||||
if (sum == 0 && eps == 0.0) {
|
|
||||||
invstd = 0;
|
|
||||||
} else {
|
|
||||||
invstd = (scalar_t) (1 / sqrt(sum/n + eps));
|
|
||||||
}
|
|
||||||
THTensor_(set1d)(save_std, f, (scalar_t) invstd);
|
|
||||||
|
|
||||||
// update running averages
|
|
||||||
if (running_mean) {
|
|
||||||
THTensor_(set1d)(running_mean, f,
|
|
||||||
(scalar_t) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
|
|
||||||
}
|
|
||||||
if (running_var) {
|
|
||||||
accreal unbiased_var = sum / (n - 1);
|
|
||||||
THTensor_(set1d)(running_var, f,
|
|
||||||
(scalar_t) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
mean = THTensor_(get1d)(running_mean, f);
|
|
||||||
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute output
|
|
||||||
scalar_t w = weight ? THTensor_(get1d)(weight, f) : 1;
|
|
||||||
scalar_t b = bias ? THTensor_(get1d)(bias, f) : 0;
|
|
||||||
|
|
||||||
TH_TENSOR_APPLY2(scalar_t, in, scalar_t, out,
|
|
||||||
*out_data = (scalar_t) (((*in_data - mean) * invstd) * w + b););
|
|
||||||
|
|
||||||
c10::raw::intrusive_ptr::decref(out);
|
|
||||||
c10::raw::intrusive_ptr::decref(in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void THNN_(BatchNormalization_backward)(
|
|
||||||
THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput,
|
|
||||||
THTensor *gradWeight, THTensor *gradBias, THTensor *weight,
|
|
||||||
THTensor *running_mean, THTensor *running_var,
|
|
||||||
THTensor *save_mean, THTensor *save_std,
|
|
||||||
bool train, double scale, double eps)
|
|
||||||
{
|
|
||||||
THNN_CHECK_SHAPE(input, gradOutput);
|
|
||||||
int64_t nInput = THTensor_(size)(input, 1);
|
|
||||||
int64_t f;
|
|
||||||
ptrdiff_t n = THTensor_(nElement)(input) / nInput;
|
|
||||||
|
|
||||||
if (gradInput) {
|
|
||||||
THTensor_(resizeAs)(gradInput, input);
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (f = 0; f < nInput; ++f) {
|
|
||||||
THTensor *in = THTensor_(newSelect)(input, 1, f);
|
|
||||||
THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f);
|
|
||||||
scalar_t w = weight ? THTensor_(get1d)(weight, f) : 1;
|
|
||||||
scalar_t mean, invstd;
|
|
||||||
if (train) {
|
|
||||||
mean = THTensor_(get1d)(save_mean, f);
|
|
||||||
invstd = THTensor_(get1d)(save_std, f);
|
|
||||||
} else {
|
|
||||||
mean = THTensor_(get1d)(running_mean, f);
|
|
||||||
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
// sum over all gradOutput in feature plane
|
|
||||||
accreal sum = 0;
|
|
||||||
TH_TENSOR_APPLY(scalar_t, gradOut, sum += *gradOut_data;);
|
|
||||||
|
|
||||||
// dot product of the Q(X) and gradOuput
|
|
||||||
accreal dotp = 0;
|
|
||||||
TH_TENSOR_APPLY2(scalar_t, in, scalar_t, gradOut,
|
|
||||||
dotp += (*in_data - mean) * (*gradOut_data););
|
|
||||||
|
|
||||||
if (gradInput) {
|
|
||||||
THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f);
|
|
||||||
|
|
||||||
if (train) {
|
|
||||||
// when in training mode
|
|
||||||
// Q(X) = X - E[x] ; i.e. input centered to zero mean
|
|
||||||
// Y = Q(X) / σ ; i.e. BN output before weight and bias
|
|
||||||
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
|
|
||||||
|
|
||||||
// projection of gradOutput on to output scaled by std
|
|
||||||
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
|
|
||||||
TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, in,
|
|
||||||
*gradIn_data = (*in_data - mean) * k;);
|
|
||||||
|
|
||||||
accreal gradMean = sum / n;
|
|
||||||
TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, gradOut,
|
|
||||||
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// when in evaluation mode
|
|
||||||
// Q(X) = X - running_mean ; i.e. input centered to zero mean
|
|
||||||
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
|
|
||||||
// dL/dX = w / running_std
|
|
||||||
TH_TENSOR_APPLY2(scalar_t, gradIn, scalar_t, gradOut,
|
|
||||||
*gradIn_data = *gradOut_data * invstd * w;);
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::raw::intrusive_ptr::decref(gradIn);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradWeight) {
|
|
||||||
scalar_t val = THTensor_(get1d)(gradWeight, f);
|
|
||||||
THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gradBias) {
|
|
||||||
scalar_t val = THTensor_(get1d)(gradBias, f);
|
|
||||||
THTensor_(set1d)(gradBias, f, val + scale * sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::raw::intrusive_ptr::decref(gradOut);
|
|
||||||
c10::raw::intrusive_ptr::decref(in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@ -479,35 +479,6 @@ TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)(
|
|||||||
int osizeW,
|
int osizeW,
|
||||||
bool align_corners);
|
bool align_corners);
|
||||||
|
|
||||||
TH_API void THNN_(BatchNormalization_updateOutput)(
|
|
||||||
THNNState *state,
|
|
||||||
THTensor *input,
|
|
||||||
THTensor *output,
|
|
||||||
THTensor *weight, // [OPTIONAL]
|
|
||||||
THTensor *bias, // [OPTIONAL]
|
|
||||||
THTensor *running_mean, // [OPTIONAL] if train
|
|
||||||
THTensor *running_var, // [OPTIONAL] if train
|
|
||||||
THTensor *save_mean,
|
|
||||||
THTensor *save_std,
|
|
||||||
bool train,
|
|
||||||
double momentum,
|
|
||||||
double eps);
|
|
||||||
TH_API void THNN_(BatchNormalization_backward)(
|
|
||||||
THNNState *state,
|
|
||||||
THTensor *input,
|
|
||||||
THTensor *gradOutput,
|
|
||||||
THTensor *gradInput, // [OPTIONAL]
|
|
||||||
THTensor *gradWeight, // [OPTIONAL]
|
|
||||||
THTensor *gradBias, // [OPTIONAL]
|
|
||||||
THTensor *weight, // [OPTIONAL]
|
|
||||||
THTensor *running_mean, // [OPTIONAL] if train
|
|
||||||
THTensor *running_var, // [OPTIONAL] if train
|
|
||||||
THTensor *save_mean, // [OPTIONAL] if !train
|
|
||||||
THTensor *save_std, // [OPTIONAL] if !train
|
|
||||||
bool train,
|
|
||||||
double scale,
|
|
||||||
double eps);
|
|
||||||
|
|
||||||
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
|
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
|
||||||
THNNState *state,
|
THNNState *state,
|
||||||
THTensor *input,
|
THTensor *input,
|
||||||
|
@ -139,9 +139,6 @@
|
|||||||
#include "generic/FeatureLPPooling.c"
|
#include "generic/FeatureLPPooling.c"
|
||||||
#include "THGenerateFloatTypes.h"
|
#include "THGenerateFloatTypes.h"
|
||||||
|
|
||||||
#include "generic/BatchNormalization.c"
|
|
||||||
#include "THGenerateFloatTypes.h"
|
|
||||||
|
|
||||||
#include "generic/unfold.c"
|
#include "generic/unfold.c"
|
||||||
#include "THGenerateFloatTypes.h"
|
#include "THGenerateFloatTypes.h"
|
||||||
|
|
||||||
|
@ -2789,6 +2789,17 @@ class TestNN(NNTestCase):
|
|||||||
def test_Conv2d_naive_groups_cuda(self, dtype=torch.float):
|
def test_Conv2d_naive_groups_cuda(self, dtype=torch.float):
|
||||||
self._test_Conv2d_naive_groups("cuda", dtype)
|
self._test_Conv2d_naive_groups("cuda", dtype)
|
||||||
|
|
||||||
|
def test_batchnorm_grad(self):
|
||||||
|
self._test_batchnorm_grad()
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||||
|
@skipIfRocm
|
||||||
|
def test_batchnorm_grad_cuda(self):
|
||||||
|
self._test_batchnorm_grad("cuda")
|
||||||
|
if TEST_CUDNN:
|
||||||
|
with torch.backends.cudnn.flags(enabled=False):
|
||||||
|
self._test_batchnorm_grad("cuda")
|
||||||
|
|
||||||
def test_batchnorm_eval(self):
|
def test_batchnorm_eval(self):
|
||||||
self._test_batchnorm_eval()
|
self._test_batchnorm_eval()
|
||||||
|
|
||||||
@ -2796,6 +2807,9 @@ class TestNN(NNTestCase):
|
|||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_batchnorm_eval_cuda(self, dtype=torch.float):
|
def test_batchnorm_eval_cuda(self, dtype=torch.float):
|
||||||
self._test_batchnorm_eval("cuda", dtype)
|
self._test_batchnorm_eval("cuda", dtype)
|
||||||
|
if TEST_CUDNN:
|
||||||
|
with torch.backends.cudnn.flags(enabled=False):
|
||||||
|
self._test_batchnorm_eval("cuda", dtype)
|
||||||
|
|
||||||
def test_batchnorm_simple_average(self):
|
def test_batchnorm_simple_average(self):
|
||||||
self._test_batchnorm_simple_average()
|
self._test_batchnorm_simple_average()
|
||||||
@ -2804,6 +2818,9 @@ class TestNN(NNTestCase):
|
|||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_batchnorm_simple_average_cuda(self):
|
def test_batchnorm_simple_average_cuda(self):
|
||||||
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
|
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
|
||||||
|
if TEST_CUDNN:
|
||||||
|
with torch.backends.cudnn.flags(enabled=False):
|
||||||
|
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
|
||||||
|
|
||||||
def test_MaxPool1d_indices(self):
|
def test_MaxPool1d_indices(self):
|
||||||
self._test_maxpool_indices(1)
|
self._test_maxpool_indices(1)
|
||||||
@ -5021,6 +5038,9 @@ class TestNN(NNTestCase):
|
|||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_batchnorm_update_stats_cuda(self):
|
def test_batchnorm_update_stats_cuda(self):
|
||||||
self._test_batchnorm_update_stats("cuda", torch.float)
|
self._test_batchnorm_update_stats("cuda", torch.float)
|
||||||
|
if TEST_CUDNN:
|
||||||
|
with torch.backends.cudnn.flags(enabled=False):
|
||||||
|
self._test_batchnorm_update_stats("cuda", torch.float)
|
||||||
|
|
||||||
def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
|
def test_batchnorm_raises_error_if_running_mean_is_not_same_size_as_input(self):
|
||||||
input = torch.rand(2, 10)
|
input = torch.rand(2, 10)
|
||||||
@ -5056,6 +5076,18 @@ class TestNN(NNTestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size)))
|
F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size)))
|
||||||
|
|
||||||
|
def _test_batchnorm_grad(self, device="cpu", dtype=torch.double):
|
||||||
|
bs, n_feat, size_feat = 4, 5, 6
|
||||||
|
input = torch.arange(bs * n_feat * size_feat, device=device,
|
||||||
|
requires_grad=True, dtype=dtype).view(bs, n_feat, size_feat)
|
||||||
|
weight = torch.arange(1, n_feat + 1, device=device, requires_grad=True, dtype=dtype)
|
||||||
|
bias = torch.arange(n_feat, device=device, requires_grad=True, dtype=dtype)
|
||||||
|
running_mean = 1 - torch.arange(n_feat, device=device, dtype=dtype)
|
||||||
|
running_var = 2 * torch.arange(n_feat, device=device, dtype=dtype)
|
||||||
|
for training in [False, True]:
|
||||||
|
_assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias,
|
||||||
|
training, 0.1, 0.0001))
|
||||||
|
|
||||||
def _test_batchnorm_eval(self, device="cpu", dtype=torch.float):
|
def _test_batchnorm_eval(self, device="cpu", dtype=torch.float):
|
||||||
module = nn.BatchNorm1d(3).to(device, dtype)
|
module = nn.BatchNorm1d(3).to(device, dtype)
|
||||||
module.eval()
|
module.eval()
|
||||||
|
@ -516,6 +516,14 @@
|
|||||||
- name: mvlgamma(Tensor self, int64_t p)
|
- name: mvlgamma(Tensor self, int64_t p)
|
||||||
self: mvlgamma_backward(grad, self, p)
|
self: mvlgamma_backward(grad, self, p)
|
||||||
|
|
||||||
|
- name: native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
|
||||||
|
input, weight, bias: native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask)
|
||||||
|
|
||||||
|
- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, double eps, std::array<bool,3> output_mask)
|
||||||
|
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
|
||||||
|
save_mean: not_implemented("native_batch_norm_backward save_mean")
|
||||||
|
save_invstd: not_implemented("native_batch_norm_backward save_invstd")
|
||||||
|
|
||||||
- name: ne_(Tensor self, Scalar other)
|
- name: ne_(Tensor self, Scalar other)
|
||||||
self: zeros_like(self)
|
self: zeros_like(self)
|
||||||
|
|
||||||
@ -1035,14 +1043,6 @@
|
|||||||
- name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding)
|
- name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding)
|
||||||
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
|
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
|
||||||
|
|
||||||
- name: thnn_batch_norm_forward(Tensor self, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
|
|
||||||
self, weight, bias: thnn_batch_norm_backward(grad.contiguous(), self, weight, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
|
|
||||||
|
|
||||||
- name: thnn_batch_norm_backward(Tensor grad_output, Tensor self, Tensor weight, Tensor running_mean, Tensor running_var, bool training, double eps, Tensor save_mean, Tensor save_std, std::array<bool,3> output_mask)
|
|
||||||
self, weight, grad_output: batchnorm_double_backward(self, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, training, eps, save_mean, save_std, grad_input_mask)
|
|
||||||
save_mean: not_implemented("thnn_batch_norm_backward save_mean")
|
|
||||||
save_std: not_implemented("thnn_batch_norm_backward save_std")
|
|
||||||
|
|
||||||
- name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
|
- name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
|
||||||
self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask)
|
self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask)
|
||||||
|
|
||||||
@ -1293,7 +1293,7 @@
|
|||||||
# work.)
|
# work.)
|
||||||
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
|
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
|
||||||
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
||||||
input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)"
|
input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)"
|
||||||
|
|
||||||
# HACK: save_mean and save_var are going to be passed in as
|
# HACK: save_mean and save_var are going to be passed in as
|
||||||
# requires_grad variables (even though we'll never backprop through
|
# requires_grad variables (even though we'll never backprop through
|
||||||
@ -1325,7 +1325,7 @@
|
|||||||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
||||||
|
|
||||||
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
||||||
input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)"
|
input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)"
|
||||||
|
|
||||||
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
|
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
|
||||||
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
|
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
|
||||||
|
@ -1883,7 +1883,7 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
|||||||
bool training,
|
bool training,
|
||||||
double eps,
|
double eps,
|
||||||
const Tensor & save_mean,
|
const Tensor & save_mean,
|
||||||
const Tensor & save_std,
|
const Tensor & save_invstd,
|
||||||
std::array<bool,3> output_mask) {
|
std::array<bool,3> output_mask) {
|
||||||
|
|
||||||
bool affine = gamma.defined();
|
bool affine = gamma.defined();
|
||||||
@ -1907,9 +1907,12 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
|||||||
for (auto s : input.sizes().slice(2)) {
|
for (auto s : input.sizes().slice(2)) {
|
||||||
M *= s;
|
M *= s;
|
||||||
}
|
}
|
||||||
auto mu = unsqueeze_dim1(training ? save_mean : running_mean, input);
|
// for half inputs, save_mean, save_invstd are float (ideally, we would cast
|
||||||
|
// everything else, but not now)
|
||||||
|
auto mu = unsqueeze_dim1(training ? save_mean.to(input.type().scalarType()) : running_mean, input);
|
||||||
auto input_sub_mu = input - mu;
|
auto input_sub_mu = input - mu;
|
||||||
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_std : running_var.add(Scalar(eps)).pow(-0.5), input);
|
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_invstd.to(input.type().scalarType())
|
||||||
|
: running_var.add(Scalar(eps)).pow(-0.5), input);
|
||||||
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
|
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
|
||||||
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
|
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user