From f2900420da9cd96465e84071b6fe1f7c110ed527 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 15 Jun 2023 16:48:25 +0000 Subject: [PATCH] fix missing-prototypes warnings in torch_cpu (Part 6) (#101845) This PR fixes more missing-prototypes violations in the torch_cpu source following PRs #100053, #100147, #100245, #100849 and #101788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101845 Approved by: https://github.com/albanD --- aten/src/ATen/core/ivalue.cpp | 10 +-- aten/src/ATen/native/TensorFactories.cpp | 13 ++-- aten/src/ATen/native/cpu/UnaryOpsKernel.cpp | 64 ++++++++----------- aten/src/ATen/native/mkldnn/Conv.cpp | 14 ++-- aten/src/ATen/native/mkldnn/Normalization.cpp | 3 +- aten/src/ATen/native/mkldnn/TensorShape.cpp | 2 +- .../src/ATen/native/prim_native_functions.cpp | 1 + .../src/ATen/native/quantized/cpu/Pooling.cpp | 1 + .../native/sparse/SparseCsrTensorMath.cpp | 15 +---- .../src/ATen/native/sparse/SparseUnaryOps.cpp | 5 ++ aten/src/ATen/native/xnnpack/Activation.cpp | 3 +- .../ATen/native/xnnpack/AveragePooling.cpp | 3 +- .../ATen/native/xnnpack/ChannelShuffle.cpp | 1 + aten/src/ATen/native/xnnpack/Convolution.cpp | 3 +- aten/src/ATen/native/xnnpack/Convolution.h | 9 +++ aten/src/ATen/native/xnnpack/Linear.h | 10 +++ aten/src/ATen/native/xnnpack/MaxPooling.cpp | 3 +- .../templates/CompositeViewCopyKernels.cpp | 2 + aten/src/ATen/vulkan/Context.h | 3 +- caffe2/utils/threadpool/ThreadPool.h | 1 + torch/csrc/autograd/VariableTypeManual.cpp | 24 ++++--- torch/csrc/autograd/profiler_kineto.cpp | 1 + torch/csrc/distributed/c10d/Ops.cpp | 4 ++ torch/csrc/jit/mobile/nnc/aot_compiler.cpp | 33 +++++----- torch/csrc/profiler/unwind/unwind.cpp | 3 +- torchgen/gen_backend_stubs.py | 1 + 26 files changed, 132 insertions(+), 100 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 8ef0ceb29607..9c4fba468415 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -64,6 +64,11 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) { _fastEqualsForContainer); } +std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) { + out << v.qualifiedClassName() << "." << v.name(); + return out; +} + bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) { return lhs.name() == rhs.name() && *rhs.type() == *lhs.type(); } @@ -763,11 +768,6 @@ IValueComparator getGreaterThanComparator(const IValue& v) { }; } -std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) { - out << v.qualifiedClassName() << "." << v.name(); - return out; -} - std::ostream& operator<<(std::ostream & out, const IValue & v) { auto formatter = [&](std::ostream& out, const IValue& v) { out << v; diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index a4b2a6716540..8c5127c088fc 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -171,7 +171,7 @@ Tensor arange( return at::arange_out(result, start, end, step); } -Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) { +static Tensor& arange_start_out(const Scalar& start, const Scalar& end, Tensor& result) { return at::arange_out(result, start, end, /*step=*/1); } @@ -179,7 +179,7 @@ Tensor& arange_out(const Scalar& end, Tensor& result) { return at::arange_out(result, /*start=*/0, end, /*step=*/1); } -Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) { +static Tensor& arange_out(Tensor& result, const Scalar& start, const Scalar& end) { return at::arange_out(result, start, end, /*step=*/1); } @@ -189,14 +189,14 @@ Tensor _dim_arange(const Tensor& like, int64_t dim) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ complex / polar ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -void complex_check_floating(const Tensor& a, const Tensor& b) { +static void complex_check_floating(const Tensor& a, const Tensor& b) { TORCH_CHECK((a.scalar_type() == kFloat || a.scalar_type() == kDouble || a.scalar_type() == kHalf) && (b.scalar_type() == kFloat || b.scalar_type() == kDouble || b.scalar_type() == kHalf), "Expected both inputs to be Half, Float or Double tensors but got ", a.scalar_type(), " and ", b.scalar_type()); } -void complex_check_dtype( +static void complex_check_dtype( const Tensor& result, const Tensor& a, const Tensor& b) { @@ -352,7 +352,12 @@ Tensor& empty_out(IntArrayRef size, return self.to(ScalarType::n, non_blocking); \ } +// Some scalar types in CAST_OP have no declarations, they may be unused in Pytorch. +// But we keep them and ignore the warning here until verified in the future. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-prototypes" AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DEFINE_CAST_OP) +#pragma clang diagnostic pop #undef DEFINE_CAST_OP diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 57190f9ae092..6a54be7ace79 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -691,7 +691,7 @@ static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) { #define IMPLEMENT_FLOAT_KERNEL(op) \ inline namespace CPU_CAPABILITY { \ - void op##_kernel(TensorIteratorBase& iter) { \ + static void op##_kernel(TensorIteratorBase& iter) { \ TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \ constexpr int64_t grain_size = 2048; \ @@ -715,6 +715,19 @@ static void modified_bessel_k1_kernel(TensorIteratorBase& iterator) { } \ REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel) +#define STATIC_IMPLEMENT_COMPLEX_KERNEL(op) \ + inline namespace CPU_CAPABILITY { \ + static void op##_kernel(TensorIteratorBase& iter) { \ + TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); \ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), #op "_vml_cpu", [&]() { \ + constexpr int64_t grain_size = 2048; \ + iter.for_each(IMPLEMENT_ITERATOR_LAMBDA(op), grain_size); \ + }); \ + iter.cast_outputs(); \ + } \ + } \ + REGISTER_DISPATCH(op##_stub, &CPU_CAPABILITY::op##_kernel) + } // CPU_CAPABILITY namespace REGISTER_DISPATCH(rsqrt_stub, &CPU_CAPABILITY::rsqrt_kernel); @@ -761,51 +774,28 @@ REGISTER_DISPATCH(special_modified_bessel_i1_stub, &CPU_CAPABILITY::modified_bes REGISTER_DISPATCH(special_modified_bessel_k0_stub, &CPU_CAPABILITY::modified_bessel_k0_kernel); REGISTER_DISPATCH(special_modified_bessel_k1_stub, &CPU_CAPABILITY::modified_bessel_k1_kernel); -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(acos) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(asin) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(atan) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(acos) +STATIC_IMPLEMENT_COMPLEX_KERNEL(asin) +STATIC_IMPLEMENT_COMPLEX_KERNEL(atan) IMPLEMENT_FLOAT_KERNEL(ceil) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(cos) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(cos) IMPLEMENT_FLOAT_KERNEL(erf) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_FLOAT_KERNEL(erfc) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_FLOAT_KERNEL(erfinv) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(exp) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(expm1) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(exp) +STATIC_IMPLEMENT_COMPLEX_KERNEL(expm1) IMPLEMENT_FLOAT_KERNEL(floor) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(log) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(log10) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(log1p) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(log2) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(log) +STATIC_IMPLEMENT_COMPLEX_KERNEL(log10) +STATIC_IMPLEMENT_COMPLEX_KERNEL(log1p) +STATIC_IMPLEMENT_COMPLEX_KERNEL(log2) IMPLEMENT_FLOAT_KERNEL(i0) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_FLOAT_KERNEL(round) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(sin) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(sin) IMPLEMENT_COMPLEX_KERNEL(sqrt) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(tan) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_COMPLEX_KERNEL(tanh) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +STATIC_IMPLEMENT_COMPLEX_KERNEL(tan) +STATIC_IMPLEMENT_COMPLEX_KERNEL(tanh) IMPLEMENT_FLOAT_KERNEL(trunc) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_FLOAT_KERNEL(lgamma) } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 5cee9303f181..19f5c59e0843 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -27,19 +27,19 @@ Tensor mkldnn_convolution( TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support"); } -Tensor mkldnn_convolution_backward_input( +static Tensor mkldnn_convolution_backward_input( IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_backward_weights( +static std::tuple mkldnn_convolution_backward_weights( IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_backward( +static std::tuple mkldnn_convolution_backward( const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support"); @@ -47,27 +47,27 @@ std::tuple mkldnn_convolution_backward( REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_backward_stub); -Tensor mkldnn_convolution_transpose( +static Tensor mkldnn_convolution_transpose( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { TORCH_CHECK(false, "mkldnn_convolution_transpose: ATen not compiled with MKLDNN support"); } -Tensor mkldnn_convolution_transpose_backward_input( +static Tensor mkldnn_convolution_transpose_backward_input( IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_input: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_transpose_backward_weights( +static std::tuple mkldnn_convolution_transpose_backward_weights( IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { TORCH_CHECK(false, "mkldnn_convolution_transpose_backward_weights: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_transpose_backward( +static std::tuple mkldnn_convolution_transpose_backward( const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index d0171865fac6..108ce354ec9b 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -6,6 +6,7 @@ #ifndef AT_PER_OPERATOR_HEADERS #include #else +#include #include #include #include @@ -34,7 +35,7 @@ std::tuple mkldnn_batch_norm_backward( TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( +static std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( const Tensor& input, IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, bool inplace) { diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index a665d36b3069..18d82ca676eb 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -106,7 +106,7 @@ namespace at { namespace native { -Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) { +static Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) { return mkldnn_view(self, C10_AS_INTARRAYREF_SLOW(size)); } diff --git a/aten/src/ATen/native/prim_native_functions.cpp b/aten/src/ATen/native/prim_native_functions.cpp index 8db2bbbd1e82..e41535cf9505 100644 --- a/aten/src/ATen/native/prim_native_functions.cpp +++ b/aten/src/ATen/native/prim_native_functions.cpp @@ -5,6 +5,7 @@ #include #else #include +#include #include #endif diff --git a/aten/src/ATen/native/quantized/cpu/Pooling.cpp b/aten/src/ATen/native/quantized/cpu/Pooling.cpp index a3c541d23dbb..b422ad7ef62b 100644 --- a/aten/src/ATen/native/quantized/cpu/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cpu/Pooling.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #endif #include diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index f7de8f75b074..e2b2da549712 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -464,7 +465,7 @@ CREATE_UNARY_UFUNC(tan); CREATE_UNARY_UFUNC(tanh); CREATE_UNARY_UFUNC(trunc); CREATE_UNARY_UFUNC(conj_physical); -CREATE_UNARY_UFUNC(relu); +static CREATE_UNARY_UFUNC(relu); // With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads // to unresolved overload. @@ -776,18 +777,6 @@ Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) { 1.0); } -Tensor _sparse_csr_addmm( - const Tensor& t, - const SparseCsrTensor& sparse, - const Tensor& dense, - const Scalar& beta, - const Scalar& alpha) { - // _sparse_addmm forward is functionally equivalent to addmm; it's - // just the backward that is different. This technically does an - // unnecessary redispatch, I was too lazy to make it not do that - return at::addmm(t, sparse, dense, beta, alpha); -} - // Functions for element-wise addition. Tensor add_sparse_csr( const Tensor& self, diff --git a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp index 4946c20d243a..6d9c6e23a461 100644 --- a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp +++ b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp @@ -188,7 +188,12 @@ COALESCED_UNARY_UFUNC(sqrt); COALESCED_UNARY_UFUNC(tan); COALESCED_UNARY_UFUNC(tanh); COALESCED_UNARY_UFUNC(trunc); +// relu function has no declaration, it may be unused in Pytorch. +// But we keep it and ignore the warning here until verified in the future. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-prototypes" COALESCED_UNARY_UFUNC(relu); +#pragma clang diagnostic pop COALESCED_UNARY_UFUNC_NO_INPLACE(signbit); COALESCED_UNARY_UFUNC_NO_INPLACE(isneginf); diff --git a/aten/src/ATen/native/xnnpack/Activation.cpp b/aten/src/ATen/native/xnnpack/Activation.cpp index 664be585be19..a9791d478cee 100644 --- a/aten/src/ATen/native/xnnpack/Activation.cpp +++ b/aten/src/ATen/native/xnnpack/Activation.cpp @@ -1,6 +1,7 @@ #ifdef USE_XNNPACK #include +#include #include namespace at { @@ -18,7 +19,7 @@ bool use_hardswish( true; } -Tensor& hardswish_impl(Tensor& input, Tensor& output) { +static Tensor& hardswish_impl(Tensor& input, Tensor& output) { using namespace internal; xnn_operator_t hardswish_op{}; diff --git a/aten/src/ATen/native/xnnpack/AveragePooling.cpp b/aten/src/ATen/native/xnnpack/AveragePooling.cpp index 0599b3fe986b..401f6fa5fea9 100644 --- a/aten/src/ATen/native/xnnpack/AveragePooling.cpp +++ b/aten/src/ATen/native/xnnpack/AveragePooling.cpp @@ -1,7 +1,8 @@ #ifdef USE_XNNPACK -#include #include +#include +#include #include namespace at { diff --git a/aten/src/ATen/native/xnnpack/ChannelShuffle.cpp b/aten/src/ATen/native/xnnpack/ChannelShuffle.cpp index 8b20eca3aa96..b8eaabbc425c 100644 --- a/aten/src/ATen/native/xnnpack/ChannelShuffle.cpp +++ b/aten/src/ATen/native/xnnpack/ChannelShuffle.cpp @@ -1,6 +1,7 @@ #ifdef USE_XNNPACK #include +#include #include namespace at { diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index cf9d180b2153..1907b38b1743 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -2,11 +2,12 @@ #include -#include #include #include #include +#include #include +#include #include namespace at { diff --git a/aten/src/ATen/native/xnnpack/Convolution.h b/aten/src/ATen/native/xnnpack/Convolution.h index 3b1ccdfe4c59..a0e9dc54c4d6 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.h +++ b/aten/src/ATen/native/xnnpack/Convolution.h @@ -62,6 +62,15 @@ Tensor run(ContextConv2D& context, const Tensor& input); } // namespace convolution2d } // namespace internal + +Tensor convolution2d( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups); } // namespace xnnpack } // namespace native } // namespace at diff --git a/aten/src/ATen/native/xnnpack/Linear.h b/aten/src/ATen/native/xnnpack/Linear.h index 65dbe1b661aa..9dda8fa9cb61 100644 --- a/aten/src/ATen/native/xnnpack/Linear.h +++ b/aten/src/ATen/native/xnnpack/Linear.h @@ -32,6 +32,16 @@ ContextLinear create( Tensor run(const ContextLinear& context, const Tensor& input); } // namespace linear } // namespace internal + +bool use_linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias); + +Tensor linear( + const Tensor& input, + const Tensor& weight, + const Tensor& bias); } // namespace xnnpack } // namespace native } // namespace at diff --git a/aten/src/ATen/native/xnnpack/MaxPooling.cpp b/aten/src/ATen/native/xnnpack/MaxPooling.cpp index f2e31a3e5606..980c0644d0d3 100644 --- a/aten/src/ATen/native/xnnpack/MaxPooling.cpp +++ b/aten/src/ATen/native/xnnpack/MaxPooling.cpp @@ -1,8 +1,9 @@ #ifdef USE_XNNPACK #include -#include #include +#include +#include #include namespace at { diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp index 7548d7c1a3a8..47097d7aa432 100644 --- a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -18,6 +18,7 @@ namespace native { // This file contains a number of kernels for aten functions that are fully code-generated. // TODO: rename this file to something more generic. +namespace { at::Tensor clone_arg(const at::Tensor& t) { return t.clone(); } @@ -59,6 +60,7 @@ void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) { at::native::resize_output(dst[i], src[i].sizes()); } } +} ${CompositeViewCopyKernel_Definitions} diff --git a/aten/src/ATen/vulkan/Context.h b/aten/src/ATen/vulkan/Context.h index b3485623a829..521bc3dd5761 100644 --- a/aten/src/ATen/vulkan/Context.h +++ b/aten/src/ATen/vulkan/Context.h @@ -22,9 +22,10 @@ class VulkanImplRegistrar { }; at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src); +} // namespace vulkan + namespace native { bool is_vulkan_available(); }// namespace native -} // namespace vulkan } // namespace at diff --git a/caffe2/utils/threadpool/ThreadPool.h b/caffe2/utils/threadpool/ThreadPool.h index af21b6c14c95..dbaefc51389f 100644 --- a/caffe2/utils/threadpool/ThreadPool.h +++ b/caffe2/utils/threadpool/ThreadPool.h @@ -62,6 +62,7 @@ class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool { size_t minWorkSize_; }; +size_t getDefaultNumThreads(); } // namespace caffe2 #endif // CAFFE2_UTILS_THREADPOOL_H_ diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 437c41aefd2b..b42d22d0fa95 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -22,7 +23,7 @@ namespace torch { namespace autograd { namespace VariableType { -std::vector allTypesForBackends( +static std::vector allTypesForBackends( at::ArrayRef backends) { std::vector res; res.reserve(backends.size()); @@ -37,16 +38,16 @@ std::vector allTypesForBackends( return res; } -C10_EXPORT std::vector allCPUTypes() { +std::vector allCPUTypes() { return allTypesForBackends({Backend::CPU, Backend::SparseCPU}); } -C10_EXPORT std::vector allCUDATypes() { +std::vector allCUDATypes() { at::globalContext().lazyInitCUDA(); return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA}); } -C10_EXPORT std::vector allXPUTypes() { +std::vector allXPUTypes() { return allTypesForBackends({Backend::XPU, Backend::SparseXPU}); } @@ -375,7 +376,7 @@ namespace ADInplaceOrView { : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT \ : CreationMeta::NO_GRAD_MODE) -Tensor& copy_( +static Tensor& copy_( c10::DispatchKeySet ks, Tensor& self, const Tensor& src, @@ -389,7 +390,7 @@ Tensor& copy_( return self; } -const Tensor& resize_( +static const Tensor& resize_( c10::DispatchKeySet ks, const Tensor& self, SymIntArrayRef size, @@ -413,7 +414,7 @@ const Tensor& resize_( return self; } -const Tensor& resize_as_( +static const Tensor& resize_as_( c10::DispatchKeySet ks, const Tensor& self, const Tensor& the_template, @@ -438,7 +439,7 @@ const Tensor& resize_as_( return self; } -Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { +static Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { auto out = ([&]() { at::AutoDispatchBelowADInplaceOrView guard; return at::_ops::detach::redispatch( @@ -460,7 +461,10 @@ Tensor detach(c10::DispatchKeySet ks, const Tensor& self) { return result; } -Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { +static Tensor _fw_primal( + c10::DispatchKeySet ks, + const Tensor& self, + int64_t level) { auto tmp = ([&]() { at::AutoDispatchBelowADInplaceOrView guard; return at::alias(self); @@ -484,7 +488,7 @@ Tensor _fw_primal(c10::DispatchKeySet ks, const Tensor& self, int64_t level) { } // NB: This does not redispatch any further -Tensor _make_dual( +static Tensor _make_dual( c10::DispatchKeySet ks, const Tensor& primal, const Tensor& tangent, diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 3ff32f8bd138..30c530bed3ce 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -40,6 +40,7 @@ extern "C" { // This function is needed to avoid superfluous dependency on GNU OpenMP library // when cuPTI is linked statically For more details see // https://github.com/pytorch/pytorch/issues/51026 +__attribute__((weak)) int acc_get_device_type(); __attribute__((weak)) int acc_get_device_type() { throw std::runtime_error( "Dummy implementation of acc_get_device_type is not supposed to be called!"); diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index dc978e958acd..d1fc4d404fa0 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -62,6 +62,9 @@ namespace ops { // Below are ProcessGroup's corresponding ops for each backend. Ops are but // routed through the dispatcher to be dispatched to the appropriate backend. // Currently a no-op as the process group does not have a list of backends. + +namespace { + #define IMPL_SEND(DEV) \ c10::intrusive_ptr send##DEV( \ at::TensorList tensors, \ @@ -425,6 +428,7 @@ void monitored_barrier_CPU( BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, wait_all_ranks); } +} // namespace // register functions to dispatcher namespace { diff --git a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp index 0755df8db763..7f8819780140 100644 --- a/torch/csrc/jit/mobile/nnc/aot_compiler.cpp +++ b/torch/csrc/jit/mobile/nnc/aot_compiler.cpp @@ -30,7 +30,7 @@ namespace jit { namespace mobile { namespace nnc { -std::vector getConstSizes(const BufPtr b) { +static std::vector getConstSizes(const BufPtr b) { std::vector r; for (const auto& dim : b->dims()) { LongImmPtr imm_dim = to(dim); @@ -42,7 +42,7 @@ std::vector getConstSizes(const BufPtr b) { } // Construct input-specs vector from the inputs of the original graph -std::vector toInputSpecs( +static std::vector toInputSpecs( const std::shared_ptr& kernel) { const std::shared_ptr& g = kernel->graph(); std::vector specs; @@ -89,7 +89,7 @@ std::vector toInputSpecs( // If a symbolic shape can be found in several different positions, we // return the first one we find (TODO: maybe we should return all and // verify that they all match at runtime). -std::vector findSymbolicShapePositions( +static std::vector findSymbolicShapePositions( std::shared_ptr kernel) { std::vector res; for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) { @@ -122,7 +122,7 @@ std::vector findSymbolicShapePositions( return res; } -std::unique_ptr compileMethod( +static std::unique_ptr compileMethod( std::shared_ptr kernel, const std::string& method_name, const std::vector>& sizes, @@ -181,7 +181,7 @@ std::unique_ptr compileMethod( return func; } -std::pair, const std::string> aotCompile( +static std::pair, const std::string> aotCompile( const std::string& method_name, std::shared_ptr& g, const std::vector>& sizes, @@ -217,7 +217,7 @@ std::pair, const std::string> aotCompile( return std::make_pair(std::move(func), compiled_assembly); } -void writeOutputLlvmAssembly( +static void writeOutputLlvmAssembly( const std::string& asm_code, const std::string& output_llvm_file_name) { std::ofstream output(output_llvm_file_name); @@ -226,7 +226,7 @@ void writeOutputLlvmAssembly( "The compiled llvm assembly code was saved to ", output_llvm_file_name); } -std::vector split( +static std::vector split( char separator, const std::string& string, bool ignore_empty = true) { @@ -241,7 +241,7 @@ std::vector split( return pieces; } -std::vector> parseInputShapes( +static std::vector> parseInputShapes( const std::string& input_dims_s) { std::vector input_dims_list = split(';', input_dims_s); std::vector> inputs; @@ -257,7 +257,7 @@ std::vector> parseInputShapes( return inputs; } -std::vector parseInputTypes( +static std::vector parseInputTypes( const std::string& input_types_str) { std::vector inputTypes = split(';', input_types_str); std::vector scalarTypes; @@ -277,7 +277,7 @@ std::vector parseInputTypes( return scalarTypes; } -std::vector parseInputMemoryFormats( +static std::vector parseInputMemoryFormats( const std::string& input_memory_format_str) { std::vector memFormatsStr = split(';', input_memory_format_str); std::vector memFormats; @@ -295,7 +295,7 @@ std::vector parseInputMemoryFormats( return memFormats; } -std::vector parseInputDynamicShapes( +static std::vector parseInputDynamicShapes( const std::string& dynamic_dims_s) { std::vector dynamic_dims_list = split(',', dynamic_dims_s); std::vector dynamic_dims; @@ -306,7 +306,7 @@ std::vector parseInputDynamicShapes( return dynamic_dims; } -std::string getNncKernelId( +static std::string getNncKernelId( const std::string& model_name, const std::string& model_version, const std::string& method_name) { @@ -316,7 +316,7 @@ std::string getNncKernelId( version_token; } -std::string getNncKernelFuncName( +static std::string getNncKernelFuncName( const std::string& model_name, const std::string& model_version, const std::string& method_name) { @@ -325,7 +325,8 @@ std::string getNncKernelFuncName( // Preprocess the graph and returns the processed graph and // symbolic values if dynamic input shapes are specified -std::pair, std::vector> preprocessGraphPasses( +static std::pair, std::vector> +preprocessGraphPasses( std::shared_ptr& graph, const std::vector>& example_inputs, const std::vector& dynamic_sizes) { @@ -367,7 +368,7 @@ std::pair, std::vector> preprocessGraphPasses( return std::make_pair(graph, sym_val); } -std::vector> generateExampleInputs( +static std::vector> generateExampleInputs( const std::vector>& inputShapes, const std::vector& inputTypes, const std::vector& inputMemoryFormats) { @@ -382,7 +383,7 @@ std::vector> generateExampleInputs( return example_inputs; } -c10::IValue preprocess( +static c10::IValue preprocess( const torch::jit::Module& mod, const c10::Dict& compile_spec, const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) { diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 6807545b03cf..4ce104e066c7 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -95,7 +95,7 @@ struct LibraryInfo { EHFrameHdr eh_frame_hdr_; }; -const char* process_name() { +static const char* process_name() { static char name[PATH_MAX + 1] = ""; if (*name == '\0') { ssize_t len = readlink("/proc/self/exe", name, PATH_MAX); @@ -267,6 +267,7 @@ struct UnwindCache { static UnwindCache unwind_cache; static std::shared_timed_mutex cache_mutex_; +extern "C" void unwind_c(std::vector* result, int64_t rsp, int64_t rbp); extern "C" void unwind_c(std::vector* result, int64_t rsp, int64_t rbp) { std::shared_lock lock(cache_mutex_); UnwindState state; diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 3ded8b1c093f..7322daa5dc76 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -467,6 +467,7 @@ TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) { else: deferred_template = CodeTemplate( """\ +TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions(); TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() { static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key); $dispatch_registrations_body