[Reland] [1/N] Fix clang-tidy warnings in inductor (#134544)

Reland #131979 and exclude aoti_torch_index_put_out changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134544
Approved by: https://github.com/ColinPeppler
This commit is contained in:
cyy
2024-08-28 04:05:06 +00:00
committed by PyTorch MergeBot
parent 71d0eff6e7
commit bb4dfe90b8
10 changed files with 58 additions and 79 deletions

View File

@ -64,8 +64,8 @@ std::vector<at::Tensor> unpack_tensors(
const c10::Device& device) {
std::vector<at::Tensor> inputs;
for (size_t idx = 0; idx < stack.size(); idx++) {
auto ivalue = stack[idx];
auto ivalue_arg = arguments[idx];
const auto& ivalue = stack[idx];
const auto& ivalue_arg = arguments[idx];
if (ivalue.isTensor()) {
unpack_tensor_ivalue(ivalue, device, inputs);
} else if (ivalue.isTensorList()) {
@ -117,12 +117,10 @@ std::vector<ParameterMetadata> unpack_input_parameters(
if (stack[idx].isScalar()) {
// Beyond c10::Scalar, the floating value and interger value are also
// represented as Scalar.
inputs_metadata.push_back(
ParameterMetadata(stack[idx].toScalar(), arg_order));
inputs_metadata.emplace_back(stack[idx].toScalar(), arg_order);
} else if (stack[idx].isTensorList()) {
// tensor list
inputs_metadata.push_back(
ParameterMetadata(stack[idx].toTensorList().vec(), arg_order));
inputs_metadata.emplace_back(stack[idx].toTensorList().vec(), arg_order);
} else if (stack[idx].isOptionalTensorList()) {
// optional tensor list: std::vector<std::optional<at::Tensor>>
std::vector<at::Tensor> tensor_list;
@ -131,27 +129,23 @@ std::vector<ParameterMetadata> unpack_input_parameters(
tensor_list.push_back(item.toOptional<at::Tensor>().value());
}
}
inputs_metadata.push_back(ParameterMetadata(tensor_list, arg_order));
inputs_metadata.emplace_back(tensor_list, arg_order);
} else if (
*arguments[idx].real_type() ==
*c10::getTypePtr<std::optional<at::Tensor>>()) {
// optional tensor
if (stack[idx].toOptional<at::Tensor>().has_value()) {
inputs_metadata.push_back(ParameterMetadata(
stack[idx].toOptional<at::Tensor>().value(), arg_order));
inputs_metadata.emplace_back(
stack[idx].toOptional<at::Tensor>().value(), arg_order);
}
} else if (stack[idx].isTensor()) {
inputs_metadata.push_back(
ParameterMetadata(stack[idx].toTensor(), arg_order));
inputs_metadata.emplace_back(stack[idx].toTensor(), arg_order);
} else if (stack[idx].isString()) {
inputs_metadata.push_back(
ParameterMetadata(stack[idx].toStringRef(), arg_order));
inputs_metadata.emplace_back(stack[idx].toStringRef(), arg_order);
} else if (stack[idx].isBool()) {
inputs_metadata.push_back(
ParameterMetadata(c10::Scalar(stack[idx].toBool()), arg_order));
inputs_metadata.emplace_back(c10::Scalar(stack[idx].toBool()), arg_order);
} else if (stack[idx].isDevice()) {
inputs_metadata.push_back(
ParameterMetadata(stack[idx].toDevice(), arg_order));
inputs_metadata.emplace_back(stack[idx].toDevice(), arg_order);
} else {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
@ -239,7 +233,7 @@ void AOTIPythonKernelHolder::cache_hit(
auto outputs = aoti_kernel_metadata.kernel_runner_->run(inputs);
for (auto& output : outputs) {
stack->push_back(output);
stack->emplace_back(output);
}
}
@ -343,8 +337,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
auto tensor_metadata = build_tensor_metadata(metadata);
test_list_metadata.push_back(tensor_metadata);
}
parameter_metadata_list.push_back(
ParameterMetadata(test_list_metadata, arg_idx));
parameter_metadata_list.emplace_back(test_list_metadata, arg_idx);
} else if (is_scalar) {
// Scalar
auto metadata = item_metadata.cast<py::dict>();
@ -367,14 +360,12 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
dtype_value);
}
parameter_metadata_list.push_back(
ParameterMetadata(c10::Scalar(scalar), arg_idx));
parameter_metadata_list.emplace_back(c10::Scalar(scalar), arg_idx);
} else if (is_string) {
// String
auto metadata = item_metadata.cast<py::dict>();
auto str_value = metadata["string_value"].cast<std::string>();
parameter_metadata_list.push_back(
ParameterMetadata(str_value, arg_idx));
parameter_metadata_list.emplace_back(str_value, arg_idx);
} else if (is_dtype) {
// Dtype
auto metadata = item_metadata.cast<py::dict>();
@ -382,8 +373,8 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_value_obj.ptr()));
auto dtype_value =
reinterpret_cast<THPDtype*>(dtype_value_obj.ptr())->scalar_type;
parameter_metadata_list.push_back(ParameterMetadata(
c10::Scalar(static_cast<int>(dtype_value)), arg_idx));
parameter_metadata_list.emplace_back(
c10::Scalar(static_cast<int>(dtype_value)), arg_idx);
} else if (is_device) {
// Device
auto metadata = item_metadata.cast<py::dict>();
@ -395,21 +386,20 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
metadata["device_index_value"].cast<c10::DeviceIndex>();
device.set_index(device_index_value);
}
parameter_metadata_list.push_back(ParameterMetadata(device, arg_idx));
parameter_metadata_list.emplace_back(device, arg_idx);
} else if (is_layout) {
auto metadata = item_metadata.cast<py::dict>();
auto layout_value_obj = metadata["layout_value"].cast<py::object>();
TORCH_INTERNAL_ASSERT(THPLayout_Check(layout_value_obj.ptr()));
auto layout_value =
reinterpret_cast<THPLayout*>(layout_value_obj.ptr())->layout;
parameter_metadata_list.push_back(ParameterMetadata(
c10::Scalar(static_cast<int>(layout_value)), arg_idx));
parameter_metadata_list.emplace_back(
c10::Scalar(static_cast<int>(layout_value)), arg_idx);
} else {
// Tensor
auto metadata = item_metadata.cast<py::dict>();
auto tensor_metadata = build_tensor_metadata(metadata);
parameter_metadata_list.push_back(
ParameterMetadata(tensor_metadata, arg_idx));
parameter_metadata_list.emplace_back(tensor_metadata, arg_idx);
}
}
@ -480,9 +470,12 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib(
schema.overload_name().empty() ? "default" : schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
std::string ns_str(qualified_name.begin(), qualified_name.begin() + pos);
std::string ns_str(
qualified_name.begin(),
qualified_name.begin() + static_cast<ptrdiff_t>(pos));
std::string func_name(
qualified_name.begin() + pos + strlen("::"), qualified_name.end());
qualified_name.begin() + static_cast<ptrdiff_t>(pos + strlen("::")),
qualified_name.end());
py::gil_scoped_acquire gil;
py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* {

View File

@ -1,6 +1,7 @@
#if !defined(C10_MOBILE) && !defined(ANDROID)
#include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
#include <iostream>
#include <utility>
namespace torch::inductor {
@ -25,8 +26,8 @@ TensorMetadata::TensorMetadata(
dtype_(dtype),
device_(device),
dispatch_key_set_(dispatch_key_set),
sizes_(sizes),
strides_(strides),
sizes_(std::move(sizes)),
strides_(std::move(strides)),
requires_grad_(requires_grad) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_symbolic_, "Not support symbolic shape now");
@ -94,25 +95,24 @@ bool TensorMetadata::operator==(const TensorMetadata& other) const {
std::ostream& operator<<(
std::ostream& stream,
const TensorMetadata& tensor_metadata) {
stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl;
stream << "dtype_: " << tensor_metadata.dtype_ << std::endl;
stream << "device_: " << tensor_metadata.device_ << std::endl;
stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << '\n';
stream << "dtype_: " << tensor_metadata.dtype_ << '\n';
stream << "device_: " << tensor_metadata.device_ << '\n';
stream << "sizes_: ";
for (const auto& size : tensor_metadata.sizes_) {
stream << size << " ";
}
stream << std::endl;
stream << '\n';
stream << "strides_: ";
for (const auto& stride : tensor_metadata.strides_) {
stream << stride << " ";
}
stream << "requires_grad_: " << tensor_metadata.requires_grad_ << std::endl;
stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_
<< std::endl;
stream << "requires_grad_: " << tensor_metadata.requires_grad_ << '\n';
stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ << '\n';
stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value()
<< std::endl;
stream << std::endl;
<< '\n';
stream << '\n';
return stream;
}
@ -138,8 +138,9 @@ ParameterMetadata::ParameterMetadata(
uint64_t input_order)
: tag_(TENSOR_LIST), order_(input_order) {
std::vector<TensorMetadata> tensor_metadata_list;
tensor_metadata_list.reserve(tensor_list.size());
for (const auto& tensor : tensor_list) {
tensor_metadata_list.push_back(TensorMetadata(tensor));
tensor_metadata_list.emplace_back(tensor);
}
value_ = tensor_metadata_list;
}
@ -147,23 +148,17 @@ ParameterMetadata::ParameterMetadata(
ParameterMetadata::ParameterMetadata(
const c10::Scalar& scalar,
uint64_t input_order)
: tag_(SCALAR), order_(input_order) {
value_ = scalar;
}
: tag_(SCALAR), value_(scalar), order_(input_order) {}
ParameterMetadata::ParameterMetadata(
const std::string& str,
uint64_t input_order)
: tag_(STRING), order_(input_order) {
value_ = str;
}
: tag_(STRING), value_(str), order_(input_order) {}
ParameterMetadata::ParameterMetadata(
const c10::Device& device,
uint64_t input_order)
: tag_(DEVICE), order_(input_order) {
value_ = device;
}
: tag_(DEVICE), value_(device), order_(input_order) {}
bool ParameterMetadata::operator==(const ParameterMetadata& other) const {
// Same type

View File

@ -10,8 +10,8 @@
namespace torch::inductor {
// Regarding a aten operation implemented by AOTI, the metadata of the input
// tensors will be cached on the disk to acclerate next run. TensorMetada
// structure is to represent the metadata of each input tensor. it includes
// tensors will be cached on the disk to accelerate next run. TensorMetada
// structure is to represent the metadata of each input tensor. It includes
// whether the tensor is symbolic, the dtype, the device, the sizes and the
// strides of the tensor. When the metadata of the input tensors is the same as
// the cached metadata, the cached kernel library will be loaded and executed.
@ -51,7 +51,6 @@ struct TensorMetadata {
TensorMetadata()
: is_symbolic_(false),
dtype_(c10::ScalarType::Undefined),
device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES),
sizes_({}),
strides_({}) {}
@ -116,7 +115,7 @@ struct ParameterMetadata {
// same tag. For example, an operation with two input tensors, the first
// tensor is a optional tensor and the second tensor is a tensor. The first
// tensor will have the order 0 and the second tensor will have the order 1.
uint64_t order_;
uint64_t order_{};
ParameterMetadata() : tag_(INVALID) {}
ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order);

View File

@ -72,7 +72,7 @@ AOTIModelContainerRunner::AOTIModelContainerRunner(
model_so_->sym("AOTInductorModelContainerGetCallSpec"));
// Hack to find the json file name from the model so file
size_t lastindex = model_so_path.find_last_of(".");
size_t lastindex = model_so_path.find_last_of('.');
std::string json_filename = model_so_path.substr(0, lastindex) + ".json";
if (file_exists(json_filename)) {
@ -189,8 +189,8 @@ void AOTIModelContainerRunner::swap_constant_buffer() {
}
std::vector<std::string> AOTIModelContainerRunner::get_call_spec() {
const char* in_spec;
const char* out_spec;
const char* in_spec = nullptr;
const char* out_spec = nullptr;
AOTI_RUNTIME_ERROR_CODE_CHECK(
get_call_spec_func_(container_handle_, &in_spec, &out_spec));
return {in_spec, out_spec};

View File

@ -10,7 +10,7 @@ AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu(
size_t num_models)
: AOTIModelContainerRunner(model_so_path, num_models, "cpu", "") {}
AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() {}
AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default;
std::vector<at::Tensor> AOTIModelContainerRunnerCpu::run(
std::vector<at::Tensor>& inputs) {

View File

@ -14,7 +14,7 @@ AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda(
device_str,
cubin_dir) {}
AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() {}
AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default;
std::vector<at::Tensor> AOTIModelContainerRunnerCuda::run(
std::vector<at::Tensor>& inputs) {

View File

@ -6,8 +6,7 @@
#include <ideep.hpp>
#endif
namespace torch {
namespace aot_inductor {
namespace torch::aot_inductor {
#if AT_MKLDNN_ENABLED()
@ -45,5 +44,4 @@ at::Tensor mkldnn_tensor_from_data_ptr(
#endif
} // namespace aot_inductor
} // namespace torch
} // namespace torch::aot_inductor

View File

@ -1,8 +1,7 @@
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
namespace torch {
namespace aot_inductor {
namespace torch::aot_inductor {
std::vector<AtenTensorHandle> unsafe_alloc_new_handles_from_tensors(
std::vector<at::Tensor>& tensors) {
@ -45,5 +44,4 @@ std::vector<at::Tensor> alloc_tensors_by_stealing_from_handles(
return result;
}
} // namespace aot_inductor
} // namespace torch
} // namespace torch::aot_inductor

View File

@ -10,8 +10,7 @@
#include <ATen/FunctionalTensorWrapper.h>
namespace torch {
namespace inductor {
namespace torch::inductor {
using namespace at;
Tensor _mm_plus_mm_out(
@ -111,5 +110,4 @@ TORCH_LIBRARY_FRAGMENT(inductor, m) {
{at::Tag::pt2_compliant_tag});
}
} // namespace inductor
} // namespace torch
} // namespace torch::inductor

View File

@ -7,8 +7,7 @@
#include <ATen/native/cuda/Resize.h>
#endif
namespace torch {
namespace inductor {
namespace torch::inductor {
using namespace at;
// NOLINTNEXTLINE(performance-unnecessary-value-param)
@ -63,5 +62,4 @@ TORCH_LIBRARY_IMPL(inductor, Functionalize, m) {
"resize_storage_bytes_", TORCH_FN(resize_storage_bytes__functionalize));
}
} // namespace inductor
} // namespace torch
} // namespace torch::inductor