From b678256bfbece92c30f4524b1c59e83c082306d6 Mon Sep 17 00:00:00 2001 From: "xiaobing.zhang" Date: Fri, 28 Feb 2020 14:51:18 -0800 Subject: [PATCH] Move glu to Aten(CPU) (#33179) Summary: This PR move glu to Aten(CPU). Test script: ``` import torch import torch.nn.functional as F import time torch.manual_seed(0) def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() device = "cpu" #warm up for n in [10, 100, 1000, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n // 2, device=device) for i in range(1000): output = F.glu(input) output.backward(grad_output) for n in [10, 100, 1000, 10000]: fwd_t = 0 bwd_t = 0 input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n // 2, device=device) for i in range(10000): t1 = _time() output = F.glu(input) t2 = _time() output.backward(grad_output) t3 = _time() fwd_t = fwd_t + (t2 -t1) bwd_t = bwd_t + (t3 - t2) fwd_avg = fwd_t / 10000 * 1000 bwd_avg = bwd_t / 10000 * 1000 print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)." % (n, fwd_avg, bwd_avg)) ``` Test device: **skx-8180.** Before: ``` input size(128, 10) forward time is 0.04 (ms); backwad avg time is 0.08 (ms). input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.14 (ms). input size(128, 1000) forward time is 0.11 (ms); backwad avg time is 0.31 (ms). input size(128, 10000) forward time is 1.52 (ms); backwad avg time is 2.04 (ms). ``` After: ``` input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.05 (ms). input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.09 (ms). input size(128, 1000) forward time is 0.07 (ms); backwad avg time is 0.17 (ms). input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 1.03 (ms). ``` Fix https://github.com/pytorch/pytorch/issues/24707, https://github.com/pytorch/pytorch/issues/24708. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33179 Differential Revision: D19839835 Pulled By: VitalyFedyunin fbshipit-source-id: e4d3438556a1068da2c4a7e573d6bbf8d2a6e2b9 --- .circleci/scripts/cpp_doc_push_script.sh | 1 - .github/workflows/lint.yml | 1 - .gitignore | 5 - CONTRIBUTING.md | 1 - aten/CMakeLists.txt | 1 - aten/src/ATen/gen.py | 2 - aten/src/ATen/native/Activation.h | 2 + aten/src/ATen/native/GatedLinearUnit.cpp | 62 +- aten/src/ATen/native/cpu/Activation.cpp | 36 + aten/src/ATen/native/native_functions.yaml | 4 +- aten/src/README.md | 2 +- aten/src/TH/generic/THTensorEvenMoreMath.cpp | 12 - aten/src/TH/generic/THTensorMath.cpp | 66 - aten/src/TH/generic/THTensorMath.h | 5 - aten/src/TH/generic/THVector.h | 5 - aten/src/TH/generic/THVectorDefault.cpp | 139 -- aten/src/TH/generic/THVectorDispatch.cpp | 119 -- aten/src/TH/vector/AVX.cpp | 172 -- aten/src/TH/vector/NEON.cpp | 75 - aten/src/TH/vector/VSX.cpp | 1748 +---------------- aten/src/{THNN => THCUNN}/README.md | 7 +- .../src/{THNN => THCUNN}/doc/api_reference.md | 7 +- .../{THNN => THCUNN}/doc/style_guidelines.md | 41 +- aten/src/THNN/CMakeLists.txt | 4 - aten/src/THNN/THNN.h | 24 - aten/src/THNN/generic/GatedLinearUnit.c | 57 - aten/src/THNN/generic/THNN.h | 24 - aten/src/THNN/init.cpp | 65 - caffe2/CMakeLists.txt | 1 - cmake/Codegen.cmake | 1 - docs/cpp/source/check-doxygen.sh | 1 - setup.py | 2 - 32 files changed, 121 insertions(+), 2571 deletions(-) rename aten/src/{THNN => THCUNN}/README.md (59%) rename aten/src/{THNN => THCUNN}/doc/api_reference.md (74%) rename aten/src/{THNN => THCUNN}/doc/style_guidelines.md (56%) delete mode 100644 aten/src/THNN/CMakeLists.txt delete mode 100644 aten/src/THNN/THNN.h delete mode 100644 aten/src/THNN/generic/GatedLinearUnit.c delete mode 100644 aten/src/THNN/generic/THNN.h delete mode 100644 aten/src/THNN/init.cpp diff --git a/.circleci/scripts/cpp_doc_push_script.sh b/.circleci/scripts/cpp_doc_push_script.sh index 797914d4ba0a..d5942ab08231 100755 --- a/.circleci/scripts/cpp_doc_push_script.sh +++ b/.circleci/scripts/cpp_doc_push_script.sh @@ -57,7 +57,6 @@ time python aten/src/ATen/gen.py \ -s aten/src/ATen \ -d build/aten/src/ATen \ aten/src/ATen/Declarations.cwrap \ - aten/src/THNN/generic/THNN.h \ aten/src/THCUNN/generic/THCUNN.h \ aten/src/ATen/nn.yaml \ aten/src/ATen/native/native_functions.yaml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b7bbf3d4f9a5..84827f9ac3fa 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -178,7 +178,6 @@ jobs: -s aten/src/ATen \ -d build/aten/src/ATen \ aten/src/ATen/Declarations.cwrap \ - aten/src/THNN/generic/THNN.h \ aten/src/THCUNN/generic/THCUNN.h \ aten/src/ATen/nn.yaml \ aten/src/ATen/native/native_functions.yaml diff --git a/.gitignore b/.gitignore index e01a1b140ded..d0a7ad3e67f6 100644 --- a/.gitignore +++ b/.gitignore @@ -57,11 +57,6 @@ torch/csrc/jit/generated/* torch/csrc/jit/fuser/config.h torch/csrc/nn/THCUNN.cpp torch/csrc/nn/THCUNN.cwrap -torch/csrc/nn/THNN_generic.cpp -torch/csrc/nn/THNN_generic.cwrap -torch/csrc/nn/THNN_generic.h -torch/csrc/nn/THNN.cpp -torch/csrc/nn/THNN.cwrap torch/bin/ torch/cmake/ torch/lib/*.a* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 98f02a7482d9..edbb43e854b6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -129,7 +129,6 @@ and `python setup.py clean`. Then you can install in `develop` mode again. * [src](aten/src) * [TH](aten/src/TH) [THC](aten/src/THC) - [THNN](aten/src/THNN) [THCUNN](aten/src/THCUNN) - Legacy library code from the original Torch. Try not to add things here; we're slowly porting these to [native](aten/src/ATen/native). diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 944491cf21ea..c25a2570d1bd 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -49,7 +49,6 @@ set(TH_CPU_INCLUDE ${CMAKE_BINARY_DIR}/aten/src) list(APPEND ATen_CPU_INCLUDE ${TH_CPU_INCLUDE}) -add_subdirectory(src/THNN) # Find the HIP package, set the HIP paths, load the HIP CMake. IF(USE_ROCM) diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 396af17acdc1..a2099a0102a0 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -315,8 +315,6 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi env['th_headers'] = [ '#include ', '#include ', - '#include ', - '#undef THNN_', ] env['extra_cuda_headers'] = [] env['state'] = [] diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index 55ab487f15b2..8f622487f299 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -38,6 +38,8 @@ DECLARE_DISPATCH(shrink_fn, softshrink_stub); DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub); DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub); DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub); +DECLARE_DISPATCH(activation_fn, glu_stub); +DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub); } // namespace native diff --git a/aten/src/ATen/native/GatedLinearUnit.cpp b/aten/src/ATen/native/GatedLinearUnit.cpp index c4d646fcfaa7..d3ba828e414f 100644 --- a/aten/src/ATen/native/GatedLinearUnit.cpp +++ b/aten/src/ATen/native/GatedLinearUnit.cpp @@ -1,31 +1,34 @@ #include #include +#include +#include namespace at { namespace native { +DEFINE_DISPATCH(glu_stub); +DEFINE_DISPATCH(glu_backward_stub); + Tensor& glu_out(Tensor &result, const Tensor& self, int64_t dim) { // this can't pass anyway because a 0-dimensional tensor has "size" 1, which // can't be evenly halved, but give a nicer error message here. TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); - dim = maybe_wrap_dim(dim, self.dim()); - const int64_t nIn = self.size(dim); + auto wrap_dim = maybe_wrap_dim(dim, self.dim()); + const int64_t nIn = self.size(wrap_dim); TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", - dim, " is size ", nIn); - + wrap_dim, " is size ", nIn); // size output to half of input const int64_t selfSize = nIn / 2; auto newSizes = self.sizes().vec(); - newSizes[dim] = selfSize; + newSizes[wrap_dim] = selfSize; result.resize_(newSizes); + // half tensor + Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize); + Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize); - // halve tensor - Tensor firstHalf = self.narrow(dim, 0, selfSize); - Tensor secondHalf = self.narrow(dim, selfSize, selfSize); - - // x = x1:cmul( sigmoid(x2) ) - at::sigmoid_out(result, secondHalf); - return result.mul_(firstHalf); + auto iter = TensorIterator::binary_op(result, firstHalf, secondHalf); + glu_stub(iter.device_type(), iter); + return result; } Tensor glu(const Tensor& self, int64_t dim) { @@ -33,5 +36,40 @@ Tensor glu(const Tensor& self, int64_t dim) { return at::glu_out(result, self, dim); } +Tensor& glu_backward_out(Tensor& grad_input, + const Tensor& grad_output, const Tensor& input, int64_t dim) { + TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors"); + auto wrap_dim = maybe_wrap_dim(dim, input.dim()); + const int64_t nIn = input.size(wrap_dim); + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", + wrap_dim, " is size ", nIn); + + grad_input.resize_as_(input); + const int64_t inputSize = nIn / 2; + // half tensor + Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize); + Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize); + Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize); + Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize); + + at::sigmoid_out(gradInputfirstHalf, secondHalf); + // for second gradinput half, can get a better performance by fusion + auto iter = at::TensorIterator(); + iter.set_check_mem_overlap(true); + iter.add_output(gradInputsecondHalf); + iter.add_input(gradInputfirstHalf); + iter.add_input(firstHalf); + iter.add_input(grad_output); + iter.build(); + glu_backward_stub(iter.device_type(), iter); + gradInputfirstHalf.mul_(grad_output); + return grad_input; +} + +Tensor glu_backward(const Tensor& grad_output, const Tensor& input, int64_t dim) { + auto grad_input = at::empty({0}, input.options()); + return at::glu_backward_out(grad_input, grad_output, input, dim); +} + } // at::native } // at diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index f1e9f5d4d439..1c4913713189 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -448,6 +448,40 @@ void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar thresho }); } +void glu_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&] { + using Vec = Vec256; + const scalar_t one_val(1); + const Vec one_vec(one_val); + cpu_kernel_vec( + iter, + [one_val](scalar_t a, scalar_t b) -> scalar_t { + return a * (one_val / (one_val + std::exp(-b))); + }, + [one_vec](Vec a, Vec b) -> Vec { + return a * (one_vec / (one_vec + b.neg().exp())); + } + ); + }); +} + +void glu_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&] { + using Vec = Vec256; + const scalar_t one_val(1); + const Vec one_vec(one_val); + cpu_kernel_vec( + iter, + [one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t { + return (one_val - a) * a * b * c; + }, + [one_vec](Vec a, Vec b, Vec c) -> Vec { + return (one_vec - a) * a * b * c; + } + ); + }); +} + } // namespace REGISTER_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel); @@ -465,6 +499,8 @@ REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); REGISTER_DISPATCH(softplus_stub, &softplus_kernel); REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); +REGISTER_DISPATCH(glu_stub, &glu_kernel); +REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 911937051f7d..c3538b86b998 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5637,14 +5637,14 @@ - func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU: legacy::cpu::_thnn_glu_backward_out + CPU: glu_backward_out CUDA: legacy::cuda::_thnn_glu_backward_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor use_c10_dispatcher: full python_module: nn dispatch: - CPU: legacy::cpu::_thnn_glu_backward + CPU: glu_backward CUDA: legacy::cuda::_thnn_glu_backward - func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/README.md b/aten/src/README.md index cba392f6daf4..6ad5d1970e08 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -8,7 +8,7 @@ multiple variants of the library, summarized here: * THC = TorcH Cuda * THCS = TorcH Cuda Sparse (now defunct) * THCUNN = TorcH CUda Neural Network (see cunn) -* THNN = TorcH Neural Network +* THNN = TorcH Neural Network (now defunct) * THS = TorcH Sparse (now defunct) (You'll also see these abbreviations show up in symbol names.) diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 94eace363f5e..539a540817dd 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -196,18 +196,6 @@ void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value) } } -void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - if (r_Contig && tContig) { - TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(divs)(r__data, t_data, value, r__len);); - } else { - TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data / value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD) - } -} #endif #if !defined(TH_REAL_IS_BFLOAT16) /* non bfloat16 part*/ diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index bb18f9de97c3..ee2825411eee 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -336,72 +336,6 @@ static inline bool modulo_wrap(scalar_t a, scalar_t b) { return (a != 0) && (a < 0) != (b < 0); } -void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - if (srcSize == r_Size) { - if (r_Contig && tContig && srcContig) { - if(r_ == t) { - THBlas_(axpy)(THTensor_(nElement)(t), value, src->data(), 1, r_->data(), 1); - } else { - TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len);); - } - } else { - TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data + value * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } - } else { - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data + value * *src_data;); - } -} - -void THTensor_(csub)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src) -{ - THTensor_(cadd)(r_, t, -value, src); -} - -void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cmul)(r__data, t_data, src_data, r__len);); - } else { - TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } - } else { - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data * *src_data;); - } -} - -void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cdiv)(r__data, t_data, src_data, r__len);); - } else { - TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data / *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } - } else { - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data / *src_data;); - } -} - void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src) { THTensor_(resizeAs)(r_, t); diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index ebd3a15e95cc..25610b102891 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -65,7 +65,6 @@ TH_API void THTensor_(addr)(THTensor *r_, THTensor *t, THTensor *vec1, THTensor #if !defined(TH_REAL_IS_BOOL) TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value); -TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value); #endif #if !defined(TH_REAL_IS_BFLOAT16) @@ -96,10 +95,6 @@ TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src); TH_API void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value); TH_API void THTensor_(clamp)(THTensor *r_, THTensor *t, scalar_t min_value, scalar_t max_value); -TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src); -TH_API void THTensor_(csub)(THTensor *self, THTensor *src1, scalar_t value, THTensor *src2); -TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src); -TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src); TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src); TH_API void THTensor_(addbmm)(THTensor *r_, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha); diff --git a/aten/src/TH/generic/THVector.h b/aten/src/TH/generic/THVector.h index fe71fa475142..3600c3f398db 100644 --- a/aten/src/TH/generic/THVector.h +++ b/aten/src/TH/generic/THVector.h @@ -9,12 +9,7 @@ TH_API void THVector_(fill)(scalar_t *x, const scalar_t c, const ptrdiff_t n); #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ -TH_API void THVector_(cadd)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n); -TH_API void THVector_(adds)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n); -TH_API void THVector_(cmul)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n); TH_API void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n); -TH_API void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n); -TH_API void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n); TH_API void THVector_(neg)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); TH_API void THVector_(normal_fill)(scalar_t *data, const int64_t size, diff --git a/aten/src/TH/generic/THVectorDefault.cpp b/aten/src/TH/generic/THVectorDefault.cpp index 9c43664f036e..ba4bad50a188 100644 --- a/aten/src/TH/generic/THVectorDefault.cpp +++ b/aten/src/TH/generic/THVectorDefault.cpp @@ -36,54 +36,6 @@ void THVector_(copy_DEFAULT)(scalar_t *x, const scalar_t *y, const ptrdiff_t n) x[i] = y[i]; } -void THVector_(cadd_DEFAULT)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n) -{ - ptrdiff_t i = 0; - - for(; i (0, 1] for log. - const scalar_t u2 = data[j + 8]; - - const scalar_t radius = sqrt(-2 * log(u1)); - const scalar_t theta = 2.0f * M_PI * u2; - - data[j] = radius * cos(theta) * stddev + mean; - data[j + 8] = radius * std::sin(theta) * stddev + mean; - } -} - -void THVector_(normal_fill_DEFAULT)(scalar_t *data, - int64_t size, - at::Generator *generator, - const scalar_t mean, - const scalar_t stddev) -{ - THAssert(size >= 16 && "Size must be >= 16 for normal fill"); - auto gen = at::get_generator_or_default(generator, at::detail::getDefaultCPUGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - - for (int64_t i = 0; i < size; ++i) { -#ifdef TH_REAL_IS_FLOAT - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(gen); -#else - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(gen); -#endif - } - - for (int64_t i = 0; i < size - 15; i += 16) { - THVector_(interleaved_normal_fill_16)(data + i, mean, stddev); - } - - if (size % 16 != 0) { - // Recompute the last 16 values. - data = data + size - 16; - for (int64_t i = 0; i < 16; ++i) { -#ifdef TH_REAL_IS_FLOAT - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(gen); -#else - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(gen); -#endif - } - THVector_(interleaved_normal_fill_16)(data, mean, stddev); - } -} - #define VECTOR_IMPLEMENT_FUNCTION(NAME, CFUNC) \ void THVector_(NAME)(scalar_t *y, const scalar_t *x, const ptrdiff_t n) \ { \ diff --git a/aten/src/TH/generic/THVectorDispatch.cpp b/aten/src/TH/generic/THVectorDispatch.cpp index 6e07dded79ea..38618b68877e 100644 --- a/aten/src/TH/generic/THVectorDispatch.cpp +++ b/aten/src/TH/generic/THVectorDispatch.cpp @@ -39,80 +39,6 @@ void THVector_(fill)(scalar_t *x, const scalar_t c, const ptrdiff_t n) { } #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ - -static void (*THVector_(cadd_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(cadd_DEFAULT); -static FunctionDescription THVector_(cadd_DISPATCHTABLE)[] = { - #if defined(__NEON__) - #if defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cadd_NEON), SIMDExtension_NEON), - #endif - #endif - - #if defined(USE_AVX2) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cadd_AVX2), SIMDExtension_AVX2), - #endif - #endif - - #if defined(USE_AVX) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cadd_AVX), SIMDExtension_AVX), - #endif - #endif - - FUNCTION_IMPL(THVector_(cadd_DEFAULT), SIMDExtension_DEFAULT) -}; -void THVector_(cadd)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n) { - THVector_(cadd_DISPATCHPTR)(z, x, y, c, n); -} - -static void (*THVector_(adds_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(adds_DEFAULT); -static FunctionDescription THVector_(adds_DISPATCHTABLE)[] = { - #if defined(__NEON__) - #if defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(adds_NEON), SIMDExtension_NEON), - #endif - #endif - - #if defined(__PPC64__) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(adds_VSX), SIMDExtension_VSX), - #endif - #endif - - #if defined(USE_AVX) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(adds_AVX), SIMDExtension_AVX), - #endif - #endif - - FUNCTION_IMPL(THVector_(adds_DEFAULT), SIMDExtension_DEFAULT) -}; -// Dispatch stubs that just call the pointers -void THVector_(adds)(scalar_t *r_, const scalar_t *t, const scalar_t value, const ptrdiff_t n) { - THVector_(adds_DISPATCHPTR)(r_, t, value, n); -} - -static void (*THVector_(cmul_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(cmul_DEFAULT); -static FunctionDescription THVector_(cmul_DISPATCHTABLE)[] = { - #if defined(__NEON__) - #if defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cmul_NEON), SIMDExtension_NEON), - #endif - #endif - - #if defined(USE_AVX) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cmul_AVX), SIMDExtension_AVX), - #endif - #endif - - FUNCTION_IMPL(THVector_(cmul_DEFAULT), SIMDExtension_DEFAULT) -}; -void THVector_(cmul)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n) { - THVector_(cmul_DISPATCHPTR)(z, x, y, n); -} - static void (*THVector_(muls_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(muls_DEFAULT); static FunctionDescription THVector_(muls_DISPATCHTABLE)[] = { #if defined(__NEON__) @@ -139,46 +65,6 @@ void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptr THVector_(muls_DISPATCHPTR)(y, x, c, n); } -static void (*THVector_(cdiv_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(cdiv_DEFAULT); -static FunctionDescription THVector_(cdiv_DISPATCHTABLE)[] = { - #if defined(__NEON__) - #if defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cdiv_NEON), SIMDExtension_NEON), - #endif - #endif - - #if defined(USE_AVX) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(cdiv_AVX), SIMDExtension_AVX), - #endif - #endif - - FUNCTION_IMPL(THVector_(cdiv_DEFAULT), SIMDExtension_DEFAULT) -}; -void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n) { - THVector_(cdiv_DISPATCHPTR)(z, x, y, n); -} - -static void (*THVector_(divs_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(divs_DEFAULT); -static FunctionDescription THVector_(divs_DISPATCHTABLE)[] = { - #if defined(__NEON__) - #if defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(divs_NEON), SIMDExtension_NEON), - #endif - #endif - - #if defined(USE_AVX) - #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) - FUNCTION_IMPL(THVector_(divs_AVX), SIMDExtension_AVX), - #endif - #endif - - FUNCTION_IMPL(THVector_(divs_DEFAULT), SIMDExtension_DEFAULT) -}; -void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n) { - THVector_(divs_DISPATCHPTR)(y, x, c, n); -} - /* * This struct's constructor initializes the dispatch tables. It simply checks * what SIMD extensions are available, and then walks the dispatch table @@ -191,12 +77,7 @@ struct THVector_(startup) { THVector_(startup)() { uint32_t hostSimdExts = detectHostSIMDExtensions(); INIT_DISPATCH_PTR(fill); - INIT_DISPATCH_PTR(cadd); - INIT_DISPATCH_PTR(adds); - INIT_DISPATCH_PTR(cmul); INIT_DISPATCH_PTR(muls); - INIT_DISPATCH_PTR(cdiv); - INIT_DISPATCH_PTR(divs); } }; diff --git a/aten/src/TH/vector/AVX.cpp b/aten/src/TH/vector/AVX.cpp index 11d8d0b1dcd4..bd4538b373bc 100644 --- a/aten/src/TH/vector/AVX.cpp +++ b/aten/src/TH/vector/AVX.cpp @@ -24,59 +24,6 @@ void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n) { } } -void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { - ptrdiff_t i; - __m256d YMM0, YMM1, YMM2, YMM3; - for (i=0; i<=((n)-8); i+=8) { - YMM0 = _mm256_loadu_pd(x+i); - YMM1 = _mm256_loadu_pd(x+i+4); - YMM2 = _mm256_loadu_pd(y+i); - YMM3 = _mm256_loadu_pd(y+i+4); - YMM2 = _mm256_div_pd(YMM0, YMM2); - YMM3 = _mm256_div_pd(YMM1, YMM3); - _mm256_storeu_pd(z+i, YMM2); - _mm256_storeu_pd(z+i+4, YMM3); - } - for (; i<(n); i++) { - z[i] = x[i] / y[i]; - } -} - -void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { - ptrdiff_t i; - __m256d YMM15 = _mm256_set_pd(c, c, c, c); - __m256d YMM0, YMM1; - for (i=0; i<=((n)-8); i+=8) { - YMM0 = _mm256_loadu_pd(x+i); - YMM1 = _mm256_loadu_pd(x+i+4); - YMM0 = _mm256_div_pd(YMM0, YMM15); - YMM1 = _mm256_div_pd(YMM1, YMM15); - _mm256_storeu_pd(y+i, YMM0); - _mm256_storeu_pd(y+i+4, YMM1); - } - for (; i<(n); i++) { - y[i] = x[i] / c; - } -} - -void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) { - ptrdiff_t i; - __m256d YMM0, YMM1, YMM2, YMM3; - for (i=0; i<=((n)-8); i+=8) { - YMM0 = _mm256_loadu_pd(x+i); - YMM1 = _mm256_loadu_pd(x+i+4); - YMM2 = _mm256_loadu_pd(y+i); - YMM3 = _mm256_loadu_pd(y+i+4); - YMM2 = _mm256_mul_pd(YMM0, YMM2); - YMM3 = _mm256_mul_pd(YMM1, YMM3); - _mm256_storeu_pd(z+i, YMM2); - _mm256_storeu_pd(z+i+4, YMM3); - } - for (; i ``` diff --git a/aten/src/THNN/CMakeLists.txt b/aten/src/THNN/CMakeLists.txt deleted file mode 100644 index ab4bb755071c..000000000000 --- a/aten/src/THNN/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -set(ATen_CPU_SRCS ${ATen_CPU_SRCS} - ${CMAKE_CURRENT_SOURCE_DIR}/init.cpp -PARENT_SCOPE) -INSTALL(FILES generic/THNN.h DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THNN/generic") diff --git a/aten/src/THNN/THNN.h b/aten/src/THNN/THNN.h deleted file mode 100644 index fbf8a03c4596..000000000000 --- a/aten/src/THNN/THNN.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef THNN_H -#define THNN_H - -#include -#include - -#define THNN_(NAME) TH_CONCAT_3(THNN_, Real, NAME) - -#define THIndexTensor THLongTensor -#define THIndexTensor_(NAME) THLongTensor_ ## NAME - -typedef int64_t THIndex_t; -typedef void THNNState; - -#include -#include - -#include -#include - -#include -#include - -#endif diff --git a/aten/src/THNN/generic/GatedLinearUnit.c b/aten/src/THNN/generic/GatedLinearUnit.c deleted file mode 100644 index c0634028646f..000000000000 --- a/aten/src/THNN/generic/GatedLinearUnit.c +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/GatedLinearUnit.c" -#else - -#include - -void THNN_(GatedLinear_updateOutput)( - THNNState *state, - THTensor *input, - THTensor *output, - int dim) -{ - TORCH_INTERNAL_ASSERT(false, "GatedLinear_updateOutput called, but this is just " \ - "a stub for nn.yaml parsing"); -} - -void THNN_(GatedLinear_updateGradInput)( - THNNState *state, - THTensor *input, - THTensor *gradOutput, - THTensor *gradInput, - int dim) -{ - dim = at::maybe_wrap_dim(dim, input); - // set up tensors - const int64_t nIn = THTensor_(size)(input, dim); - THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld", - dim, nIn); - - THTensor_(resizeAs)(gradInput, input); - const int64_t inputSize = THTensor_(size)(input, dim) / 2; - THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize); - THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize); - THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize); - THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize); - - at::Tensor gradInputfirstHalf_wrap = THTensor_wrap(gradInputfirstHalf); - at::Tensor secondHalf_wrap = THTensor_wrap(secondHalf); - at::native::sigmoid_out(gradInputfirstHalf_wrap, secondHalf_wrap); - - TH_TENSOR_APPLY2(scalar_t, gradInputsecondHalf, scalar_t, gradInputfirstHalf, - scalar_t z = *gradInputfirstHalf_data; - *gradInputsecondHalf_data = (1. - z) * z; - ); - - THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput); - - THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput); - THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf); - - c10::raw::intrusive_ptr::decref(firstHalf); - c10::raw::intrusive_ptr::decref(secondHalf); - c10::raw::intrusive_ptr::decref(gradInputfirstHalf); - c10::raw::intrusive_ptr::decref(gradInputsecondHalf); -} - -#endif diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h deleted file mode 100644 index 1acb792502fe..000000000000 --- a/aten/src/THNN/generic/THNN.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "THNN/generic/THNN.h" -#else - -#include -#include -#include - -#if !defined(TH_REAL_IS_LONG) - -TH_API void THNN_(GatedLinear_updateOutput)( - THNNState *state, // library's state - THTensor *input, // input tensor - THTensor *output, // [OUT] output tensor, half size of input along dimension dim - int dim); // dimension for halving operation -TH_API void THNN_(GatedLinear_updateGradInput)( - THNNState *state, // library's state - THTensor *input, // input tensor - THTensor *gradOutput, // gradient w.r.t module's output - THTensor *gradInput, // [OUT] gradient w.r.t input - int dim); // dimension for halving operation - -#endif -#endif diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp deleted file mode 100644 index 7aaabfb960b7..000000000000 --- a/aten/src/THNN/init.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include -#include - -#include -#include - -#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) -#define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME) - -#define THNN_CHECK_SHAPE(I1, I2) \ - if (I1 != NULL && I2 != NULL && !THTensor_(isSameSizeAs)(I1, I2)) \ - { \ - THDescBuff s1 = THTensor_(sizeDesc)(I1); \ - THDescBuff s2 = THTensor_(sizeDesc)(I2); \ - THError(#I1 " and " #I2 " shapes do not match: " \ - #I1 " %s, " #I2 " %s", s1.str, s2.str); \ - } - -#define THNN_CHECK_SHAPE_INDICES(I1, I2) \ - if (I1 != NULL && I2 != NULL && !I1->sizes().equals(I2->sizes())) \ - { \ - THDescBuff s1 = THTensor_(sizeDesc)(I1); \ - THDescBuff s2 = THLongTensor_sizeDesc(I2); \ - THError(#I1 " and " #I2 " shapes do not match: " \ - #I1 " %s, " #I2 " %s", s1.str, s2.str); \ - } - -#define THNN_CHECK_NELEMENT(I1, I2) \ - if (I1 != NULL && I2 != NULL ) { \ - ptrdiff_t n1 = THTensor_(nElement)(I1); \ - ptrdiff_t n2 = THTensor_(nElement)(I2); \ - if (n1 != n2) \ - { \ - THDescBuff s1 = THTensor_(sizeDesc)(I1); \ - THDescBuff s2 = THTensor_(sizeDesc)(I2); \ - THError(#I1 " and " #I2 " have different number of elements: " \ - #I1 "%s has %ld elements, while " \ - #I2 "%s has %ld elements", s1.str, n1, s2.str, n2); \ - } \ - } - -#define THNN_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \ - if (THTensor_(nDimensionLegacyNoScalars)(T) != DIM || \ - THTensor_sizeLegacyNoScalars(T, DIM_SIZE) != SIZE) { \ - THDescBuff s1 = THTensor_(sizeDesc)(T); \ - THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \ - " but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \ - } - -#define THNN_CHECK_DIM_SIZE_INDICES(T, DIM, DIM_SIZE, SIZE) \ - if (THIndexTensor_(nDimensionLegacyNoScalars)(T) != DIM || \ - THTensor_sizeLegacyNoScalars(T, DIM_SIZE) != SIZE) { \ - THDescBuff s1 = THIndexTensor_(sizeDesc)(T); \ - THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \ - " but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \ - } - -#define THNN_ARGCHECK(COND, ARG, T, FORMAT) \ - if (!(COND)) { \ - THDescBuff s1 = THTensor_(sizeDesc)(T); \ - THArgCheck(COND, ARG, FORMAT, s1.str); \ - } - -#include -#include diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index c98cdc4010cf..ad2d88c2ee48 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -295,7 +295,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) $<$:--selected-op-list-path="${SELECTED_OP_LIST}"> DEPENDS "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" - "${CMAKE_CURRENT_LIST_DIR}/../aten/src/THNN/generic/THNN.h" "${TOOLS_PATH}/autograd/templates/VariableType.h" "${TOOLS_PATH}/autograd/templates/VariableType.cpp" "${TOOLS_PATH}/autograd/templates/Functions.h" diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index f292a97fec2e..671a029152d7 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -140,7 +140,6 @@ if (INTERN_BUILD_ATEN_OPS) set(cwrap_files ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/Declarations.cwrap - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/THNN/generic/THNN.h ${CMAKE_CURRENT_LIST_DIR}/../aten/src/THCUNN/generic/THCUNN.h ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/nn.yaml ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml) diff --git a/docs/cpp/source/check-doxygen.sh b/docs/cpp/source/check-doxygen.sh index 5d7b6a893478..8ac1361da2cf 100755 --- a/docs/cpp/source/check-doxygen.sh +++ b/docs/cpp/source/check-doxygen.sh @@ -19,7 +19,6 @@ python aten/src/ATen/gen.py \ -s aten/src/ATen \ -d build/aten/src/ATen \ aten/src/ATen/Declarations.cwrap \ - aten/src/THNN/generic/THNN.h \ aten/src/THCUNN/generic/THCUNN.h \ aten/src/ATen/nn.yaml \ aten/src/ATen/native/native_functions.yaml diff --git a/setup.py b/setup.py index c3ecb43a6494..12ff07e0a856 100644 --- a/setup.py +++ b/setup.py @@ -852,8 +852,6 @@ if __name__ == '__main__': 'include/THH/*.cuh', 'include/THH/*.h*', 'include/THH/generic/*.h', - 'include/THNN/*.h', - 'include/THNN/generic/*.h', 'share/cmake/ATen/*.cmake', 'share/cmake/Caffe2/*.cmake', 'share/cmake/Caffe2/public/*.cmake',