mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Move glu to Aten(CPU) (#33179)
Summary: This PR move glu to Aten(CPU). Test script: ``` import torch import torch.nn.functional as F import time torch.manual_seed(0) def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() device = "cpu" #warm up for n in [10, 100, 1000, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n // 2, device=device) for i in range(1000): output = F.glu(input) output.backward(grad_output) for n in [10, 100, 1000, 10000]: fwd_t = 0 bwd_t = 0 input = torch.randn(128, n, requires_grad=True, device=device) grad_output = torch.ones(128, n // 2, device=device) for i in range(10000): t1 = _time() output = F.glu(input) t2 = _time() output.backward(grad_output) t3 = _time() fwd_t = fwd_t + (t2 -t1) bwd_t = bwd_t + (t3 - t2) fwd_avg = fwd_t / 10000 * 1000 bwd_avg = bwd_t / 10000 * 1000 print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)." % (n, fwd_avg, bwd_avg)) ``` Test device: **skx-8180.** Before: ``` input size(128, 10) forward time is 0.04 (ms); backwad avg time is 0.08 (ms). input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.14 (ms). input size(128, 1000) forward time is 0.11 (ms); backwad avg time is 0.31 (ms). input size(128, 10000) forward time is 1.52 (ms); backwad avg time is 2.04 (ms). ``` After: ``` input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.05 (ms). input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.09 (ms). input size(128, 1000) forward time is 0.07 (ms); backwad avg time is 0.17 (ms). input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 1.03 (ms). ``` Fix https://github.com/pytorch/pytorch/issues/24707, https://github.com/pytorch/pytorch/issues/24708. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33179 Differential Revision: D19839835 Pulled By: VitalyFedyunin fbshipit-source-id: e4d3438556a1068da2c4a7e573d6bbf8d2a6e2b9
This commit is contained in:
committed by
Facebook Github Bot
parent
3c5677a676
commit
b678256bfb
@ -57,7 +57,6 @@ time python aten/src/ATen/gen.py \
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THNN/generic/THNN.h \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
|
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@ -178,7 +178,6 @@ jobs:
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THNN/generic/THNN.h \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -57,11 +57,6 @@ torch/csrc/jit/generated/*
|
||||
torch/csrc/jit/fuser/config.h
|
||||
torch/csrc/nn/THCUNN.cpp
|
||||
torch/csrc/nn/THCUNN.cwrap
|
||||
torch/csrc/nn/THNN_generic.cpp
|
||||
torch/csrc/nn/THNN_generic.cwrap
|
||||
torch/csrc/nn/THNN_generic.h
|
||||
torch/csrc/nn/THNN.cpp
|
||||
torch/csrc/nn/THNN.cwrap
|
||||
torch/bin/
|
||||
torch/cmake/
|
||||
torch/lib/*.a*
|
||||
|
@ -129,7 +129,6 @@ and `python setup.py clean`. Then you can install in `develop` mode again.
|
||||
* [src](aten/src)
|
||||
* [TH](aten/src/TH)
|
||||
[THC](aten/src/THC)
|
||||
[THNN](aten/src/THNN)
|
||||
[THCUNN](aten/src/THCUNN) - Legacy library code from the original
|
||||
Torch. Try not to add things here; we're slowly porting these to
|
||||
[native](aten/src/ATen/native).
|
||||
|
@ -49,7 +49,6 @@ set(TH_CPU_INCLUDE
|
||||
${CMAKE_BINARY_DIR}/aten/src)
|
||||
list(APPEND ATen_CPU_INCLUDE ${TH_CPU_INCLUDE})
|
||||
|
||||
add_subdirectory(src/THNN)
|
||||
|
||||
# Find the HIP package, set the HIP paths, load the HIP CMake.
|
||||
IF(USE_ROCM)
|
||||
|
@ -315,8 +315,6 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi
|
||||
env['th_headers'] = [
|
||||
'#include <TH/TH.h>',
|
||||
'#include <TH/THTensor.hpp>',
|
||||
'#include <THNN/THNN.h>',
|
||||
'#undef THNN_',
|
||||
]
|
||||
env['extra_cuda_headers'] = []
|
||||
env['state'] = []
|
||||
|
@ -38,6 +38,8 @@ DECLARE_DISPATCH(shrink_fn, softshrink_stub);
|
||||
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
|
||||
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
|
||||
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
|
||||
DECLARE_DISPATCH(activation_fn, glu_stub);
|
||||
DECLARE_DISPATCH(activation_backward_fn, glu_backward_stub);
|
||||
|
||||
} // namespace native
|
||||
|
||||
|
@ -1,31 +1,34 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/Activation.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
DEFINE_DISPATCH(glu_stub);
|
||||
DEFINE_DISPATCH(glu_backward_stub);
|
||||
|
||||
Tensor& glu_out(Tensor &result, const Tensor& self, int64_t dim) {
|
||||
// this can't pass anyway because a 0-dimensional tensor has "size" 1, which
|
||||
// can't be evenly halved, but give a nicer error message here.
|
||||
TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors");
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
const int64_t nIn = self.size(dim);
|
||||
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
|
||||
const int64_t nIn = self.size(wrap_dim);
|
||||
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
|
||||
dim, " is size ", nIn);
|
||||
|
||||
wrap_dim, " is size ", nIn);
|
||||
// size output to half of input
|
||||
const int64_t selfSize = nIn / 2;
|
||||
auto newSizes = self.sizes().vec();
|
||||
newSizes[dim] = selfSize;
|
||||
newSizes[wrap_dim] = selfSize;
|
||||
result.resize_(newSizes);
|
||||
// half tensor
|
||||
Tensor firstHalf = self.narrow(wrap_dim, 0, selfSize);
|
||||
Tensor secondHalf = self.narrow(wrap_dim, selfSize, selfSize);
|
||||
|
||||
// halve tensor
|
||||
Tensor firstHalf = self.narrow(dim, 0, selfSize);
|
||||
Tensor secondHalf = self.narrow(dim, selfSize, selfSize);
|
||||
|
||||
// x = x1:cmul( sigmoid(x2) )
|
||||
at::sigmoid_out(result, secondHalf);
|
||||
return result.mul_(firstHalf);
|
||||
auto iter = TensorIterator::binary_op(result, firstHalf, secondHalf);
|
||||
glu_stub(iter.device_type(), iter);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor glu(const Tensor& self, int64_t dim) {
|
||||
@ -33,5 +36,40 @@ Tensor glu(const Tensor& self, int64_t dim) {
|
||||
return at::glu_out(result, self, dim);
|
||||
}
|
||||
|
||||
Tensor& glu_backward_out(Tensor& grad_input,
|
||||
const Tensor& grad_output, const Tensor& input, int64_t dim) {
|
||||
TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
|
||||
auto wrap_dim = maybe_wrap_dim(dim, input.dim());
|
||||
const int64_t nIn = input.size(wrap_dim);
|
||||
TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
|
||||
wrap_dim, " is size ", nIn);
|
||||
|
||||
grad_input.resize_as_(input);
|
||||
const int64_t inputSize = nIn / 2;
|
||||
// half tensor
|
||||
Tensor firstHalf = input.narrow(wrap_dim, 0, inputSize);
|
||||
Tensor secondHalf = input.narrow(wrap_dim, inputSize, inputSize);
|
||||
Tensor gradInputfirstHalf = grad_input.narrow(wrap_dim, 0, inputSize);
|
||||
Tensor gradInputsecondHalf = grad_input.narrow(wrap_dim, inputSize, inputSize);
|
||||
|
||||
at::sigmoid_out(gradInputfirstHalf, secondHalf);
|
||||
// for second gradinput half, can get a better performance by fusion
|
||||
auto iter = at::TensorIterator();
|
||||
iter.set_check_mem_overlap(true);
|
||||
iter.add_output(gradInputsecondHalf);
|
||||
iter.add_input(gradInputfirstHalf);
|
||||
iter.add_input(firstHalf);
|
||||
iter.add_input(grad_output);
|
||||
iter.build();
|
||||
glu_backward_stub(iter.device_type(), iter);
|
||||
gradInputfirstHalf.mul_(grad_output);
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor glu_backward(const Tensor& grad_output, const Tensor& input, int64_t dim) {
|
||||
auto grad_input = at::empty({0}, input.options());
|
||||
return at::glu_backward_out(grad_input, grad_output, input, dim);
|
||||
}
|
||||
|
||||
} // at::native
|
||||
} // at
|
||||
|
@ -448,6 +448,40 @@ void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar thresho
|
||||
});
|
||||
}
|
||||
|
||||
void glu_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_cpu", [&] {
|
||||
using Vec = Vec256<scalar_t>;
|
||||
const scalar_t one_val(1);
|
||||
const Vec one_vec(one_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[one_val](scalar_t a, scalar_t b) -> scalar_t {
|
||||
return a * (one_val / (one_val + std::exp(-b)));
|
||||
},
|
||||
[one_vec](Vec a, Vec b) -> Vec {
|
||||
return a * (one_vec / (one_vec + b.neg().exp()));
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
void glu_backward_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "glu_backward_cpu", [&] {
|
||||
using Vec = Vec256<scalar_t>;
|
||||
const scalar_t one_val(1);
|
||||
const Vec one_vec(one_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[one_val](scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
|
||||
return (one_val - a) * a * b * c;
|
||||
},
|
||||
[one_vec](Vec a, Vec b, Vec c) -> Vec {
|
||||
return (one_vec - a) * a * b * c;
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
|
||||
@ -465,6 +499,8 @@ REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
|
||||
REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
|
||||
REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
|
||||
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
|
||||
REGISTER_DISPATCH(glu_stub, &glu_kernel);
|
||||
REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -5637,14 +5637,14 @@
|
||||
- func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_thnn_glu_backward_out
|
||||
CPU: glu_backward_out
|
||||
CUDA: legacy::cuda::_thnn_glu_backward_out
|
||||
|
||||
- func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: legacy::cpu::_thnn_glu_backward
|
||||
CPU: glu_backward
|
||||
CUDA: legacy::cuda::_thnn_glu_backward
|
||||
|
||||
- func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -8,7 +8,7 @@ multiple variants of the library, summarized here:
|
||||
* THC = TorcH Cuda
|
||||
* THCS = TorcH Cuda Sparse (now defunct)
|
||||
* THCUNN = TorcH CUda Neural Network (see cunn)
|
||||
* THNN = TorcH Neural Network
|
||||
* THNN = TorcH Neural Network (now defunct)
|
||||
* THS = TorcH Sparse (now defunct)
|
||||
|
||||
(You'll also see these abbreviations show up in symbol names.)
|
||||
|
@ -196,18 +196,6 @@ void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value)
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value)
|
||||
{
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
if (r_Contig && tContig) {
|
||||
TH_TENSOR_APPLY2_CONTIG(scalar_t, r_, scalar_t, t, THVector_(divs)(r__data, t_data, value, r__len););
|
||||
} else {
|
||||
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = *t_data / value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD)
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(TH_REAL_IS_BFLOAT16) /* non bfloat16 part*/
|
||||
|
@ -336,72 +336,6 @@ static inline bool modulo_wrap(scalar_t a, scalar_t b) {
|
||||
return (a != 0) && (a < 0) != (b < 0);
|
||||
}
|
||||
|
||||
void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src)
|
||||
{
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int64_t srcSize = THTensor_(nElement)(src);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
int srcContig = THTensor_(isContiguous)(src);
|
||||
if (srcSize == r_Size) {
|
||||
if (r_Contig && tContig && srcContig) {
|
||||
if(r_ == t) {
|
||||
THBlas_(axpy)(THTensor_(nElement)(t), value, src->data<scalar_t>(), 1, r_->data<scalar_t>(), 1);
|
||||
} else {
|
||||
TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len););
|
||||
}
|
||||
} else {
|
||||
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data + value * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
}
|
||||
} else {
|
||||
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data + value * *src_data;);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(csub)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src)
|
||||
{
|
||||
THTensor_(cadd)(r_, t, -value, src);
|
||||
}
|
||||
|
||||
void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src)
|
||||
{
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int64_t srcSize = THTensor_(nElement)(src);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
int srcContig = THTensor_(isContiguous)(src);
|
||||
if (srcSize == r_Size){
|
||||
if (r_Contig && tContig && srcContig) {
|
||||
TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cmul)(r__data, t_data, src_data, r__len););
|
||||
} else {
|
||||
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
}
|
||||
} else {
|
||||
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data * *src_data;);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src)
|
||||
{
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
int64_t r_Size = THTensor_(nElement)(r_);
|
||||
int64_t srcSize = THTensor_(nElement)(src);
|
||||
int r_Contig = THTensor_(isContiguous)(r_);
|
||||
int tContig = THTensor_(isContiguous)(t);
|
||||
int srcContig = THTensor_(isContiguous)(src);
|
||||
if (srcSize == r_Size){
|
||||
if (r_Contig && tContig && srcContig) {
|
||||
TH_TENSOR_APPLY3_CONTIG(scalar_t, r_, scalar_t, t, scalar_t, src, THVector_(cdiv)(r__data, t_data, src_data, r__len););
|
||||
} else {
|
||||
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data / *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
|
||||
}
|
||||
} else {
|
||||
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = *t_data / *src_data;);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src)
|
||||
{
|
||||
THTensor_(resizeAs)(r_, t);
|
||||
|
@ -65,7 +65,6 @@ TH_API void THTensor_(addr)(THTensor *r_, THTensor *t, THTensor *vec1, THTensor
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL)
|
||||
TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value);
|
||||
TH_API void THTensor_(div)(THTensor *r_, THTensor *t, scalar_t value);
|
||||
#endif
|
||||
|
||||
#if !defined(TH_REAL_IS_BFLOAT16)
|
||||
@ -96,10 +95,6 @@ TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);
|
||||
TH_API void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value);
|
||||
TH_API void THTensor_(clamp)(THTensor *r_, THTensor *t, scalar_t min_value, scalar_t max_value);
|
||||
|
||||
TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src);
|
||||
TH_API void THTensor_(csub)(THTensor *self, THTensor *src1, scalar_t value, THTensor *src2);
|
||||
TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src);
|
||||
|
||||
TH_API void THTensor_(addbmm)(THTensor *r_, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha);
|
||||
|
@ -9,12 +9,7 @@ TH_API void THVector_(fill)(scalar_t *x, const scalar_t c, const ptrdiff_t n);
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
TH_API void THVector_(cadd)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n);
|
||||
TH_API void THVector_(adds)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
|
||||
TH_API void THVector_(cmul)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n);
|
||||
TH_API void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
|
||||
TH_API void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n);
|
||||
TH_API void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n);
|
||||
TH_API void THVector_(neg)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
|
||||
TH_API void THVector_(normal_fill)(scalar_t *data,
|
||||
const int64_t size,
|
||||
|
@ -36,54 +36,6 @@ void THVector_(copy_DEFAULT)(scalar_t *x, const scalar_t *y, const ptrdiff_t n)
|
||||
x[i] = y[i];
|
||||
}
|
||||
|
||||
void THVector_(cadd_DEFAULT)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
|
||||
for(; i<n-4; i+=4)
|
||||
{
|
||||
z[i] = x[i] + c * y[i];
|
||||
z[i+1] = x[i+1] + c * y[i+1];
|
||||
z[i+2] = x[i+2] + c * y[i+2];
|
||||
z[i+3] = x[i+3] + c * y[i+3];
|
||||
}
|
||||
|
||||
for(; i<n; i++)
|
||||
z[i] = x[i] + c * y[i];
|
||||
}
|
||||
|
||||
void THVector_(adds_DEFAULT)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
|
||||
for(; i<n-4; i+=4)
|
||||
{
|
||||
y[i] = x[i] + c;
|
||||
y[i+1] = x[i+1] + c;
|
||||
y[i+2] = x[i+2] + c;
|
||||
y[i+3] = x[i+3] + c;
|
||||
}
|
||||
|
||||
for(; i<n; i++)
|
||||
y[i] = x[i] + c;
|
||||
}
|
||||
|
||||
void THVector_(cmul_DEFAULT)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
|
||||
for(; i <n-4; i+=4)
|
||||
{
|
||||
z[i] = x[i] * y[i];
|
||||
z[i+1] = x[i+1] * y[i+1];
|
||||
z[i+2] = x[i+2] * y[i+2];
|
||||
z[i+3] = x[i+3] * y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
|
||||
void THVector_(muls_DEFAULT)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
@ -100,97 +52,6 @@ void THVector_(muls_DEFAULT)(scalar_t *y, const scalar_t *x, const scalar_t c, c
|
||||
y[i] = x[i] * c;
|
||||
}
|
||||
|
||||
void THVector_(cdiv_DEFAULT)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
|
||||
for(; i<n-4; i+=4)
|
||||
{
|
||||
z[i] = x[i] / y[i];
|
||||
z[i+1] = x[i+1] / y[i+1];
|
||||
z[i+2] = x[i+2] / y[i+2];
|
||||
z[i+3] = x[i+3] / y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] / y[i];
|
||||
}
|
||||
|
||||
void THVector_(divs_DEFAULT)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n)
|
||||
{
|
||||
ptrdiff_t i = 0;
|
||||
|
||||
for(; i<n-4; i+=4)
|
||||
{
|
||||
y[i] = x[i] / c;
|
||||
y[i+1] = x[i+1] / c;
|
||||
y[i+2] = x[i+2] / c;
|
||||
y[i+3] = x[i+3] / c;
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] = x[i] / c;
|
||||
}
|
||||
|
||||
// Fills 16 normally distributed samples into data, interleaved with a
|
||||
// stride of 8, i.e. in order of ([0], [8]), ([1], [9]), ...
|
||||
static void THVector_(interleaved_normal_fill_16)(scalar_t *data,
|
||||
const scalar_t mean,
|
||||
const scalar_t stddev)
|
||||
{
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
|
||||
const scalar_t u2 = data[j + 8];
|
||||
|
||||
const scalar_t radius = sqrt(-2 * log(u1));
|
||||
const scalar_t theta = 2.0f * M_PI * u2;
|
||||
|
||||
data[j] = radius * cos(theta) * stddev + mean;
|
||||
data[j + 8] = radius * std::sin(theta) * stddev + mean;
|
||||
}
|
||||
}
|
||||
|
||||
void THVector_(normal_fill_DEFAULT)(scalar_t *data,
|
||||
int64_t size,
|
||||
at::Generator *generator,
|
||||
const scalar_t mean,
|
||||
const scalar_t stddev)
|
||||
{
|
||||
THAssert(size >= 16 && "Size must be >= 16 for normal fill");
|
||||
auto gen = at::get_generator_or_default<at::CPUGenerator>(generator, at::detail::getDefaultCPUGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
#ifdef TH_REAL_IS_FLOAT
|
||||
at::uniform_real_distribution<float> uniform(0, 1);
|
||||
data[i] = uniform(gen);
|
||||
#else
|
||||
at::uniform_real_distribution<double> uniform(0, 1);
|
||||
data[i] = uniform(gen);
|
||||
#endif
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < size - 15; i += 16) {
|
||||
THVector_(interleaved_normal_fill_16)(data + i, mean, stddev);
|
||||
}
|
||||
|
||||
if (size % 16 != 0) {
|
||||
// Recompute the last 16 values.
|
||||
data = data + size - 16;
|
||||
for (int64_t i = 0; i < 16; ++i) {
|
||||
#ifdef TH_REAL_IS_FLOAT
|
||||
at::uniform_real_distribution<float> uniform(0, 1);
|
||||
data[i] = uniform(gen);
|
||||
#else
|
||||
at::uniform_real_distribution<double> uniform(0, 1);
|
||||
data[i] = uniform(gen);
|
||||
#endif
|
||||
}
|
||||
THVector_(interleaved_normal_fill_16)(data, mean, stddev);
|
||||
}
|
||||
}
|
||||
|
||||
#define VECTOR_IMPLEMENT_FUNCTION(NAME, CFUNC) \
|
||||
void THVector_(NAME)(scalar_t *y, const scalar_t *x, const ptrdiff_t n) \
|
||||
{ \
|
||||
|
@ -39,80 +39,6 @@ void THVector_(fill)(scalar_t *x, const scalar_t c, const ptrdiff_t n) {
|
||||
}
|
||||
|
||||
#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
|
||||
|
||||
static void (*THVector_(cadd_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(cadd_DEFAULT);
|
||||
static FunctionDescription THVector_(cadd_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
#if defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cadd_NEON), SIMDExtension_NEON),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX2)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cadd_AVX2), SIMDExtension_AVX2),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cadd_AVX), SIMDExtension_AVX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
FUNCTION_IMPL(THVector_(cadd_DEFAULT), SIMDExtension_DEFAULT)
|
||||
};
|
||||
void THVector_(cadd)(scalar_t *z, const scalar_t *x, const scalar_t *y, const scalar_t c, const ptrdiff_t n) {
|
||||
THVector_(cadd_DISPATCHPTR)(z, x, y, c, n);
|
||||
}
|
||||
|
||||
static void (*THVector_(adds_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(adds_DEFAULT);
|
||||
static FunctionDescription THVector_(adds_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
#if defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(adds_NEON), SIMDExtension_NEON),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__PPC64__)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(adds_VSX), SIMDExtension_VSX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(adds_AVX), SIMDExtension_AVX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
FUNCTION_IMPL(THVector_(adds_DEFAULT), SIMDExtension_DEFAULT)
|
||||
};
|
||||
// Dispatch stubs that just call the pointers
|
||||
void THVector_(adds)(scalar_t *r_, const scalar_t *t, const scalar_t value, const ptrdiff_t n) {
|
||||
THVector_(adds_DISPATCHPTR)(r_, t, value, n);
|
||||
}
|
||||
|
||||
static void (*THVector_(cmul_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(cmul_DEFAULT);
|
||||
static FunctionDescription THVector_(cmul_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
#if defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cmul_NEON), SIMDExtension_NEON),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cmul_AVX), SIMDExtension_AVX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
FUNCTION_IMPL(THVector_(cmul_DEFAULT), SIMDExtension_DEFAULT)
|
||||
};
|
||||
void THVector_(cmul)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n) {
|
||||
THVector_(cmul_DISPATCHPTR)(z, x, y, n);
|
||||
}
|
||||
|
||||
static void (*THVector_(muls_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(muls_DEFAULT);
|
||||
static FunctionDescription THVector_(muls_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
@ -139,46 +65,6 @@ void THVector_(muls)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptr
|
||||
THVector_(muls_DISPATCHPTR)(y, x, c, n);
|
||||
}
|
||||
|
||||
static void (*THVector_(cdiv_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t *, const ptrdiff_t) = &THVector_(cdiv_DEFAULT);
|
||||
static FunctionDescription THVector_(cdiv_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
#if defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cdiv_NEON), SIMDExtension_NEON),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(cdiv_AVX), SIMDExtension_AVX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
FUNCTION_IMPL(THVector_(cdiv_DEFAULT), SIMDExtension_DEFAULT)
|
||||
};
|
||||
void THVector_(cdiv)(scalar_t *z, const scalar_t *x, const scalar_t *y, const ptrdiff_t n) {
|
||||
THVector_(cdiv_DISPATCHPTR)(z, x, y, n);
|
||||
}
|
||||
|
||||
static void (*THVector_(divs_DISPATCHPTR))(scalar_t *, const scalar_t *, const scalar_t, const ptrdiff_t) = &THVector_(divs_DEFAULT);
|
||||
static FunctionDescription THVector_(divs_DISPATCHTABLE)[] = {
|
||||
#if defined(__NEON__)
|
||||
#if defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(divs_NEON), SIMDExtension_NEON),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
|
||||
FUNCTION_IMPL(THVector_(divs_AVX), SIMDExtension_AVX),
|
||||
#endif
|
||||
#endif
|
||||
|
||||
FUNCTION_IMPL(THVector_(divs_DEFAULT), SIMDExtension_DEFAULT)
|
||||
};
|
||||
void THVector_(divs)(scalar_t *y, const scalar_t *x, const scalar_t c, const ptrdiff_t n) {
|
||||
THVector_(divs_DISPATCHPTR)(y, x, c, n);
|
||||
}
|
||||
|
||||
/*
|
||||
* This struct's constructor initializes the dispatch tables. It simply checks
|
||||
* what SIMD extensions are available, and then walks the dispatch table
|
||||
@ -191,12 +77,7 @@ struct THVector_(startup) {
|
||||
THVector_(startup)() {
|
||||
uint32_t hostSimdExts = detectHostSIMDExtensions();
|
||||
INIT_DISPATCH_PTR(fill);
|
||||
INIT_DISPATCH_PTR(cadd);
|
||||
INIT_DISPATCH_PTR(adds);
|
||||
INIT_DISPATCH_PTR(cmul);
|
||||
INIT_DISPATCH_PTR(muls);
|
||||
INIT_DISPATCH_PTR(cdiv);
|
||||
INIT_DISPATCH_PTR(divs);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -24,59 +24,6 @@ void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-8); i+=8) {
|
||||
YMM0 = _mm256_loadu_pd(x+i);
|
||||
YMM1 = _mm256_loadu_pd(x+i+4);
|
||||
YMM2 = _mm256_loadu_pd(y+i);
|
||||
YMM3 = _mm256_loadu_pd(y+i+4);
|
||||
YMM2 = _mm256_div_pd(YMM0, YMM2);
|
||||
YMM3 = _mm256_div_pd(YMM1, YMM3);
|
||||
_mm256_storeu_pd(z+i, YMM2);
|
||||
_mm256_storeu_pd(z+i+4, YMM3);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
z[i] = x[i] / y[i];
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
|
||||
__m256d YMM0, YMM1;
|
||||
for (i=0; i<=((n)-8); i+=8) {
|
||||
YMM0 = _mm256_loadu_pd(x+i);
|
||||
YMM1 = _mm256_loadu_pd(x+i+4);
|
||||
YMM0 = _mm256_div_pd(YMM0, YMM15);
|
||||
YMM1 = _mm256_div_pd(YMM1, YMM15);
|
||||
_mm256_storeu_pd(y+i, YMM0);
|
||||
_mm256_storeu_pd(y+i+4, YMM1);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
y[i] = x[i] / c;
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-8); i+=8) {
|
||||
YMM0 = _mm256_loadu_pd(x+i);
|
||||
YMM1 = _mm256_loadu_pd(x+i+4);
|
||||
YMM2 = _mm256_loadu_pd(y+i);
|
||||
YMM3 = _mm256_loadu_pd(y+i+4);
|
||||
YMM2 = _mm256_mul_pd(YMM0, YMM2);
|
||||
YMM3 = _mm256_mul_pd(YMM1, YMM3);
|
||||
_mm256_storeu_pd(z+i, YMM2);
|
||||
_mm256_storeu_pd(z+i+4, YMM3);
|
||||
}
|
||||
for (; i<n; i++) {
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
|
||||
@ -94,39 +41,6 @@ void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const p
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
|
||||
__m256d YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-4); i+=4) {
|
||||
YMM0 = _mm256_loadu_pd(y+i);
|
||||
YMM1 = _mm256_loadu_pd(x+i);
|
||||
YMM2 = _mm256_mul_pd(YMM0, YMM15);
|
||||
YMM3 = _mm256_add_pd(YMM1, YMM2);
|
||||
_mm256_storeu_pd(z+i, YMM3);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
z[i] = x[i] + y[i] * c;
|
||||
}
|
||||
}
|
||||
|
||||
void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
|
||||
__m256d YMM0, YMM1;
|
||||
for (i=0; i<=((n)-8); i+=8) {
|
||||
YMM0 = _mm256_loadu_pd(x+i);
|
||||
YMM1 = _mm256_loadu_pd(x+i+4);
|
||||
YMM0 = _mm256_add_pd(YMM0, YMM15);
|
||||
YMM1 = _mm256_add_pd(YMM1, YMM15);
|
||||
_mm256_storeu_pd(y+i, YMM0);
|
||||
_mm256_storeu_pd(y+i+4, YMM1);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
y[i] = x[i] + c;
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
ptrdiff_t off;
|
||||
@ -143,59 +57,6 @@ void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-16); i+=16) {
|
||||
YMM0 = _mm256_loadu_ps(x+i);
|
||||
YMM1 = _mm256_loadu_ps(x+i+8);
|
||||
YMM2 = _mm256_loadu_ps(y+i);
|
||||
YMM3 = _mm256_loadu_ps(y+i+8);
|
||||
YMM2 = _mm256_div_ps(YMM0, YMM2);
|
||||
YMM3 = _mm256_div_ps(YMM1, YMM3);
|
||||
_mm256_storeu_ps(z+i, YMM2);
|
||||
_mm256_storeu_ps(z+i+8, YMM3);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
z[i] = x[i] / y[i];
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_divs_AVX(float *y, const float *x, const float c, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
|
||||
__m256 YMM0, YMM1;
|
||||
for (i=0; i<=((n)-16); i+=16) {
|
||||
YMM0 = _mm256_loadu_ps(x+i);
|
||||
YMM1 = _mm256_loadu_ps(x+i+8);
|
||||
YMM0 = _mm256_div_ps(YMM0, YMM15);
|
||||
YMM1 = _mm256_div_ps(YMM1, YMM15);
|
||||
_mm256_storeu_ps(y+i, YMM0);
|
||||
_mm256_storeu_ps(y+i+8, YMM1);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
y[i] = x[i] / c;
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-16); i+=16) {
|
||||
YMM0 = _mm256_loadu_ps(x+i);
|
||||
YMM1 = _mm256_loadu_ps(x+i+8);
|
||||
YMM2 = _mm256_loadu_ps(y+i);
|
||||
YMM3 = _mm256_loadu_ps(y+i+8);
|
||||
YMM2 = _mm256_mul_ps(YMM0, YMM2);
|
||||
YMM3 = _mm256_mul_ps(YMM1, YMM3);
|
||||
_mm256_storeu_ps(z+i, YMM2);
|
||||
_mm256_storeu_ps(z+i+8, YMM3);
|
||||
}
|
||||
for (; i<n; i++) {
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
|
||||
@ -213,37 +74,4 @@ void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdi
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
|
||||
__m256 YMM0, YMM1, YMM2, YMM3;
|
||||
for (i=0; i<=((n)-8); i+=8) {
|
||||
YMM0 = _mm256_loadu_ps(y+i);
|
||||
YMM1 = _mm256_loadu_ps(x+i);
|
||||
YMM2 = _mm256_mul_ps(YMM0, YMM15);
|
||||
YMM3 = _mm256_add_ps(YMM1, YMM2);
|
||||
_mm256_storeu_ps(z+i, YMM3);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
z[i] = x[i] + y[i] * c;
|
||||
}
|
||||
}
|
||||
|
||||
void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n) {
|
||||
ptrdiff_t i;
|
||||
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
|
||||
__m256 YMM0, YMM1;
|
||||
for (i=0; i<=((n)-16); i+=16) {
|
||||
YMM0 = _mm256_loadu_ps(x+i);
|
||||
YMM1 = _mm256_loadu_ps(x+i+8);
|
||||
YMM0 = _mm256_add_ps(YMM0, YMM15);
|
||||
YMM1 = _mm256_add_ps(YMM1, YMM15);
|
||||
_mm256_storeu_ps(y+i, YMM0);
|
||||
_mm256_storeu_ps(y+i+8, YMM1);
|
||||
}
|
||||
for (; i<(n); i++) {
|
||||
y[i] = x[i] + c;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // defined(__AVX__)
|
||||
|
@ -14,21 +14,6 @@ static void THFloatVector_fill_NEON(float *x, const float c, const ptrdiff_t n)
|
||||
|
||||
}
|
||||
|
||||
static void THFloatVector_cmul_NEON(float *z, const float *x, const float* y, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
for(; i < n-4; i += 4)
|
||||
{
|
||||
z[i] = x[i] * y[i];
|
||||
z[i+1] = x[i+1] * y[i+1];
|
||||
z[i+2] = x[i+2] * y[i+2];
|
||||
z[i+3] = x[i+3] * y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] * y[i];
|
||||
}
|
||||
|
||||
static void THFloatVector_muls_NEON(float *y, const float *x, const float c, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
@ -43,63 +28,3 @@ static void THFloatVector_muls_NEON(float *y, const float *x, const float c, con
|
||||
for(; i < n; i++)
|
||||
y[i] = x[i] * c;
|
||||
}
|
||||
|
||||
static void THFloatVector_cadd_NEON(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
for(;i < n-4; i += 4)
|
||||
{
|
||||
z[i] = x[i] + c * y[i];
|
||||
z[i+1] = x[i+1] + c * y[i+1];
|
||||
z[i+2] = x[i+2] + c * y[i+2];
|
||||
z[i+3] = x[i+3] + c * y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] + c * y[i];
|
||||
}
|
||||
|
||||
static void THFloatVector_adds_NEON(float *y, const float *x, const float c, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
for(;i < n-4; i += 4)
|
||||
{
|
||||
y[i] = x[i] + c;
|
||||
y[i+1] = x[i+1] + c;
|
||||
y[i+2] = x[i+2] + c;
|
||||
y[i+3] = x[i+3] + c;
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] = x[i] + c;
|
||||
}
|
||||
|
||||
static void THFloatVector_cdiv_NEON(float *z, const float *x, const float *y, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
for(;i < n-4; i += 4)
|
||||
{
|
||||
z[i] = x[i] / y[i];
|
||||
z[i+1] = x[i+1] / y[i+1];
|
||||
z[i+2] = x[i+2] / y[i+2];
|
||||
z[i+3] = x[i+3] / y[i+3];
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
z[i] = x[i] / y[i];
|
||||
}
|
||||
|
||||
static void THFloatVector_divs_NEON(float *y, const float *x, const float c, const ptrdiff_t n) {
|
||||
int64_t i = 0;
|
||||
|
||||
for(;i < n-4; i += 4)
|
||||
{
|
||||
y[i] = x[i] / c;
|
||||
y[i+1] = x[i+1] / c;
|
||||
y[i+2] = x[i+2] / c;
|
||||
y[i+3] = x[i+3] / c;
|
||||
}
|
||||
|
||||
for(; i < n; i++)
|
||||
y[i] = x[i] / c;
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,7 @@
|
||||
# THNN
|
||||
# THCUNN
|
||||
|
||||
THNN is a library that gathers nn's C implementations of neural network modules. It's entirely free of Lua dependency and therefore can be used in any application that has a C FFI. Please note that it only contains quite low level functions; most users will want to use ATen, which provides a C++ wrapper around these functions.
|
||||
THCUNN is a library that gathers nn's C implementations of neural network modules. It's entirely free of Lua dependency and therefore can be used in any application that has a C FFI. Please note that it only contains quite low level functions; most users will want to use ATen, which provides a C++ wrapper around these functions.
|
||||
|
||||
There is also a CUDA counterpart of THNN, THCUNN.
|
||||
|
||||
Looking to add an implementation? Consider writing an ATen native function
|
||||
instead! See [../ATen/native](../ATen/native).
|
||||
@ -14,7 +13,7 @@ instead! See [../ATen/native](../ATen/native).
|
||||
|
||||
## API
|
||||
|
||||
THNN is a purely functional library. It provides 2-3 functions for each module, that perform the most important operations:
|
||||
THCUNN is a purely functional library. It provides 2-3 functions for each module, that perform the most important operations:
|
||||
|
||||
* **updateOutput** - applies the module to an input
|
||||
* **updateGradInput** - accepts gradient w.r.t. output and previous module input, and computes a gradient w.r.t. that input
|
@ -1,10 +1,10 @@
|
||||
# API docs
|
||||
|
||||
This document describes the conventions behind the THNN API.
|
||||
This document describes the conventions behind the THCUNN API.
|
||||
|
||||
### The API
|
||||
|
||||
All functions provided by THNN are stored in `aten/src/THNN/generic/THNN.h`.
|
||||
All functions provided by THCUNN are stored in `aten/src/THCUNN/generic/THCUNN.h`.
|
||||
Look at this file.
|
||||
|
||||
### Note on function names
|
||||
@ -14,7 +14,7 @@ Please remember, that because C doesn't support function overloading, functions
|
||||
* `void THNN_FloatAbs_updateOutput(...)`
|
||||
* `void THNN_DoubleAbs_updateOutput(...)`
|
||||
|
||||
In these docs such function will be referred to as `void THNN_Abs_updateOutput(...)`, and it's up to developer to add a type prefix. `real` is an alias for that type.
|
||||
In these docs such function will be referred to as `void THCUNN_Abs_updateOutput(...)`, and it's up to developer to add a type prefix. `real` is an alias for that type.
|
||||
|
||||
### Argument types
|
||||
|
||||
@ -24,4 +24,3 @@ Some arguments have additional tags placed in square brackets in their header de
|
||||
* **[OPTIONAL]** - This argument is optional and can be safely set to NULL
|
||||
* **[BUFFER]** - A buffer. `updateGradInput` and `accGradParameters` should get the same buffers that were used in `updateOutput` call.
|
||||
* **[MODIFIED]** - Some functions accept an `inplace` flag. If set to true, this argument might be modified (in addition to the output).
|
||||
|
@ -16,12 +16,16 @@ accGradParameters: state, input, gradOutput, [gradWeight], [gradBias], ...
|
||||
|
||||
e.g.
|
||||
```C
|
||||
void THNN_(HardShrink_updateGradInput)(
|
||||
THNNState* state,
|
||||
THTensor *input,
|
||||
THTensor *gradOutput,
|
||||
THTensor *gradInput,
|
||||
real lambda)
|
||||
void THNN_(ClassNLLCriterion_updateGradInput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCIndexTensor *target,
|
||||
THCTensor *gradOutput,
|
||||
THCTensor *gradInput,
|
||||
int64_t reduction,
|
||||
THCTensor *weights,
|
||||
THCTensor *total_weight,
|
||||
int64_t ignore_index)
|
||||
```
|
||||
|
||||
### Criterions
|
||||
@ -34,23 +38,24 @@ e.g.
|
||||
|
||||
```C
|
||||
void THNN_(ClassNLLCriterion_updateOutput)(
|
||||
THNNState* state,
|
||||
THTensor *input,
|
||||
THLongTensor *target,
|
||||
THTensor *output,
|
||||
THTensor *weights,
|
||||
THTensor *total_weight,
|
||||
bool sizeAverage)
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCIndexTensor *target,
|
||||
THCTensor *output,
|
||||
int64_t reduction,
|
||||
THCTensor *weights,
|
||||
THCTensor *total_weight,
|
||||
int64_t ignore_index)
|
||||
```
|
||||
|
||||
## Code style guide
|
||||
|
||||
```C
|
||||
void THNN_Linear_updateOutput(
|
||||
THTensor *input,
|
||||
THTensor *output,
|
||||
THTensor *weight,
|
||||
THTensor *bias);
|
||||
void THNN_(GatedLinear_updateOutput)(
|
||||
THCState *state,
|
||||
THCTensor *input,
|
||||
THCTensor *output,
|
||||
int dim)
|
||||
//<- 10 ->
|
||||
```
|
||||
|
@ -1,4 +0,0 @@
|
||||
set(ATen_CPU_SRCS ${ATen_CPU_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/init.cpp
|
||||
PARENT_SCOPE)
|
||||
INSTALL(FILES generic/THNN.h DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THNN/generic")
|
@ -1,24 +0,0 @@
|
||||
#ifndef THNN_H
|
||||
#define THNN_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <TH/TH.h>
|
||||
|
||||
#define THNN_(NAME) TH_CONCAT_3(THNN_, Real, NAME)
|
||||
|
||||
#define THIndexTensor THLongTensor
|
||||
#define THIndexTensor_(NAME) THLongTensor_ ## NAME
|
||||
|
||||
typedef int64_t THIndex_t;
|
||||
typedef void THNNState;
|
||||
|
||||
#include <THNN/generic/THNN.h>
|
||||
#include <THGenerateFloatTypes.h>
|
||||
|
||||
#include <THNN/generic/THNN.h>
|
||||
#include <THGenerateLongType.h>
|
||||
|
||||
#include <THNN/generic/THNN.h>
|
||||
#include <THGenerateBFloat16Type.h>
|
||||
|
||||
#endif
|
@ -1,57 +0,0 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "THNN/generic/GatedLinearUnit.c"
|
||||
#else
|
||||
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
|
||||
void THNN_(GatedLinear_updateOutput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *output,
|
||||
int dim)
|
||||
{
|
||||
TORCH_INTERNAL_ASSERT(false, "GatedLinear_updateOutput called, but this is just " \
|
||||
"a stub for nn.yaml parsing");
|
||||
}
|
||||
|
||||
void THNN_(GatedLinear_updateGradInput)(
|
||||
THNNState *state,
|
||||
THTensor *input,
|
||||
THTensor *gradOutput,
|
||||
THTensor *gradInput,
|
||||
int dim)
|
||||
{
|
||||
dim = at::maybe_wrap_dim(dim, input);
|
||||
// set up tensors
|
||||
const int64_t nIn = THTensor_(size)(input, dim);
|
||||
THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
|
||||
dim, nIn);
|
||||
|
||||
THTensor_(resizeAs)(gradInput, input);
|
||||
const int64_t inputSize = THTensor_(size)(input, dim) / 2;
|
||||
THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
|
||||
THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
|
||||
THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize);
|
||||
THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize);
|
||||
|
||||
at::Tensor gradInputfirstHalf_wrap = THTensor_wrap(gradInputfirstHalf);
|
||||
at::Tensor secondHalf_wrap = THTensor_wrap(secondHalf);
|
||||
at::native::sigmoid_out(gradInputfirstHalf_wrap, secondHalf_wrap);
|
||||
|
||||
TH_TENSOR_APPLY2(scalar_t, gradInputsecondHalf, scalar_t, gradInputfirstHalf,
|
||||
scalar_t z = *gradInputfirstHalf_data;
|
||||
*gradInputsecondHalf_data = (1. - z) * z;
|
||||
);
|
||||
|
||||
THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput);
|
||||
|
||||
THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput);
|
||||
THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf);
|
||||
|
||||
c10::raw::intrusive_ptr::decref(firstHalf);
|
||||
c10::raw::intrusive_ptr::decref(secondHalf);
|
||||
c10::raw::intrusive_ptr::decref(gradInputfirstHalf);
|
||||
c10::raw::intrusive_ptr::decref(gradInputsecondHalf);
|
||||
}
|
||||
|
||||
#endif
|
@ -1,24 +0,0 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "THNN/generic/THNN.h"
|
||||
#else
|
||||
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/DistributionsHelper.h>
|
||||
|
||||
#if !defined(TH_REAL_IS_LONG)
|
||||
|
||||
TH_API void THNN_(GatedLinear_updateOutput)(
|
||||
THNNState *state, // library's state
|
||||
THTensor *input, // input tensor
|
||||
THTensor *output, // [OUT] output tensor, half size of input along dimension dim
|
||||
int dim); // dimension for halving operation
|
||||
TH_API void THNN_(GatedLinear_updateGradInput)(
|
||||
THNNState *state, // library's state
|
||||
THTensor *input, // input tensor
|
||||
THTensor *gradOutput, // gradient w.r.t module's output
|
||||
THTensor *gradInput, // [OUT] gradient w.r.t input
|
||||
int dim); // dimension for halving operation
|
||||
|
||||
#endif
|
||||
#endif
|
@ -1,65 +0,0 @@
|
||||
#include <TH/TH.h>
|
||||
#include <THNN/THNN.h>
|
||||
|
||||
#include <TH/THTensor.hpp>
|
||||
#include <cmath>
|
||||
|
||||
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
|
||||
#define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME)
|
||||
|
||||
#define THNN_CHECK_SHAPE(I1, I2) \
|
||||
if (I1 != NULL && I2 != NULL && !THTensor_(isSameSizeAs)(I1, I2)) \
|
||||
{ \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(I1); \
|
||||
THDescBuff s2 = THTensor_(sizeDesc)(I2); \
|
||||
THError(#I1 " and " #I2 " shapes do not match: " \
|
||||
#I1 " %s, " #I2 " %s", s1.str, s2.str); \
|
||||
}
|
||||
|
||||
#define THNN_CHECK_SHAPE_INDICES(I1, I2) \
|
||||
if (I1 != NULL && I2 != NULL && !I1->sizes().equals(I2->sizes())) \
|
||||
{ \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(I1); \
|
||||
THDescBuff s2 = THLongTensor_sizeDesc(I2); \
|
||||
THError(#I1 " and " #I2 " shapes do not match: " \
|
||||
#I1 " %s, " #I2 " %s", s1.str, s2.str); \
|
||||
}
|
||||
|
||||
#define THNN_CHECK_NELEMENT(I1, I2) \
|
||||
if (I1 != NULL && I2 != NULL ) { \
|
||||
ptrdiff_t n1 = THTensor_(nElement)(I1); \
|
||||
ptrdiff_t n2 = THTensor_(nElement)(I2); \
|
||||
if (n1 != n2) \
|
||||
{ \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(I1); \
|
||||
THDescBuff s2 = THTensor_(sizeDesc)(I2); \
|
||||
THError(#I1 " and " #I2 " have different number of elements: " \
|
||||
#I1 "%s has %ld elements, while " \
|
||||
#I2 "%s has %ld elements", s1.str, n1, s2.str, n2); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define THNN_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THTensor_(nDimensionLegacyNoScalars)(T) != DIM || \
|
||||
THTensor_sizeLegacyNoScalars(T, DIM_SIZE) != SIZE) { \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
" but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \
|
||||
}
|
||||
|
||||
#define THNN_CHECK_DIM_SIZE_INDICES(T, DIM, DIM_SIZE, SIZE) \
|
||||
if (THIndexTensor_(nDimensionLegacyNoScalars)(T) != DIM || \
|
||||
THTensor_sizeLegacyNoScalars(T, DIM_SIZE) != SIZE) { \
|
||||
THDescBuff s1 = THIndexTensor_(sizeDesc)(T); \
|
||||
THError("Need " #T " of dimension %d and " #T ".size[%d] == %d" \
|
||||
" but got " #T " to be of shape: %s", DIM, DIM_SIZE, SIZE, s1.str); \
|
||||
}
|
||||
|
||||
#define THNN_ARGCHECK(COND, ARG, T, FORMAT) \
|
||||
if (!(COND)) { \
|
||||
THDescBuff s1 = THTensor_(sizeDesc)(T); \
|
||||
THArgCheck(COND, ARG, FORMAT, s1.str); \
|
||||
}
|
||||
|
||||
#include <THNN/generic/GatedLinearUnit.c>
|
||||
#include <TH/THGenerateFloatTypes.h>
|
@ -295,7 +295,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
$<$<BOOL:${SELECTED_OP_LIST}>:--selected-op-list-path="${SELECTED_OP_LIST}">
|
||||
DEPENDS
|
||||
"${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/THNN/generic/THNN.h"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.h"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/Functions.h"
|
||||
|
@ -140,7 +140,6 @@ if (INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
set(cwrap_files
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/Declarations.cwrap
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/THNN/generic/THNN.h
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/THCUNN/generic/THCUNN.h
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/nn.yaml
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml)
|
||||
|
@ -19,7 +19,6 @@ python aten/src/ATen/gen.py \
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THNN/generic/THNN.h \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
|
Reference in New Issue
Block a user