[2/N] Use internal linkage in aten C++ files (#151070)

Turn functions and variables into static if they are not used outside the ten cpp files. In some cases, missing header inclusion is added. In other cases, unused functions are removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151070
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2025-04-14 16:07:17 +00:00
committed by PyTorch MergeBot
parent 24b3ab9255
commit eb19f5abab
39 changed files with 130 additions and 167 deletions

View File

@ -10,15 +10,13 @@
#include <mkl.h>
#endif
#if AT_MKLDNN_ENABLED()
#include <ATen/native/mkldnn/IDeepRegistration.h>
#endif
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
namespace at {
#if AT_MKLDNN_ENABLED()
namespace native::mkldnn {
// NOLINTNEXTLINE(misc-use-internal-linkage)
void clear_computation_cache();
} // namespace native::mkldnn
#endif
namespace {
// Number of threads set by the user

View File

@ -849,10 +849,7 @@ namespace at::native {
// linear algebra operations
template<class scalar_t>
void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
template<class scalar_t, class value_t=scalar_t>
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);
static void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
template<> void lapackLu<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, int *ipiv, int *info) {
zgetrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, ipiv, info);

View File

@ -1383,35 +1383,35 @@ Tensor bitwise_right_shift(const Scalar& self, const Tensor& other) {
}
template <typename Stub>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
static Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) {
auto iter = TensorIterator::comparison_op(result, self, other);
stub(iter.device_type(), iter);
return result;
}
template <typename OutImpl>
Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl) {
static Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return out_impl(result, self, other);
}
template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
static Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
return out_impl(self, self, other);
}
template <typename OutImpl>
Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Scalar& other, OutImpl& out_impl) {
static Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Scalar& other, OutImpl& out_impl) {
return out_impl(result, self, wrapped_scalar_tensor(other));
}
template <typename OutImpl>
Tensor comparison_op(const Tensor& self, const Scalar& other, OutImpl& out_impl) {
static Tensor comparison_op(const Tensor& self, const Scalar& other, OutImpl& out_impl) {
return comparison_op(self, wrapped_scalar_tensor(other), out_impl);
}
template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) {
static Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) {
return out_impl(self, self, wrapped_scalar_tensor(other));
}

View File

@ -116,7 +116,7 @@ void fp16_gemv_trans(
fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
void bf16_gemv_trans(
static void bf16_gemv_trans(
const int m,
const int n,
const at::BFloat16 alpha,
@ -146,14 +146,14 @@ void fp16_gemv_notrans(
#endif // defined(__aarch64__) && !defined(C10_MOBILE)
template <typename scalar_t>
bool scal_use_fast_path(
static bool scal_use_fast_path(
[[maybe_unused]] int64_t n,
[[maybe_unused]] int64_t incx) {
return false;
}
template <typename scalar_t>
bool gemv_use_fast_path(
static bool gemv_use_fast_path(
[[maybe_unused]] char trans,
[[maybe_unused]] int64_t m,
[[maybe_unused]] int64_t n,
@ -166,7 +166,7 @@ bool gemv_use_fast_path(
}
template <typename scalar_t>
void scal_fast_path(
static void scal_fast_path(
[[maybe_unused]] int* n,
[[maybe_unused]] scalar_t* a,
[[maybe_unused]] scalar_t* x,
@ -176,7 +176,7 @@ void scal_fast_path(
}
template <typename scalar_t>
void gemv_fast_path(
static void gemv_fast_path(
[[maybe_unused]] const char* trans,
[[maybe_unused]] const int* m,
[[maybe_unused]] const int* n,

View File

@ -554,7 +554,7 @@ using is_blas_library_type = std::integral_constant<bool,
std::is_same_v<scalar_t, c10::complex<float>>>;
template <typename scalar_t>
void gemm_batched_generic(
static void gemm_batched_generic(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
@ -568,7 +568,7 @@ void gemm_batched_generic(
}
template <typename scalar_t>
void gemm_batched(
static void gemm_batched(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
@ -596,7 +596,7 @@ void gemm_batched(
}
template <typename scalar_t>
void gemm_batched_with_stride_generic(
static void gemm_batched_with_stride_generic(
TransposeType transa, TransposeType transb,
int64_t batch_size, int64_t m, int64_t n, int64_t k,
scalar_t alpha,
@ -945,7 +945,7 @@ struct PackKey {
}
};
inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
static inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
if (dtype == ScalarType::Float) {
return dnnl::memory::data_type::f32;
} else if (dtype == ScalarType::BFloat16) {

View File

@ -13,7 +13,7 @@ class Tensor;
namespace native {
template<typename O, typename C>
void _assert_match(const O& original, const C& compared, const std::string& name) {
static void _assert_match(const O& original, const C& compared, const std::string& name) {
if (compared) {
bool equal = (original == compared.value());
if (!equal) {

View File

@ -97,7 +97,7 @@ static bool conv_benchmark_empty_cache = true;
// Check workload to activate fast depthwise FP16 cudnn conv kernels
template <typename T>
bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
static bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
auto w = at::symint::size<T>(input, 3); // same as h
auto ch = at::symint::size<T>(input, 1);
auto bs = at::symint::size<T>(input, 0);
@ -220,7 +220,7 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
// simplified version for cudnn 8.2 and above
template <typename T>
bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
static bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
// 1D conv
if(at::symint::size<T>(input, 2) == 1 && stride == 1){
return true;
@ -640,7 +640,7 @@ REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub)
REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub)
template <typename T>
std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
out << "ConvParams {"
<< " stride = " << IntArrayRef{params.stride}
<< " padding = " << ArrayRef<T>{params.padding}
@ -1203,7 +1203,7 @@ at::Tensor convolution_overrideable(
// a bool indicating whether the bias is defined. This is done to save memory by
// avoiding saving the full bias tensor for backward.
template <typename T>
ConvBackend _select_conv_backend(
static ConvBackend _select_conv_backend(
const Tensor& input,
const Tensor& weight,
const std::optional<Tensor>& bias,

View File

@ -1059,7 +1059,7 @@ static Tensor apply_bag_size_backward(
}
template <typename scalar_t>
void embedding_bag_cpu_max_out(
static void embedding_bag_cpu_max_out(
Tensor* max_indices,
const Tensor& weight,
const Tensor& indices,
@ -1505,7 +1505,7 @@ static std::vector<index_t> compute_counts_uniq(
}
template <typename scalar_t>
void _embedding_bag_dense_backward_cpu_sum_mean(
static void _embedding_bag_dense_backward_cpu_sum_mean(
const Tensor& grad,
const Tensor& indices_,
const Tensor& offset2bag_,
@ -1641,7 +1641,7 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi
}
template<typename scalar_t>
Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
static Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices_,

View File

@ -285,7 +285,7 @@ TORCH_META_FUNC(_linalg_slogdet)(const Tensor& A) {
}
template <typename Meta>
void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
static void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
@ -1639,7 +1639,7 @@ TORCH_IMPL_FUNC(mm_out_cpu)(const Tensor & self, const Tensor & mat2, const Tens
}
template <typename scalar_t, bool is_bmm>
inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
int64_t bs = result.size(0);
int64_t is = result.size(1);
int64_t js = result.size(2);

View File

@ -20,9 +20,6 @@
namespace at::native {
template<typename scalar_t>
void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
namespace {
static inline void slow_conv_transpose3d_shape_check(

View File

@ -132,7 +132,7 @@ static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) {
}
template<typename scalar_t, typename param_t>
std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
static std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
const Tensor& input, const Tensor& weight, const Tensor& bias,
const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
@ -197,7 +197,7 @@ std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
}
template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
static std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
double momentum, double eps, Tensor& save_mean, Tensor& save_var_transform) {
@ -287,7 +287,7 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
}
template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
static std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
double momentum, double eps) {
int64_t n_input = input.size(1);
@ -306,7 +306,7 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
}
template<typename scalar_t, typename param_t>
std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
static std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps, std::array<bool,3> grad_input_mask) {

View File

@ -472,7 +472,7 @@ Tensor& logcumsumexp_out(const Tensor& self, int64_t dim, Tensor& result) {
}
template <class Stub>
void impl_func_cum_ops(
static void impl_func_cum_ops(
const Tensor& self,
int64_t dim,
const Tensor& result,
@ -769,7 +769,7 @@ inline bool isnan_(T x) {
}
template<typename T1, typename T2, typename Operation>
void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data,
static void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data,
int self_dim_size, int self_stride, int values_stride, int indices_stride) {
Operation op;
T1 out = c10::load(self_data);
@ -1182,7 +1182,7 @@ std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_o
// ALL REDUCE #################################################################
inline bool should_use_acc_buffer(at::TensorIterator& iter) {
static inline bool should_use_acc_buffer(at::TensorIterator& iter) {
const auto ndim = iter.ndim();
if (!iter.device().is_cpu() || iter.noutputs() != 1) {
return false;
@ -1591,7 +1591,7 @@ Tensor norm(const Tensor& self, const Scalar& p) {
return at::norm(self, p, IntArrayRef{}, false);
}
inline TensorIterator get_allany_iter(
static inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef dims,
@ -1608,7 +1608,7 @@ inline TensorIterator get_allany_iter(
}
template <int identity, typename Stub>
inline void allany_impl(
static inline void allany_impl(
const Tensor& self,
const Tensor& result,
OptionalIntArrayRef dims,
@ -1653,7 +1653,7 @@ TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
}
template <bool is_all>
Tensor allany_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
static Tensor allany_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
// Default implementation in terms of all-reduce or single dim reduce
if (!dim) {
Tensor out;
@ -1732,7 +1732,7 @@ TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, co
}
template <class Stub>
void argmax_argmin_impl(
static void argmax_argmin_impl(
const Tensor& self,
std::optional<int64_t> dim,
bool keepdim,

View File

@ -9,6 +9,7 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/resize_as_native.h>
#include <ATen/ops/resize_as_sparse_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/_resize_output.h>
@ -21,7 +22,7 @@ namespace at::native {
// Returns true if resize is necessary
template <typename T>
bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
static bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
// Tests for resizing of tensors with one or more elements
if (at::symint::sizes<T>(output).equals(shape)) {
return false;
@ -56,7 +57,7 @@ static void native_resize_(const Tensor& output, SymIntArrayRef shape) {
}
template <typename T>
bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
static bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
if (_resize_output_check<T>(output, shape)) {
// avoid a redispatch for cpu and cuda.
// TODO: when resize_cuda_ is re-written to be unified with resize_,
@ -196,7 +197,7 @@ static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes)
}
template <typename T>
TensorImpl* _resize_impl_(
static TensorImpl* _resize_impl_(
TensorImpl* self,
ArrayRef<T> size,
at::OptionalArrayRef<T> stride,
@ -234,7 +235,7 @@ TensorImpl* resize_impl_cpu_(
}
template <typename T>
const Tensor& _resize_(
static const Tensor& _resize_(
const Tensor& self,
ArrayRef<T> size,
std::optional<MemoryFormat> optional_memory_format) {

View File

@ -147,7 +147,6 @@
namespace at::native {
std::string shapes_as_str(TensorList tensors);
AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
} // namespace at::native
@ -186,7 +185,7 @@ TORCH_META_FUNC(gather)
}
template <bool use_new_options = false, typename Meta>
void scatter_meta_impl(
static void scatter_meta_impl(
Meta& meta,
const Tensor& self,
int64_t dim,
@ -358,7 +357,7 @@ TORCH_PRECOMPUTE_META_FUNC(index_copy)
}
template <typename Meta>
void index_func_meta_impl(
static void index_func_meta_impl(
Meta& meta,
const Tensor& self,
int64_t dim,
@ -593,21 +592,6 @@ static bool all_strides_match(TensorList tensors) {
return true;
}
inline std::string shapes_as_str(TensorList tensors) {
std::ostringstream os;
bool first = true;
for (auto& tensor : tensors) {
if (tensor.defined()) {
if (!first) {
os << ", ";
}
os << tensor.sizes();
first = false;
}
}
return os.str();
}
// Replace indexed dimensions in src with stride 0 and the size of the result
// tensor. The offset in these dimensions is computed by the kernel using the
// index tensor's values and the stride of src. The new shape is not meaningful.
@ -2249,7 +2233,7 @@ template <
typename T,
typename ReduceStub,
typename FillStub>
void scatter_impl(
static void scatter_impl(
const Tensor& self,
int64_t dim,
const Tensor& index,
@ -2822,7 +2806,7 @@ Tensor _gather_sparse_backward(
}
template <typename scalar_t>
int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
static int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
int64_t num_nonzero = 0;
auto loop = [&](char** data, const int64_t* strides, int64_t n) {

View File

@ -569,7 +569,7 @@ static void isin_sorting(
}
template <typename... Args>
Device out_device(Args&... inps) {
static Device out_device(Args&... inps) {
for (const auto& i : {inps...}) {
if (i.device() != at::kCPU) {
return i.device();
@ -739,7 +739,7 @@ std::tuple<Tensor&, Tensor&> mode_out(
}
template <class Stub>
void minmax_out_impl(
static void minmax_out_impl(
const Tensor& self,
int64_t dim,
bool keepdim,

View File

@ -806,7 +806,7 @@ Tensor sparse_compressed_to_dense(
// Computes the strides for view_dtype output when the view dtype is
// smaller than the original dtype
inline SymDimVector compute_strides_for_view_dtype_downsize(
static inline SymDimVector compute_strides_for_view_dtype_downsize(
SymIntArrayRef old_strides,
int64_t size_ratio,
ScalarType old_dtype,
@ -832,7 +832,7 @@ inline SymDimVector compute_strides_for_view_dtype_downsize(
// Computes the strides for view_dtype output when the view dtype is
// larger than the original dtype
inline SymDimVector compute_strides_for_view_dtype_upsize(
static inline SymDimVector compute_strides_for_view_dtype_upsize(
SymIntArrayRef old_strides,
int64_t size_ratio,
ScalarType old_dtype,
@ -1989,7 +1989,7 @@ TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_cpu)
* Modified to ensure sorted BSR column indices.
*/
template <class index_t, class scalar_t, bool compressed_rows>
void _compressed_to_block_compressed_cpu_kernel(
static void _compressed_to_block_compressed_cpu_kernel(
const index_t n_compressed, // Tensor size along compressed dimension
const index_t n_plain, // Tensor size along plain dimension
const index_t C, // Block size along compressed dimensions
@ -2086,7 +2086,7 @@ void _compressed_to_block_compressed_cpu_kernel(
* https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/sparse/sparsetools/csr.h
*/
template <class index_t>
index_t compressed_count_blocks(
static index_t compressed_count_blocks(
const index_t n_compressed, // Tensor size along compressed dimension
const index_t n_plain, // Tensor size along plain dimension
const index_t C, // Block size along compressed dimensions
@ -2110,7 +2110,7 @@ index_t compressed_count_blocks(
}
template <Layout target_layout>
Tensor _compressed_to_block_compressed_cpu(
static Tensor _compressed_to_block_compressed_cpu(
const Tensor& self,
IntArrayRef blocksize) {
static_assert(

View File

@ -2072,22 +2072,24 @@ Tensor vander(const Tensor& x, std::optional<int64_t> N, bool increasing) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename T>
Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options) {
static Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options) {
return at::detail::tensor_cpu(values, options);
}
template <typename T>
Tensor tensor_backend(ArrayRef<T> values, const TensorOptions& options) {
static Tensor tensor_backend(ArrayRef<T> values, const TensorOptions& options) {
return at::detail::tensor_backend(values, options);
}
template <typename T>
Tensor tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options) {
static Tensor tensor_complex_cpu(
ArrayRef<T> values,
const TensorOptions& options) {
return at::detail::tensor_complex_cpu(values, options);
}
template <typename T>
Tensor tensor_complex_backend(
static Tensor tensor_complex_backend(
ArrayRef<T> values,
const TensorOptions& options) {
return at::detail::tensor_complex_backend(values, options);

View File

@ -216,7 +216,7 @@
namespace at::meta {
inline c10::MemoryFormat cat_compute_output_memory_format(
static inline c10::MemoryFormat cat_compute_output_memory_format(
const MaterializedITensorListRef& inputs) {
std::optional<c10::MemoryFormat> format = std::nullopt;
for (const Tensor& t : inputs) {
@ -1119,7 +1119,7 @@ std::vector<Tensor> tensor_split_sections_symint(
}
template <typename T>
std::vector<Tensor> _tensor_split_indices(
static std::vector<Tensor> _tensor_split_indices(
const Tensor& self,
ArrayRef<T> indices,
int64_t dim) {
@ -1417,7 +1417,7 @@ Tensor as_strided_tensorimpl(
}
template <typename T>
inline void setStridedUnchecked(
static inline void setStridedUnchecked(
const Tensor& self,
ArrayRef<T> size,
ArrayRef<T> stride,
@ -1922,7 +1922,7 @@ Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) {
// templated for ArrayRef<int64_t> and SmallVector<int64_t> use cases
//
template <typename Vec>
Tensor alias_with_sizes_and_strides(
static Tensor alias_with_sizes_and_strides(
const Tensor& self,
const Vec& sizes,
const Vec& strides) {
@ -1958,7 +1958,7 @@ Tensor alias_with_sizes_and_strides(
// SymIntArrayRef/ArrayRef<c10::SymInt> and
// SmallVector<c10::SymInt>/SymDimVector
template <template <typename...> typename Container>
Tensor alias_with_sizes_and_strides(
static Tensor alias_with_sizes_and_strides(
const Tensor& self,
const Container<c10::SymInt>& sizes,
const Container<c10::SymInt>& strides) {
@ -3290,7 +3290,7 @@ static inline std::vector<Tensor> get_stack_inputs(
return inputs;
}
bool inline maybe_native_stack(
static bool inline maybe_native_stack(
Tensor& result,
TensorList tensors,
int64_t dim) {
@ -4021,7 +4021,7 @@ Tensor& squeeze_(Tensor& self, IntArrayRef dims) {
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.
inline Tensor view_impl(const Tensor& self, IntArrayRef size) {
static inline Tensor view_impl(const Tensor& self, IntArrayRef size) {
at::DimVector inferred_size = at::infer_size_dv(size, self.numel());
auto stride =
at::detail::computeStride(self.sizes(), self.strides(), inferred_size);

View File

@ -6,6 +6,7 @@
#include <c10/util/accumulate.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/native/quantized/library.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
@ -20,8 +21,6 @@
namespace ao::sparse {
int register_linear_params();
#ifdef USE_PYTORCH_QNNPACK
template <>
at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<true>(

View File

@ -1,15 +1,16 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <c10/core/Allocator.h>
#if AT_MKLDNN_ENABLED()
#include <c10/core/Allocator.h>
// needs to be included only once in library.
#include <ideep_pin_singletons.hpp>
#include <ATen/native/mkldnn/IDeepRegistration.h>
using namespace ideep;
RegisterEngineAllocator cpu_alloc(
static RegisterEngineAllocator cpu_alloc(
engine::cpu_engine(),
[](size_t size) {
return c10::GetAllocator(c10::DeviceType::CPU)->raw_allocate(size);
@ -20,8 +21,6 @@ RegisterEngineAllocator cpu_alloc(
);
namespace at::native::mkldnn{
void clear_computation_cache();
void clear_computation_cache() {
// Reset computation_cache for forward convolutions
// As it also caches max number of OpenMP workers

View File

@ -0,0 +1,7 @@
#pragma once
namespace at::native::mkldnn{
void clear_computation_cache();
} // namespace at::native::mkldnn

View File

@ -109,7 +109,7 @@ static bool use_mkldnn_bf32_matmul() {
template<typename scalar_t>
inline typename std::enable_if_t<
static inline typename std::enable_if_t<
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, c10::Half> ||
std::is_same_v<scalar_t, c10::BFloat16>,
@ -322,7 +322,7 @@ void mkldnn_matmul(
}
inline bool checksize(const Tensor& mat1, const Tensor& mat2){
static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)

View File

@ -169,7 +169,7 @@ struct RNNParams {
};
template<bool is_single_direction>
std::vector<int64_t> _output_size(const RNNParams& rnn) {
static std::vector<int64_t> _output_size(const RNNParams& rnn) {
auto output_channels = is_single_direction ? rnn.hidden_size
: rnn.hidden_size * rnn.num_directions;
return {rnn.seq_length, rnn.mini_batch, output_channels};

View File

@ -25,7 +25,7 @@ static bool is_mkldnn_fp16_supported() {
return mkldnn_fp16_device_check();
}
constexpr bool is_mkldnn_acl_supported() {
static constexpr bool is_mkldnn_acl_supported() {
return AT_MKLDNN_ACL_ENABLED();
}

View File

@ -84,7 +84,7 @@ void check_mkldnn_binary_fusion_inputs(
return ideep::attr_t::fuse_##NAME(); \
}
AttrFunction attr_func_leaky_relu =
static AttrFunction attr_func_leaky_relu =
[](torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
TORCH_CHECK(
@ -96,7 +96,7 @@ AttrFunction attr_func_leaky_relu =
return ideep::attr_t::fuse_relu(1.0, alpha_value);
};
AttrFunction attr_func_hardtanh =
static AttrFunction attr_func_hardtanh =
[](torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
TORCH_CHECK(
@ -112,7 +112,7 @@ AttrFunction attr_func_hardtanh =
return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
};
AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
static AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
TORCH_CHECK(
algorithm.has_value(),
@ -130,7 +130,7 @@ AttrFunction attr_func_gelu = [](torch::List<std::optional<at::Scalar>> scalars,
return ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type);
};
AttrFunction attr_func_hardsigmoid =
static AttrFunction attr_func_hardsigmoid =
[](torch::List<std::optional<at::Scalar>> scalars,
std::optional<std::string_view> algorithm) {
ideep::attr_t attr;

View File

@ -26,7 +26,7 @@ DEFINE_DISPATCH(qmean_inner_dim_stub);
DEFINE_DISPATCH(qstd_inner_dim_stub);
// If mean/std is taken in the innermost dims, the fast path can be used.
inline bool is_innnermost_dim(
static inline bool is_innnermost_dim(
const Tensor& self,
OptionalIntArrayRef opt_dim) {
if (!opt_dim.has_value()) {
@ -43,7 +43,7 @@ inline bool is_innnermost_dim(
return is_innermost;
}
inline bool is_mean_inner_dim_fast_path(
static inline bool is_mean_inner_dim_fast_path(
const Tensor& self,
OptionalIntArrayRef opt_dim,
std::optional<ScalarType> opt_dtype) {
@ -172,7 +172,7 @@ Tensor mean_quantized_cpu(
}
// qstd
inline bool is_std_inner_dim_fast_path(
static inline bool is_std_inner_dim_fast_path(
const Tensor& self,
OptionalIntArrayRef dim,
const std::optional<Scalar>& correction) {

View File

@ -117,7 +117,7 @@ static void upsample_nearest2d_out_frame_nhwc(
}
template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
Tensor _upsample_nearest2d_quantized_cpu(
static Tensor _upsample_nearest2d_quantized_cpu(
const Tensor& input,
IntArrayRef output_size,
std::optional<double> scales_h,

View File

@ -129,7 +129,7 @@ static void upsample_nearest3d_out_frame_nhwc(
}
template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
Tensor _upsample_nearest3d_quantized_cpu(
static Tensor _upsample_nearest3d_quantized_cpu(
const Tensor& input,
IntArrayRef output_size,
std::optional<double> scales_d,

View File

@ -44,7 +44,7 @@ constexpr int64_t kReasonableMaxDim = 1000000;
} // namespace
template <int kSpatialDim = 2>
bool ConvDimChecks(
static bool ConvDimChecks(
int64_t act_dims,
int64_t stride_dims,
int64_t padding_dims,
@ -95,7 +95,7 @@ bool ConvDimChecks(
return true;
}
inline int64_t compute_deconv_shape(int64_t input,
static inline int64_t compute_deconv_shape(int64_t input,
int64_t kernel,
int64_t stride,
int64_t input_padding,
@ -107,7 +107,7 @@ inline int64_t compute_deconv_shape(int64_t input,
}
template <int64_t kSpatialDim>
at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
static at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
int64_t N, int64_t M,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& kernel,
@ -178,7 +178,7 @@ at::SmallVector<int64_t, 5> MakeConvOutputShape<3>(
#ifdef USE_PYTORCH_QNNPACK
template <size_t kSpatialDim>
std::array<int64_t, kSpatialDim> MakeInputShape(
static std::array<int64_t, kSpatialDim> MakeInputShape(
int64_t D,
int64_t H,
int64_t W);

View File

@ -3,6 +3,7 @@
#include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <ATen/native/quantized/library.h>
#include <torch/library.h>
#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
@ -28,8 +29,6 @@
#include <arm_neon.h>
#endif
int register_embedding_params();
namespace {
// Fallback implementation when FBGEMM is not available.

View File

@ -7,6 +7,7 @@
#include <ATen/core/custom_class.h>
#include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/library.h>
#include <c10/core/ScalarType.h>
#include <torch/library.h>
@ -23,8 +24,6 @@
#include <utility>
int register_embedding_params();
/*
* Prepack function for embedding_bag weights.
* This function expects a per-row quantized weight tensor

View File

@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/EmbeddingPackedParams.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <ATen/native/quantized/library.h>
#include <c10/util/irange.h>
#include <torch/library.h>
@ -17,8 +18,6 @@
#include <ATen/ops/resize_native.h>
#endif
int register_embedding_params();
at::Tensor PackedEmbeddingBagWeight::unpack() {
auto packed_weight = packed_w;
at::Tensor weight_origin;

View File

@ -21,6 +21,8 @@
#else
#include <ATen/ops/_saturate_weight_to_fp16.h>
#include <ATen/ops/_saturate_weight_to_fp16_native.h>
#include <ATen/ops/_wrapped_linear_prepack_native.h>
#include <ATen/ops/_wrapped_quantized_linear_prepacked_native.h>
#include <ATen/ops/dequantize.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/quantize_per_tensor.h>
@ -292,7 +294,7 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
return ret_ptr;
}
inline at::Tensor pack_weight_to_onednn_tensor(
static inline at::Tensor pack_weight_to_onednn_tensor(
const at::Tensor& weight,
std::optional<torch::List<int64_t>>& input_shape) {
std::vector<int64_t> w_dims = weight.sizes().vec();
@ -312,7 +314,7 @@ inline at::Tensor pack_weight_to_onednn_tensor(
return packed_weight;
}
inline at::Tensor pack_weight_to_fp16_onednn_tensor(
static inline at::Tensor pack_weight_to_fp16_onednn_tensor(
at::Tensor& weight,
std::optional<torch::List<int64_t>>& input_shape) {
TORCH_CHECK(weight.scalar_type() == at::kHalf || weight.scalar_type() == at::kFloat, "Weight should be of type float or float16");
@ -342,12 +344,12 @@ at::Tensor _saturate_weight_to_fp16(const Tensor& weight) {
}
template <class... Inputs>
inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
static inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
return {std::forward<Inputs>(inputs)...};
}
template <class... Args>
inline std::vector<c10::IValue> callOpByHandle(
static inline std::vector<c10::IValue> callOpByHandle(
const c10::OperatorHandle& op,
Args... args) {
auto stack = makeStack(std::forward<Args>(args)...);
@ -356,7 +358,7 @@ inline std::vector<c10::IValue> callOpByHandle(
}
template <class... Args>
inline std::vector<c10::IValue> callOpByName(
static inline std::vector<c10::IValue> callOpByName(
const char* func_name,
const char* overload_name,
Args... args) {
@ -366,7 +368,7 @@ inline std::vector<c10::IValue> callOpByName(
return callOpByHandle(op_handle.value(), std::forward<Args>(args)...);
}
at::Tensor wrapped_quantized_linear(
static at::Tensor wrapped_quantized_linear(
at::Tensor input,
const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
@ -422,7 +424,7 @@ at::Tensor wrapped_quantized_linear(
#endif // USE_FBGEMM
}
at::Tensor wrapped_quantized_linear_meta(
static at::Tensor wrapped_quantized_linear_meta(
at::Tensor input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
@ -457,11 +459,6 @@ at::Tensor wrapped_quantized_linear_meta(
#endif // USE_FBGEMM
}
at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
const at::Tensor& bias);
at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
const at::Tensor& weight_scale,
const at::Tensor& weight_zero_point,
@ -495,13 +492,6 @@ at::Tensor _wrapped_linear_prepack(const at::Tensor& weight,
#endif // USE_FBGEMM
}
at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& packed_weight,
const at::Tensor& output_scale,
const at::Tensor& output_zero_point,
[[maybe_unused]] const int64_t out_channel);
at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale,
const at::Tensor& input_zero_point,
const at::Tensor& packed_weight,
@ -528,12 +518,7 @@ at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at
#endif // USE_FBGEMM
}
at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias);
at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
static at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
[[maybe_unused]] const at::Tensor& weight_scale,
[[maybe_unused]] const at::Tensor& weight_zero_point,
[[maybe_unused]] const at::Tensor& bias) {
@ -551,15 +536,7 @@ at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight,
#endif // USE_FBGEMM
}
at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
[[maybe_unused]] const at::Tensor& packed_weight,
[[maybe_unused]] const at::Tensor& output_scale,
[[maybe_unused]] const at::Tensor& output_zero_point,
const int64_t out_channel);
at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
static at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input,
[[maybe_unused]] const at::Tensor& input_scale,
[[maybe_unused]] const at::Tensor& input_zero_point,
[[maybe_unused]] const at::Tensor& packed_weight,

View File

@ -480,7 +480,7 @@ Tensor _sparse_compressed_tensor_unsafe_symint(
}
template <Layout required_layout>
Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
static Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
const Tensor& plain_indices,
const Tensor& values,
IntArrayRef size,
@ -967,7 +967,7 @@ Tensor empty_like_sparse_csr(
}
template <bool require_view, bool require_copy>
Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
static Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) {
#ifndef STRIP_ERROR_MESSAGES
constexpr const char* select_name = (require_view ? "select()" : "select_copy()");
#endif

View File

@ -219,7 +219,7 @@ Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
}
template <typename op_t>
Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Tensor& scalar, const op_t& op) {
static Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Tensor& scalar, const op_t& op) {
// NOTE: intersection_binary_op_with_wrapped_scalar assumes scalar.numel() == 1.
const auto result_values = op(sparse.values(), scalar.squeeze()).to(at::result_type(sparse, scalar));
const auto result_sizes = infer_size(sparse.sizes(), scalar.sizes());
@ -233,7 +233,7 @@ Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Te
}
template <typename op_t>
Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
static Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
// NOTE: intersection_binary_op_with_wrapped_scalar_ assumes scalar.numel() == 1.
const auto broadcasted_shape = infer_size(sparse.sizes(), scalar.sizes());
if (sparse.sizes() != broadcasted_shape) {
@ -522,7 +522,7 @@ CREATE_UNARY_UFUNC_FUNCTIONAL(isnan)
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf)
template <typename scalar_t>
void addmm_out_sparse_csr_native_cpu(
static void addmm_out_sparse_csr_native_cpu(
const Tensor& sparse,
const Tensor& dense,
const Tensor& r,

View File

@ -551,7 +551,7 @@ static SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const Sparse
return r;
}
Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value);
static Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value);
SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src, const Scalar& value, SparseTensor& r) {
if (!t.is_sparse()) {
@ -593,7 +593,7 @@ SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src,
// formerly known as spcadd
// --------------------------------------------------------------------
template <typename scalar_t>
void add_dense_sparse_worker_non_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
static void add_dense_sparse_worker_non_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
auto indices_accessor = indices.accessor<int64_t, 2>();
auto values_accessor = values.accessor<scalar_t, 1>();
@ -616,7 +616,7 @@ void add_dense_sparse_worker_non_hybrid_cpu(Tensor& r, const Scalar& value, cons
}
template <typename scalar_t>
inline void add_dense_sparse_worker_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
static inline void add_dense_sparse_worker_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
// Get the dense dimension element numbers of hybrid sparse tensor
int64_t values_dense_size = values.stride(0);
@ -647,7 +647,7 @@ inline void add_dense_sparse_worker_hybrid_cpu(Tensor& r, const Scalar& value, c
}
template <typename scalar_t>
inline void add_dense_sparse_worker_non_coalesced_cpu(Tensor& r, const Scalar& value,
static inline void add_dense_sparse_worker_non_coalesced_cpu(Tensor& r, const Scalar& value,
const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
// Get the dense dimension element numbers of hybrid sparse tensor
@ -829,7 +829,7 @@ Tensor& mul_sparse_(Tensor& self, const Tensor& other) {
// so it is up to the user to supply right implementations for non-commutative
// operations.
template <typename binary_func_t>
Tensor& intersection_binary_op_sparse_dense_out(
static Tensor& intersection_binary_op_sparse_dense_out(
const Tensor& d,
const SparseTensor& s_,
Tensor& res,
@ -1183,7 +1183,7 @@ SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r
// --------------------------------------------------------------------
template <typename scalar_t>
void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
static void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
// r_ = alpha * sparse * dense
scalar_t cast_alpha = alpha.to<scalar_t>();
@ -1905,7 +1905,7 @@ Tensor bmm_sparse_cpu(const SparseTensor& self, const Tensor& mat2) {
// Returns the index of the found element.
// Returns by reference `found`, true if search value was found, false otherwise
template<typename scalar_t>
scalar_t binary_search_strided_rightmost(scalar_t search_val, TensorAccessor<scalar_t, 1>& sorted_arr_accessor, int64_t sorted_arr_begin_idx, int64_t length, bool* found) {
static scalar_t binary_search_strided_rightmost(scalar_t search_val, TensorAccessor<scalar_t, 1>& sorted_arr_accessor, int64_t sorted_arr_begin_idx, int64_t length, bool* found) {
if (length == 0) {
*found = false;
return -1;

View File

@ -7,7 +7,7 @@
namespace at::native::xnnpack {
inline std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in) {
static inline std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in) {
const auto mem_format = in.suggest_memory_format();
const auto& sizes = in.sizes();
std::vector<size_t> ret(sizes.begin(), sizes.end());

View File

@ -16,6 +16,7 @@
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>

View File

@ -69,5 +69,9 @@ TORCH_API bool isSupported(Node* node);
/// @return Reference of the custome operator set
///
TORCH_API OperatorSet& getCustomOperatorSet();
} // namespace tensorexpr
} // namespace torch::jit
C10_DECLARE_bool(torch_jit_disable_cat);
C10_DECLARE_bool(torch_jit_enable_dynamic_shape_fusion);