[Reland] Use missing-prototypes in torch_cpu (#104138)

This PR enables Wmissing-prototypes in torch_cpu except some generated cpp files and the mps and metal,vulkan backends and caffe2 sources.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104138
Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
cyy
2023-06-26 22:53:43 +00:00
committed by PyTorch MergeBot
parent 436d035dc7
commit d4a98280a8
28 changed files with 101 additions and 67 deletions

View File

@ -1315,7 +1315,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
std::move(packed_output.data), std::move(std::get<1>(result))); \
}
#define ONE_HIDDEN_QRNN(NAME, CELL) \
std::tuple<Tensor, Tensor> NAME##_input( \
static std::tuple<Tensor, Tensor> NAME##_input( \
const Tensor& _input, \
const Tensor& hx, \
c10::List<c10::intrusive_ptr<CellParamsBase>> _params, \
@ -1345,7 +1345,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backwar
return results; \
} \
\
std::tuple<Tensor, Tensor> NAME##_data( \
static std::tuple<Tensor, Tensor> NAME##_data( \
const Tensor& data, \
const Tensor& batch_sizes, \
const Tensor& hx, \
@ -1690,7 +1690,7 @@ Tensor rnn_relu_cell(
// an int8 or float16 quantized weight. This is advantageous in small-batch-size
// scenarios where runtime is dominated by memory fetches of the weight matrix.
std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input(
static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input(
const Tensor& _input,
c10::List<at::Tensor> hx_,
c10::List<c10::intrusive_ptr<CellParamsBase>> _params_,
@ -1763,7 +1763,7 @@ static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input_legacy(
"using the newer definitions in torch.jit.quantized");
}
std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
const Tensor& data,
const Tensor& batch_sizes,
c10::List<at::Tensor> hx_,

View File

@ -502,7 +502,7 @@ Tensor& fft_rfftn_symint_out(const Tensor& self,
return out;
}
ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
static ShapeAndDims canonicalize_fft_c2r_shape_and_dim_args(
c10::string_view fname, const Tensor& self,
const at::OptionalSymIntArrayRef& s,
const at::OptionalIntArrayRef& dims,

View File

@ -18,6 +18,10 @@
#include <ATen/ops/is_set_to_native.h>
#include <ATen/ops/size_native.h>
#include <ATen/ops/stride_native.h>
#include <ATen/ops/sym_numel_native.h>
#include <ATen/ops/sym_size_native.h>
#include <ATen/ops/sym_storage_offset_native.h>
#include <ATen/ops/sym_stride_native.h>
#endif
#include <c10/util/irange.h>

View File

@ -3130,7 +3130,6 @@ struct InferUnsqueezeGeometryResult {
: sizes(tensor_sizes.begin(), tensor_sizes.end())
, strides(tensor_strides.begin(), tensor_strides.end()) {}
};
}
InferUnsqueezeGeometryResult
inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides());
@ -3142,7 +3141,7 @@ inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
}
// dim is present if squeezing a single dimension and absent if squeezing all dimensions
static Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
auto quantizer = get_qtensorimpl(self)->quantizer();
SymDimVector sizes;
SymDimVector strides;
@ -3176,6 +3175,7 @@ static Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims)
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}
}
Tensor squeeze(const Tensor& self) {
auto g = inferSqueezeGeometry(self);

View File

@ -55,7 +55,7 @@ Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) {
namespace at {
namespace native {
Tensor emptyBinaryOp(const Tensor& self, const Tensor& other) {
static Tensor emptyBinaryOp(const Tensor& self, const Tensor& other) {
if (!self.requires_grad() && !other.requires_grad()) {
auto out_size = infer_size(self.sizes(), other.sizes());
auto out_dtype = promoteTypes(

View File

@ -208,7 +208,7 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo
return memory_format;
}
void _mkldnn_convolution_out (
static void _mkldnn_convolution_out (
const Tensor& input_t,
const Tensor& weight_t,
const Tensor& bias,
@ -256,7 +256,7 @@ void _mkldnn_convolution_out (
}
}
Tensor _mkldnn_convolution(
static Tensor _mkldnn_convolution(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
@ -344,6 +344,7 @@ Tensor mkldnn_convolution(
use_channels_last);
}
namespace{
Tensor mkldnn_convolution_pointwise(
const Tensor& input_t,
const Tensor& weight_t,
@ -936,9 +937,11 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
}
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_backward_stub, &mkldnn_convolution_backward);
namespace{
Tensor mkldnn_convolution_transpose(
const Tensor& input,
const Tensor& weight,
@ -1081,6 +1084,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_transpose_backward(
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
}
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_stub, &mkldnn_convolution_transpose);
REGISTER_ALL_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub, &mkldnn_convolution_transpose_backward);

View File

@ -94,7 +94,7 @@ ContextConv create(
attr};
}
void _mkldnn_convolution_out(
static void _mkldnn_convolution_out(
const ideep::tensor& x,
ideep::tensor& y,
const ideep::tensor& w,
@ -143,7 +143,7 @@ void _mkldnn_convolution_out(
}
}
void mkldnn_convolution_out(
static void mkldnn_convolution_out(
const Tensor& input,
ideep::tensor& mkldnn_output,
const ideep::tensor& mkldnn_weight,
@ -178,7 +178,7 @@ void mkldnn_convolution_out(
attr);
}
std::vector<int64_t> get_output_sizes(
static std::vector<int64_t> get_output_sizes(
ContextConv& context,
const Tensor& input) {
const ideep::tensor& mkldnn_weight = context.weight_packed_;

View File

@ -19,7 +19,8 @@ RegisterEngineAllocator cpu_alloc(
}
);
namespace at { namespace native { namespace mkldnn {
namespace at::native::mkldnn{
void clear_computation_cache();
void clear_computation_cache() {
// Reset computation_cache for forward convolutions
@ -27,6 +28,6 @@ void clear_computation_cache() {
ideep::convolution_forward::t_store().clear();
}
}}} // namespace at::native::mkldnn
} // namespace at::native::mkldnn
#endif // AT_MKLDNN_ENALBED()

View File

@ -177,7 +177,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor mkldnn_linear_pointwise(
static Tensor mkldnn_linear_pointwise(
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
@ -248,7 +248,7 @@ Tensor mkldnn_linear_pointwise(
return output;
}
Tensor mkldnn_linear_pointwise_binary(
static Tensor mkldnn_linear_pointwise_binary(
const Tensor& input_t,
const Tensor& other_t,
const Tensor& weight_t,
@ -329,7 +329,7 @@ Tensor mkldnn_linear_pointwise_binary(
#if AT_MKL_ENABLED()
#include <mkl.h>
Tensor mkl_linear(
static Tensor mkl_linear(
const Tensor& self,
const Tensor& mkl_weight_t,
const Tensor& origin_weight_t,
@ -417,7 +417,7 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
#else // AT_MKL_ENABLED
Tensor mkl_linear(
static Tensor mkl_linear(
const Tensor& self,
const Tensor& mkl_weight_t,
const Tensor& origin_weight_t,

View File

@ -210,7 +210,7 @@ Tensor mkldnn_reorder_conv3d_weight(
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
Tensor mkldnn_reorder_linear_weight(
static Tensor mkldnn_reorder_linear_weight(
const Tensor& self,
c10::optional<int64_t> batch_size_opt) {
if (self.scalar_type() == ScalarType::BFloat16) {
@ -236,7 +236,7 @@ Tensor mkldnn_reorder_linear_weight(
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
ideep::tensor::desc get_conv_transpose_expected_weights_desc(
static ideep::tensor::desc get_conv_transpose_expected_weights_desc(
const ideep::tensor::dims& weights_dims,
ideep::tensor::data_type w_dtype,
const ideep::tensor::dims& strides,
@ -275,7 +275,7 @@ ideep::tensor::desc get_conv_transpose_expected_weights_desc(
}
}
Tensor mkldnn_reorder_conv_transpose2d_weight(
static Tensor mkldnn_reorder_conv_transpose2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef output_padding,
@ -373,7 +373,7 @@ Tensor mkldnn_reorder_conv3d_weight(
#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
#include <mkl.h>
Tensor mkl_reorder_linear_weight(
static Tensor mkl_reorder_linear_weight(
const Tensor& weight,
const int64_t batch_size) {
TORCH_CHECK(

View File

@ -165,7 +165,7 @@ struct RNNParams {
}
};
std::vector<int64_t> _hidden_size(const RNNParams& rnn) {
static std::vector<int64_t> _hidden_size(const RNNParams& rnn) {
return {rnn.num_layers * rnn.num_directions, rnn.mini_batch, rnn.hidden_size};
}
@ -196,7 +196,7 @@ std::vector<int64_t> _output_size(const RNNParams& rnn) {
// | nt2 |
// +---------+
//
Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) {
static Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) {
auto weight_t = weight.contiguous();
if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
std::vector<Tensor> gates = weight_t.chunk(3, /*gates*/0);
@ -205,7 +205,7 @@ Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) {
return weight_t;
}
Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_t fn_mode) {
static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_t fn_mode) {
if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
std::vector<Tensor> b1 = bias_ih.chunk(3, /*output_channels*/0);
std::vector<Tensor> b2 = bias_hh.chunk(3, /*output_channels*/0);
@ -468,7 +468,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_la
// b. padded sequence input support
//
std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
const Tensor& input_, TensorList weight, int64_t weight_stride0,
const Tensor& hx_, const Tensor& cx_,
int64_t mode, int64_t hidden_size,
@ -566,8 +566,6 @@ std::pair<Tensor, hidden_type> mkldnn_impl(
pack_hidden<hidden_type>(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))};
}
} // anonymous namespace
void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,
const Tensor& input, TensorList hx, TensorList params, bool has_biases,
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
@ -577,6 +575,7 @@ void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,
hy = std::get<0>(result.second);
cy = std::get<1>(result.second);
}
} // anonymous namespace
REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn);

View File

@ -13,7 +13,7 @@ namespace mkldnn {
using namespace internal::convolution;
bool is_mkldnn_bf16_supported() {
static bool is_mkldnn_bf16_supported() {
#if defined(__aarch64__)
return mkldnn_bf16_device_check_arm();
#else

View File

@ -662,7 +662,7 @@ DEFINE_DISPATCH(sparse_mask_projection_out_stub);
using OptTensor = c10::optional<Tensor>;
std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
static std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
const std::string& method_name,
const Tensor& t,
const Tensor& mask) {

View File

@ -824,6 +824,27 @@ if(BUILD_CAFFE2 AND NOT MSVC)
target_compile_options(torch_cpu PRIVATE "-Wno-sign-compare")
endif()
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_PYTORCH_METAL AND NOT USE_MPS AND NOT USE_COREML_DELEGATE)
target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes")
target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes")
get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES)
foreach(generated_file IN LISTS GENERATED_CXX_TORCH)
set_source_files_properties(${generated_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes")
endforeach()
foreach(source_file IN LISTS TORCH_CPU_SOURCES)
get_filename_component(source_file "${source_file}" REALPATH)
string(FIND "${source_file}" "${CMAKE_BINARY_DIR}" res)
if(res GREATER -1)
set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes")
continue()
endif()
string(FIND "${source_file}" "caffe2" res)
if(res GREATER -1)
set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes")
endif()
endforeach()
endif()
set_property(SOURCE ${ATen_CORE_SRCS} APPEND
PROPERTY COMPILE_DEFINITIONS "TORCH_ASSERT_ONLY_METHOD_OPERATORS")

View File

@ -1451,7 +1451,7 @@ Tensor mm_mat1_sparse_backward(
mat2.layout());
}
Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) {
static Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) {
if (x.is_coalesced() && gx.is_coalesced()) {
if (x._nnz() >= gx._nnz()) {
// search into x is faster

View File

@ -61,7 +61,7 @@ at::Tensor empty_llga(
std::move(storage_impl), options.dtype(), desc);
}
const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) {
static const LlgaTensorDesc& get_llga_desc(const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(
tensor.is_mkldnn(), "get_llga_desc expects Mkldnn tensor input");
return static_cast<LlgaTensorImpl*>(tensor.unsafeGetTensorImpl())->desc();

View File

@ -10,7 +10,7 @@ namespace jit {
namespace fuser {
namespace onednn {
bool shouldDecomposeSilu(Node* node) {
static bool shouldDecomposeSilu(Node* node) {
if (node->kind() != aten::silu) {
return false;
}
@ -26,7 +26,7 @@ bool shouldDecomposeSilu(Node* node) {
return false;
}
void DecomposeSilu(Node* node) {
static void DecomposeSilu(Node* node) {
if (shouldDecomposeSilu(node)) {
auto dtype = node->input(0)->type()->expect<TensorType>();

View File

@ -1,3 +1,4 @@
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>

View File

@ -12,7 +12,7 @@ namespace onednn {
using opkind = dnnl::graph::op::kind;
void fixConvOptionalBias(Node* node) {
static void fixConvOptionalBias(Node* node) {
if (node->namedInput("bias")->mustNotBeNone() == false) {
// Replace non-existent optional bias with const None
auto g = node->owningGraph();
@ -22,7 +22,7 @@ void fixConvOptionalBias(Node* node) {
}
}
c10::optional<size_t> getDimensions(Value* v) {
static c10::optional<size_t> getDimensions(Value* v) {
if (v->type()->isSubtypeOf(TensorType::get())) {
return v->type()->cast<TensorType>()->sizes().size();
} else {
@ -36,7 +36,7 @@ c10::optional<size_t> getDimensions(Value* v) {
// no need to check beforehand whether the op is supported by oneDNN Graph or
// not oneDNN Graph ops separated by wildcards don't end up in the same
// partition.
Operator makeWildcardOp(Node* node) {
static Operator makeWildcardOp(Node* node) {
auto o = Operator(node, opkind::Wildcard);
// wildcard op contains only topology info
for (size_t i = 0; i < node->inputs().size(); i++) {
@ -307,7 +307,7 @@ Operator LlgaGraphHelper::createOperator(Node* node) {
return makeWildcardOp(node);
}
DeviceType inferDeviceFromValue(Value* v) {
static DeviceType inferDeviceFromValue(Value* v) {
auto tt = v->type()->cast<TensorType>();
if (!tt) {
return at::kCPU;
@ -319,7 +319,7 @@ DeviceType inferDeviceFromValue(Value* v) {
return device->type();
}
DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
static DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
auto dt = inferDeviceFromValue(graph->inputs()[0]);
TORCH_CHECK(
std::all_of(
@ -330,7 +330,7 @@ DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
return dt;
}
dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) {
static dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) {
switch (type) {
case DeviceType::CPU:
return dnnl::graph::engine::kind::cpu;
@ -339,7 +339,7 @@ dnnl::graph::engine::kind getLlgaEngineKind(DeviceType type) {
}
}
void mayAddListConstructIntoConcatPartition(
static void mayAddListConstructIntoConcatPartition(
Node* n,
OpPartitionMap& opToOwningPartition) {
// Since prim::ListConstruct is not visible to the LLGA,
@ -360,7 +360,7 @@ void mayAddListConstructIntoConcatPartition(
// Scalars would be converted to 1-D tensors later anyway,
// but they shouldn't be complex-double
// If this check fails, convert op to wildcard
bool checkInputCompatibility(Node* node) {
static bool checkInputCompatibility(Node* node) {
auto allInputs = node->inputs();
for (auto input : allInputs) {
c10::IValue inputIValue = toIValue(input);
@ -465,7 +465,7 @@ bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
// only use single-op partitions for ops unsupported by NNC, or ops
// that oneDNN executes faster. prim::ListConstruct is an exception, since
// we simply want to fuse it with cat.
bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
static bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
return (
(kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
(kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||

View File

@ -97,7 +97,7 @@ void fuseGraph(std::shared_ptr<Graph>& g) {
} // namespace onednn
} // namespace fuser
Operation createLlgaKernel(const Node* node) {
static Operation createLlgaKernel(const Node* node) {
auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
return [kernel](Stack& stack) {
RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
@ -117,7 +117,7 @@ RegisterOperators oneDNNFusionGroupOp({
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
// But if we have any scalar inputs to guard in the future, some logic here
// would have to be changed.
Operation createLlgaGuardKernel(const Node* node) {
static Operation createLlgaGuardKernel(const Node* node) {
return [node](Stack& stack) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());

View File

@ -10,7 +10,7 @@ namespace onednn {
static std::atomic<bool> onednn_enabled{false};
std::atomic<bool>& getLlgaEnabled() {
static std::atomic<bool>& getLlgaEnabled() {
return onednn_enabled;
}

View File

@ -7,7 +7,7 @@ namespace jit {
namespace fuser {
namespace onednn {
void LayoutPropagation(Node* n) {
static void LayoutPropagation(Node* n) {
if (!LlgaGraphHelper::isLlgaSubgraph(n))
return;
@ -37,7 +37,7 @@ void LayoutPropagation(Node* n) {
}
}
void LayoutPropagation(at::ArrayRef<Block*> blocks) {
static void LayoutPropagation(at::ArrayRef<Block*> blocks) {
for (Block* block : blocks)
for (Node* node : block->nodes())
LayoutPropagation(node);

View File

@ -8,14 +8,14 @@ namespace jit {
namespace fuser {
namespace onednn {
bool compareConstValue(Value* v, double d) {
static bool compareConstValue(Value* v, double d) {
auto ival = toIValue(v);
return ival.has_value() &&
((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
(ival->isDouble() && ival->toDouble() == d));
}
void handleBinaryOpInputs(Node* node) {
static void handleBinaryOpInputs(Node* node) {
// We do not handle binary ops with two scalar inputs,
// and we assume scalar is always at the second place.
if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
@ -123,7 +123,7 @@ static void ConvertScalarToTensor(Block* block) {
}
}
void mayDecomposeAdd(Node* node) {
static void mayDecomposeAdd(Node* node) {
if (node->inputs().size() < 3) {
return; // corner-case in BERT-mrpc that's not in line with
// native_functions.yaml

View File

@ -5,7 +5,7 @@ namespace jit {
namespace fuser {
namespace onednn {
bool canFuseNode(const Node* node) {
static bool canFuseNode(const Node* node) {
switch (node->kind()) {
case aten::conv2d:
case aten::_convolution:

View File

@ -13,12 +13,12 @@ namespace jit {
#if AT_MKLDNN_ENABLED()
c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
static c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
auto tt = n->input(idx)->type()->cast<TensorType>();
return tt->sizes();
}
void insertPrePackedConvOpForNode(Node* n) {
static void insertPrePackedConvOpForNode(Node* n) {
constexpr int POS_INPUT = 0;
constexpr int POS_WEIGHT = 1;
if (!tensorexpr::isContiguous(
@ -72,7 +72,7 @@ void insertPrePackedConvOpForNode(Node* n) {
n->output()->replaceAllUsesWith(prepack_conv->output());
}
bool isTensorTypeCPU(Node* node) {
static bool isTensorTypeCPU(Node* node) {
for (const auto& input : node->inputs()) {
auto type = input->type()->cast<TensorType>();
if (!type) {
@ -89,7 +89,7 @@ bool isTensorTypeCPU(Node* node) {
return true;
}
void insertPrePackedConvOp(Block* b) {
static void insertPrePackedConvOp(Block* b) {
for (Node* n : b->nodes()) {
for (Block* b : n->blocks()) {
insertPrePackedConvOp(b);
@ -104,15 +104,15 @@ void insertPrePackedConvOp(Block* b) {
EliminateDeadCode(b);
}
void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
static void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
insertPrePackedConvOp(graph->block());
}
void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
static void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
insertMkldnnPrePackedConv2dOp(graph);
}
void insertMkldnnPrePackedOps(script::Module& module) {
static void insertMkldnnPrePackedOps(script::Module& module) {
for (auto& method : module.get_methods()) {
auto graph = method.graph();
insertMkldnnPrePackedOps(graph);
@ -122,7 +122,7 @@ void insertMkldnnPrePackedOps(script::Module& module) {
}
}
void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
static void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
auto conv_op_rstring = at::jit::CodeTemplate(R"(
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
%dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
@ -164,7 +164,7 @@ void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
}
}
void PrePackingOpsFolder(Block* b) {
static void PrePackingOpsFolder(Block* b) {
auto is_foldable_op = [](const Node* n) -> bool {
return (
n->kind() ==
@ -201,7 +201,7 @@ void PrePackingOpsFolder(Block* b) {
}
}
void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
static void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
PrePackingOpsFolder(graph->block());
}

View File

@ -29,7 +29,7 @@ void TEWrapper::call(const std::vector<void*>& args) {
cg->call_raw(args);
}
void optimizePointwise(LoopNest* ln, Tensor target, int width) {
static void optimizePointwise(LoopNest* ln, Tensor target, int width) {
std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
ForPtr inner, tail;
TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
@ -37,7 +37,7 @@ void optimizePointwise(LoopNest* ln, Tensor target, int width) {
ln->vectorize(inner);
}
std::shared_ptr<TEWrapper> wrapTECompute(
static std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
Tensor out,
std::vector<CodeGen::BufferArg> args,
@ -54,7 +54,7 @@ std::shared_ptr<TEWrapper> wrapTECompute(
return wrap;
}
std::shared_ptr<TEWrapper> wrapTECompute(
static std::shared_ptr<TEWrapper> wrapTECompute(
std::shared_ptr<TEWrapper> wrap,
LoopNest* ln,
std::vector<CodeGen::BufferArg> args) {

View File

@ -1,5 +1,6 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/Functions.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/Export.h>
@ -97,6 +98,9 @@ void DispatchParallel(
int8_t* packed_data) noexcept;
FOR_ALL_EXTERNAL_FUNCTIONS(DECLARE_EXTERNAL_FUNCTION)
#if AT_MKLDNN_ENABLED()
DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run);
#endif
TORCH_API void nnc_aten_free(int64_t bufs_num, void** ptrs) noexcept;

View File

@ -1087,7 +1087,7 @@ void LLVMCodeGenImpl::visit(BoolImmPtr v) {
value_ = llvm::ConstantInt::get(BoolTy_, v->value());
}
llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
static llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
if (lanes > 1) {
return llvm::VectorType::get(type, ElementCount(lanes));
} else {