[18/N] Fix clang-tidy warnings in jit (#132963)

Follows #132753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132963
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-08-09 01:27:32 +00:00
committed by PyTorch MergeBot
parent 313aa151da
commit 8967d55b01
42 changed files with 141 additions and 238 deletions

View File

@ -6,9 +6,7 @@
#include <ATen/core/operator_name.h> #include <ATen/core/operator_name.h>
#include <torch/csrc/jit/runtime/instruction.h> #include <torch/csrc/jit/runtime/instruction.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
using Stack = std::vector<c10::IValue>; using Stack = std::vector<c10::IValue>;
using DebugHandle = int64_t; using DebugHandle = int64_t;
@ -34,6 +32,4 @@ struct Code {
bool initialized = false; bool initialized = false;
}; };
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -9,8 +9,7 @@
#include <c10/util/string_view.h> #include <c10/util/string_view.h>
namespace torch { namespace torch::jit {
namespace jit {
namespace { namespace {
@ -140,7 +139,7 @@ MobileDebugTable::MobileDebugTable(
} }
for (auto& val : lines.toTuple()->elements()) { for (auto& val : lines.toTuple()->elements()) {
auto tup_elems = std::move(*std::move(val).toTuple()).elements(); auto tup_elems = std::move(*val.toTuple()).elements();
// For BC we decode only tuples with 3 elements // For BC we decode only tuples with 3 elements
// assuming it contains // assuming it contains
// byte_offset, debug_handle (=source range tag), source range // byte_offset, debug_handle (=source range tag), source range
@ -159,7 +158,7 @@ MobileDebugTable::MobileDebugTable(
reader->getRecord(callstack_debug_file); reader->getRecord(callstack_debug_file);
CallStackDebugInfoUnpickler unpickler; CallStackDebugInfoUnpickler unpickler;
callstack_ptr_map_ = unpickler.unpickle( callstack_ptr_map_ = unpickler.unpickle(
std::move(callstack_data), callstack_data_size, source_range_map, cu); callstack_data, callstack_data_size, source_range_map, cu);
} }
} }
@ -229,5 +228,4 @@ std::pair<std::string, std::string> MobileDebugTable::
debug_infos, "top", top_module_type_name)); debug_infos, "top", top_module_type_name));
} }
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -5,8 +5,7 @@
#include <torch/csrc/jit/ir/scope.h> #include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/serialization/source_range_serialization.h> #include <torch/csrc/jit/serialization/source_range_serialization.h>
namespace torch { namespace torch::jit {
namespace jit {
/* /*
* MobileDebugTable: * MobileDebugTable:
* Deserializes debug_pkl and callstack_map records from PT model's zip archive * Deserializes debug_pkl and callstack_map records from PT model's zip archive
@ -53,5 +52,4 @@ class MobileDebugTable {
ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptr_map_; ska::flat_hash_map<int64_t, DebugInfoTuple> callstack_ptr_map_;
}; };
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -29,8 +29,7 @@
* only be called from one or two locations per binary. * only be called from one or two locations per binary.
*/ */
namespace torch { namespace torch::jit {
namespace jit {
/** /**
* The format of a file or data stream. * The format of a file or data stream.
@ -119,9 +118,9 @@ static void file_not_found_error() {
std::stringstream message; std::stringstream message;
message << "Error while opening file: "; message << "Error while opening file: ";
if (errno == ENOENT) { if (errno == ENOENT) {
message << "no such file or directory" << std::endl; message << "no such file or directory" << '\n';
} else { } else {
message << "error no is: " << errno << std::endl; message << "error no is: " << errno << '\n';
} }
TORCH_CHECK(false, message.str()); TORCH_CHECK(false, message.str());
} }
@ -192,5 +191,4 @@ static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
return std::make_tuple(data, buffer_size); return std::make_tuple(data, buffer_size);
} }
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -55,8 +55,7 @@ namespace flatbuffers = flatbuffers_fbsource;
#include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
#endif #endif
namespace torch { namespace torch::jit {
namespace jit {
// Our own alignment requirement does not need to be exactly the same as what // Our own alignment requirement does not need to be exactly the same as what
// flatbuffers supports, but what flatbuffers supports needs to satisfy our // flatbuffers supports, but what flatbuffers supports needs to satisfy our
@ -91,9 +90,9 @@ class FlatbufferLoader final {
ExtraFilesMap* jit_sources, ExtraFilesMap* jit_sources,
std::vector<IValue>* constants); std::vector<IValue>* constants);
typedef TypePtr (*TypeResolver)( using TypeResolver = TypePtr (*)(
const std::string& type_str, const std::string& type_str,
std::shared_ptr<CompilationUnit> cu); const std::shared_ptr<CompilationUnit>& cu);
void internal_registerTypeResolver(TypeResolver type_resolver); void internal_registerTypeResolver(TypeResolver type_resolver);
@ -187,7 +186,7 @@ IValue parseEnum(
TypePtr resolveType( TypePtr resolveType(
const std::string& type_string, const std::string& type_string,
std::shared_ptr<CompilationUnit> cu) { const std::shared_ptr<CompilationUnit>& cu) {
TypePtr type; TypePtr type;
c10::string_view type_str(type_string); c10::string_view type_str(type_string);
if (type_str.starts_with(kCustomClassPrefix)) { if (type_str.starts_with(kCustomClassPrefix)) {
@ -531,7 +530,7 @@ IValue parseList(
const mobile::serialization::IValue& ivalue) { const mobile::serialization::IValue& ivalue) {
const mobile::serialization::List* list = ivalue.val_as_List(); const mobile::serialization::List* list = ivalue.val_as_List();
auto res = c10::impl::GenericList(AnyType::get()); auto res = c10::impl::GenericList(AnyType::get());
for (int i : *list->items()) { for (auto i : *list->items()) {
res.emplace_back(loader.getIValue(i)); res.emplace_back(loader.getIValue(i));
} }
auto type = loader.getOrCreateTypeAnnotations(list->annotation_str()); auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
@ -575,11 +574,13 @@ IValue parseTuple(
FlatbufferLoader& loader, FlatbufferLoader& loader,
const mobile::serialization::IValue& ivalue) { const mobile::serialization::IValue& ivalue) {
const auto& tuple = ivalue.val_as_Tuple(); const auto& tuple = ivalue.val_as_Tuple();
const auto items = tuple->items();
std::vector<IValue> res; std::vector<IValue> res;
for (int i : *tuple->items()) { res.reserve(items->size());
for (auto i : *items) {
res.emplace_back(loader.getIValue(i)); res.emplace_back(loader.getIValue(i));
} }
return c10::ivalue::Tuple::create(res); return c10::ivalue::Tuple::create(std::move(res));
} }
IValue parseDict( IValue parseDict(
@ -939,5 +940,4 @@ bool register_flatbuffer_loader() {
return true; return true;
} }
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -18,8 +18,7 @@
* types, to avoid leaking those details to PyTorch clients. * types, to avoid leaking those details to PyTorch clients.
*/ */
namespace torch { namespace torch::jit {
namespace jit {
/// All non-copied data pointers provided to `parse_and_initialize_*` functions /// All non-copied data pointers provided to `parse_and_initialize_*` functions
/// must be aligned to this boundary. Since the Module will point directly into /// must be aligned to this boundary. Since the Module will point directly into
@ -132,5 +131,4 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module(
// no op, TODO(qihan) delete // no op, TODO(qihan) delete
TORCH_API bool register_flatbuffer_loader(); TORCH_API bool register_flatbuffer_loader();
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -5,9 +5,7 @@
#include <torch/csrc/jit/mobile/code.h> #include <torch/csrc/jit/mobile/code.h>
#include <optional> #include <optional>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
class Frame { class Frame {
public: public:
@ -48,6 +46,4 @@ class Frame {
size_t pc_{0}; size_t pc_{0};
}; };
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -8,8 +8,7 @@
#include <torch/csrc/jit/runtime/instruction.h> #include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
namespace torch { namespace torch::jit {
namespace jit {
char const* toString(OpCode op); char const* toString(OpCode op);
namespace mobile { namespace mobile {
@ -27,7 +26,11 @@ const c10::QualifiedName& Function::qualname() const {
return name_; return name_;
} }
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) { void Function::append_instruction(
OpCode op,
int64_t X,
int64_t N,
int64_t dbg_handle) {
TORCH_CHECK( TORCH_CHECK(
isOpSupportedInMobile(op), isOpSupportedInMobile(op),
toString(op), toString(op),
@ -36,7 +39,7 @@ void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
code_.debug_handles_.emplace_back(dbg_handle); code_.debug_handles_.emplace_back(dbg_handle);
} }
void Function::append_instruction(OpCode op, int X, int N) { void Function::append_instruction(OpCode op, int64_t X, int64_t N) {
TORCH_CHECK( TORCH_CHECK(
isOpSupportedInMobile(op), isOpSupportedInMobile(op),
toString(op), toString(op),
@ -166,7 +169,7 @@ const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
} }
std::optional<std::function<void(Stack&)>> makeOperatorFunction( std::optional<std::function<void(Stack&)>> makeOperatorFunction(
c10::OperatorName opname, const c10::OperatorName& opname,
std::optional<int> num_specified_args) { std::optional<int> num_specified_args) {
std::function<void(Stack&)> fn; std::function<void(Stack&)> fn;
const auto full_name = c10::toString(opname); const auto full_name = c10::toString(opname);
@ -269,5 +272,4 @@ Function& Function::registerFunc(
} }
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -7,8 +7,7 @@
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/code.h> #include <torch/csrc/jit/mobile/code.h>
namespace torch { namespace torch::jit {
namespace jit {
enum OpCode : uint8_t; enum OpCode : uint8_t;
struct Instruction; struct Instruction;
struct OperatorString; struct OperatorString;
@ -32,8 +31,8 @@ class TORCH_API Function : public torch::jit::Function {
// NOTE: the APIs below is dangerous: if you call append_instruction with // NOTE: the APIs below is dangerous: if you call append_instruction with
// dbg_handle and then call it without; then the dbg_handle will become // dbg_handle and then call it without; then the dbg_handle will become
// misaligned. Therefore only use ONE variant at time. // misaligned. Therefore only use ONE variant at time.
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle); void append_instruction(OpCode op, int64_t X, int64_t N, int64_t dbg_handle);
void append_instruction(OpCode op, int X, int N); void append_instruction(OpCode op, int64_t X, int64_t N);
void append_operator( void append_operator(
const std::string& name, const std::string& name,
const std::string& overload_name, const std::string& overload_name,
@ -76,11 +75,10 @@ class TORCH_API Function : public torch::jit::Function {
}; };
std::optional<std::function<void(Stack&)>> makeOperatorFunction( std::optional<std::function<void(Stack&)>> makeOperatorFunction(
c10::OperatorName opname, const c10::OperatorName& opname,
std::optional<int> num_specified_args); std::optional<int> num_specified_args);
TORCH_API std::string operator_str(const c10::OperatorName& opname); TORCH_API std::string operator_str(const c10::OperatorName& opname);
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -81,8 +81,7 @@
// - Argument::{known_length_,kwarg_only_} // - Argument::{known_length_,kwarg_only_}
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_} // - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
namespace torch { namespace torch::jit {
namespace jit {
using caffe2::serialize::MemoryReadAdapter; using caffe2::serialize::MemoryReadAdapter;
using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface; using caffe2::serialize::ReadAdapterInterface;
@ -91,7 +90,7 @@ OpCode parseOpCode(const char* str);
TypePtr resolveTypeNameMobile( TypePtr resolveTypeNameMobile(
const c10::QualifiedName& qn, const c10::QualifiedName& qn,
std::shared_ptr<CompilationUnit> compilation_unit) { const std::shared_ptr<CompilationUnit>& compilation_unit) {
// HACK: first we check whether the name starts with special prefix to // HACK: first we check whether the name starts with special prefix to
// tell if it's a supported pytorch class type. There are two special // tell if it's a supported pytorch class type. There are two special
// prefixes. "__torch__" for nn module, and "torch.jit" from to_backend. // prefixes. "__torch__" for nn module, and "torch.jit" from to_backend.
@ -146,7 +145,7 @@ c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
custom_class_type->getMethod("__setstate__").run(stack); custom_class_type->getMethod("__setstate__").run(stack);
return obj; return obj;
} else { } else {
auto dict = std::move(input).toGenericDict(); auto dict = input.toGenericDict();
size_t ndict = dict.size(); size_t ndict = dict.size();
auto obj = c10::ivalue::Object::create(type, ndict); auto obj = c10::ivalue::Object::create(type, ndict);
auto it = dict.begin(); auto it = dict.begin();
@ -223,8 +222,8 @@ class BytecodeDeserializer final {
// dynamically. It's used for finding the minimum required runtime to run all // dynamically. It's used for finding the minimum required runtime to run all
// operators from the given model. If it's less than the current runtime, // operators from the given model. If it's less than the current runtime,
// upgrader will be applied at loading stage. // upgrader will be applied at loading stage.
uint64_t operator_version_; uint64_t operator_version_{0};
uint64_t bytecode_version_; uint64_t bytecode_version_{0};
}; };
BytecodeDeserializer::BytecodeDeserializer( BytecodeDeserializer::BytecodeDeserializer(
@ -486,8 +485,7 @@ c10::IValue BytecodeDeserializer::readArchive(
}; };
bool bytecode_tensor_in_constants_archive = bool bytecode_tensor_in_constants_archive =
(archive_name == "bytecode" && (archive_name == "bytecode" && !isTensorInBytecodeArchive(*reader_));
!isTensorInBytecodeArchive(*reader_.get()));
auto ivalues = torch::jit::readArchiveAndTensors( auto ivalues = torch::jit::readArchiveAndTensors(
archive_name, archive_name,
@ -497,7 +495,7 @@ c10::IValue BytecodeDeserializer::readArchive(
type_resolver, type_resolver,
obj_loader, obj_loader,
device_, device_,
*reader_.get(), *reader_,
nullptr); nullptr);
return ivalues; return ivalues;
} }
@ -734,5 +732,4 @@ std::set<std::string> _export_operator_list(
} }
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -7,8 +7,7 @@
#include <caffe2/serialize/file_adapter.h> #include <caffe2/serialize/file_adapter.h>
namespace torch { namespace torch::jit {
namespace jit {
using caffe2::serialize::FileAdapter; using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter; using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::ReadAdapterInterface; using caffe2::serialize::ReadAdapterInterface;
@ -77,7 +76,7 @@ void _load_extra_only_for_mobile(
// version type_resolver and obj_loader. // version type_resolver and obj_loader.
at::TypePtr resolveTypeNameMobile( at::TypePtr resolveTypeNameMobile(
const c10::QualifiedName& qn, const c10::QualifiedName& qn,
std::shared_ptr<CompilationUnit> compilation_unit); const std::shared_ptr<CompilationUnit>& compilation_unit);
c10::StrongTypePtr typeResolverMobile( c10::StrongTypePtr typeResolverMobile(
const c10::QualifiedName& qn, const c10::QualifiedName& qn,
const std::shared_ptr<CompilationUnit>& compilation_unit); const std::shared_ptr<CompilationUnit>& compilation_unit);
@ -108,5 +107,4 @@ TORCH_API std::set<std::string> _export_operator_list(
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -9,8 +9,7 @@
#include <map> #include <map>
#include <string> #include <string>
namespace torch { namespace torch::jit {
namespace jit {
/** /**
* Loads named parameters from the serialized data in @p in. * Loads named parameters from the serialized data in @p in.
@ -34,5 +33,4 @@ TORCH_API std::map<std::string, at::Tensor> _load_parameters(
TORCH_API std::map<std::string, at::Tensor> mobile_module_to_parameter_map( TORCH_API std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
const mobile::Module& module); const mobile::Module& module);
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -5,19 +5,11 @@
* Declarations shared between import_data.cpp and export_data.cpp * Declarations shared between import_data.cpp and export_data.cpp
*/ */
namespace torch { namespace torch::jit::mobile::internal {
namespace jit {
namespace mobile {
namespace internal {
/** /**
* The name of the mobile::Module attribute which contains saved parameters, as * The name of the mobile::Module attribute which contains saved parameters, as
* a Dict of names to Tensors. Only used for Flatbuffer serialization. * a Dict of names to Tensors. Only used for Flatbuffer serialization.
*/ */
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char kSavedParametersAttributeName[] = "data"; constexpr char kSavedParametersAttributeName[] = "data";
} // namespace internal } // namespace torch::jit::mobile::internal
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -15,8 +15,7 @@
#include <torch/csrc/jit/runtime/jit_exception.h> #include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/vararg_functions.h> #include <torch/csrc/jit/runtime/vararg_functions.h>
namespace torch { namespace torch::jit {
namespace jit {
char const* toString(OpCode op); char const* toString(OpCode op);
std::ostream& operator<<(std::ostream& out, Instruction inst); std::ostream& operator<<(std::ostream& out, Instruction inst);
namespace mobile { namespace mobile {
@ -400,5 +399,4 @@ IValue& InterpreterState::reg(size_t reg) {
} }
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -5,9 +5,7 @@
#include <torch/csrc/jit/mobile/code.h> #include <torch/csrc/jit/mobile/code.h>
#include <torch/csrc/jit/mobile/frame.h> #include <torch/csrc/jit/mobile/frame.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
struct InterpreterState { struct InterpreterState {
TORCH_API explicit InterpreterState(const Code& code); TORCH_API explicit InterpreterState(const Code& code);
@ -25,6 +23,4 @@ struct InterpreterState {
}; };
const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles(); const std::vector<DebugHandle>& getInterpretersExceptionDebugHandles();
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -3,9 +3,7 @@
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/function.h> #include <torch/csrc/jit/mobile/function.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
class Module; class Module;
@ -40,6 +38,4 @@ struct TORCH_API Method {
Function* function_; Function* function_;
}; };
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -11,8 +11,7 @@
#include <c10/util/ScopeExit.h> #include <c10/util/ScopeExit.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
namespace torch { namespace torch::jit {
namespace jit {
std::ostream& operator<<(std::ostream& out, Instruction inst); std::ostream& operator<<(std::ostream& out, Instruction inst);
namespace mobile { namespace mobile {
@ -351,5 +350,4 @@ TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
} }
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -7,9 +7,7 @@
#include <utility> #include <utility>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
using Stack = std::vector<c10::IValue>; using Stack = std::vector<c10::IValue>;
// A CompilationUnit object is the one that gets executed by the lite // A CompilationUnit object is the one that gets executed by the lite
@ -135,7 +133,7 @@ class TORCH_API Module {
} }
const CompilationUnit& compilation_unit() const { const CompilationUnit& compilation_unit() const {
return *cu_.get(); return *cu_;
} }
void set_delete_memory(std::shared_ptr<char> delete_mem) { void set_delete_memory(std::shared_ptr<char> delete_mem) {
@ -192,6 +190,4 @@ struct TORCH_API ModuleInfo {
}; };
TORCH_API ModuleInfo get_module_info(const mobile::Module& module); TORCH_API ModuleInfo get_module_info(const mobile::Module& module);
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -8,8 +8,7 @@
#include <torch/csrc/jit/serialization/import_export_functions.h> #include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/custom_class_detail.h> #include <torch/custom_class_detail.h>
namespace torch { namespace torch::jit {
namespace jit {
OpCode parseOpCode(const char* str); OpCode parseOpCode(const char* str);
using c10::IValue; using c10::IValue;
@ -156,8 +155,8 @@ void parseInstructions(
"There should be three parts in an instruction. The function name is ", "There should be three parts in an instruction. The function name is ",
function_name); function_name);
OpCode op_code = opCodeCache.parse(*ins_item[0].toString()); OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
int X = ins_item[1].toInt(); auto X = ins_item[1].toInt();
int N = ins_item[2].toInt(); auto N = ins_item[2].toInt();
if (!debug_handles_list.empty()) { if (!debug_handles_list.empty()) {
int64_t debug_handle = debug_handles_list[j]; int64_t debug_handle = debug_handles_list[j];
@ -195,5 +194,4 @@ void parseRegisterSize(size_t rsize, mobile::Function* function) {
} }
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -1,9 +1,7 @@
#pragma once #pragma once
#include <torch/csrc/jit/mobile/function.h> #include <torch/csrc/jit/mobile/function.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
using c10::IValue; using c10::IValue;
TORCH_API void parseInstructions( TORCH_API void parseInstructions(
const std::string& function_name, const std::string& function_name,
@ -20,6 +18,4 @@ TORCH_API void parseRegisterSize(size_t rsize, mobile::Function* function);
TORCH_API void applyUpgrader( TORCH_API void applyUpgrader(
mobile::Function* function, mobile::Function* function,
uint64_t operator_version); uint64_t operator_version);
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -1,9 +1,7 @@
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/parse_operators.h> #include <torch/csrc/jit/mobile/parse_operators.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
void parseOperators( void parseOperators(
c10::ivalue::TupleElements&& ops_list, c10::ivalue::TupleElements&& ops_list,
@ -27,6 +25,4 @@ void parseOperators(
(module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK)); (module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK));
} }
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -1,8 +1,7 @@
#pragma once #pragma once
#include <torch/csrc/jit/mobile/function.h> #include <torch/csrc/jit/mobile/function.h>
namespace torch { namespace torch::jit {
namespace jit {
using c10::IValue; using c10::IValue;
enum MobileModuleLoadOptions { enum MobileModuleLoadOptions {
@ -23,5 +22,4 @@ TORCH_API void parseOperators(
const uint64_t& module_load_options, const uint64_t& module_load_options,
mobile::Function* function); mobile::Function* function);
} // namespace mobile } // namespace mobile
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -1,8 +1,6 @@
#include <torch/csrc/jit/mobile/prim_ops_registery.h> #include <torch/csrc/jit/mobile/prim_ops_registery.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
static std::unordered_map<std::string, std::function<void(Stack&)>>& static std::unordered_map<std::string, std::function<void(Stack&)>>&
primOpsFnTable() { primOpsFnTable() {
@ -30,6 +28,4 @@ std::function<void(Stack&)>& getPrimOpsFn(const std::string& name) {
return primOpsFnTable()[name]; return primOpsFnTable()[name];
} }
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -4,9 +4,7 @@
#include <functional> #include <functional>
#include <vector> #include <vector>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
using Stack = std::vector<c10::IValue>; using Stack = std::vector<c10::IValue>;
@ -27,6 +25,4 @@ class prim_op_fn_register {
} }
}; };
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -5,9 +5,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
thread_local KinetoEdgeCPUProfiler* tls_edge_profiler{nullptr}; thread_local KinetoEdgeCPUProfiler* tls_edge_profiler{nullptr};
@ -24,7 +22,7 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler(
: m_(m), trace_file_name_(fname) { : m_(m), trace_file_name_(fname) {
torch::profiler::impl::ExperimentalConfig experimental_config; torch::profiler::impl::ExperimentalConfig experimental_config;
// Enable hardware counters // Enable hardware counters
if (events.size()) { if (!events.empty()) {
experimental_config.performance_events = std::move(events); experimental_config.performance_events = std::move(events);
} }
@ -138,6 +136,4 @@ KinetoEdgeCPUProfiler* getCurrentEdgeProfiler() {
return tls_edge_profiler; return tls_edge_profiler;
} }
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -2,9 +2,7 @@
#include <torch/csrc/autograd/profiler_kineto.h> #include <torch/csrc/autograd/profiler_kineto.h>
#include <torch/csrc/jit/mobile/module.h> #include <torch/csrc/jit/mobile/module.h>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
// If we dont have kineto available then edge profiler does not // If we dont have kineto available then edge profiler does not
// work since it relies on Kineto // work since it relies on Kineto
@ -114,6 +112,4 @@ TORCH_API KinetoEdgeCPUProfiler* getCurrentEdgeProfiler();
#define RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( \ #define RECORD_BACKEND_MEMORY_EVENT_TO_EDGE_PROFILER( \
ptr, alloc_size, total_allocated, total_reserved, device) ptr, alloc_size, total_allocated, total_reserved, device)
#endif #endif
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -1,7 +1,8 @@
#include <ATen/ScalarOps.h> #include <ATen/ScalarOps.h>
#include <fmt/format.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h> #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
namespace torch {
namespace jit { namespace torch::jit {
void tupleIndex(Stack& stack) { void tupleIndex(Stack& stack) {
int64_t index = pop(stack).toInt(); int64_t index = pop(stack).toInt();
@ -94,8 +95,8 @@ void device(Stack& stack) {
void device_with_index(Stack& stack) { void device_with_index(Stack& stack) {
std::string type = pop(stack).toStringRef(); std::string type = pop(stack).toStringRef();
int index = pop(stack).toInt(); auto index = pop(stack).toInt();
std::string device_str = type + ":" + std::to_string(index); std::string device_str = fmt::format("{}:{}", type, index);
auto device = c10::Device(device_str); auto device = c10::Device(device_str);
push(stack, device); push(stack, device);
} }
@ -220,8 +221,7 @@ void isCuda(Stack& stack) {
} }
void numToTensorBool(Stack& stack) { void numToTensorBool(Stack& stack) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool b = false;
bool b;
pop(stack, b); pop(stack, b);
push(stack, c10::scalar_to_tensor(b)); push(stack, c10::scalar_to_tensor(b));
} }
@ -260,5 +260,4 @@ static const C10_UNUSED std::array<mobile::prim_op_fn_register, 16> op_reg = {
// mobile::prim_op_fn_register("aten::size", size) // mobile::prim_op_fn_register("aten::size", size)
}; };
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -2,8 +2,7 @@
#include <torch/csrc/jit/mobile/prim_ops_registery.h> #include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/mobile/register_ops_common_utils.h> #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
namespace torch { namespace torch::jit {
namespace jit {
void tupleIndex(Stack& stack); void tupleIndex(Stack& stack);
@ -59,5 +58,4 @@ void dictIndex(Stack& stack);
void raiseExceptionWithMessage(Stack& stack); void raiseExceptionWithMessage(Stack& stack);
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -2,10 +2,7 @@
#include <torch/csrc/jit/mobile/module.h> #include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h> #include <torch/csrc/jit/mobile/quantization.h>
namespace torch { namespace torch::jit::mobile::quantization {
namespace jit {
namespace mobile {
namespace quantization {
void PTQQuanizationHelper::quantize_dynamic( void PTQQuanizationHelper::quantize_dynamic(
torch::jit::mobile::Module& m, torch::jit::mobile::Module& m,
@ -60,7 +57,4 @@ void PTQQuanizationHelper::quantize_dynamic(
m.unsafeRemoveMethod(observe_method_name); m.unsafeRemoveMethod(observe_method_name);
m.unsafeRemoveMethod(reset_observers_method_name); m.unsafeRemoveMethod(reset_observers_method_name);
} }
} // namespace quantization } // namespace torch::jit::mobile::quantization
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -3,9 +3,7 @@
#include <c10/macros/Export.h> #include <c10/macros/Export.h>
#include <string> #include <string>
namespace torch { namespace torch::jit::mobile {
namespace jit {
namespace mobile {
class Module; class Module;
namespace quantization { namespace quantization {
/* /*
@ -33,6 +31,4 @@ class TORCH_API PTQQuanizationHelper {
const std::string& method_name); const std::string& method_name);
}; };
} // namespace quantization } // namespace quantization
} // namespace mobile } // namespace torch::jit::mobile
} // namespace jit
} // namespace torch

View File

@ -2,8 +2,7 @@
#include <ATen/core/type_factory.h> #include <ATen/core/type_factory.h>
#include <torch/csrc/jit/mobile/register_ops_common_utils.h> #include <torch/csrc/jit/mobile/register_ops_common_utils.h>
namespace torch { namespace torch::jit {
namespace jit {
int64_t normalizeIndex(int64_t idx, int64_t list_size) { int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) { if (idx < 0) {
@ -99,5 +98,4 @@ IValue tensorToListRecursive(
return result; return result;
} }
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -7,8 +7,7 @@
#include <torch/csrc/jit/runtime/jit_exception.h> #include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/vararg_functions.h> #include <torch/csrc/jit/runtime/vararg_functions.h>
namespace torch { namespace torch::jit {
namespace jit {
inline void noop(Stack& n) {} inline void noop(Stack& n) {}
@ -51,5 +50,4 @@ IValue tensorToListRecursive(
at::IntArrayRef strides, at::IntArrayRef strides,
size_t element_size); size_t element_size);
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -1,7 +1,5 @@
#include <torch/csrc/jit/mobile/type_parser.h> #include <torch/csrc/jit/mobile/type_parser.h>
#include <queue>
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h> #include <ATen/core/type_factory.h>
#include <c10/util/string_view.h> #include <c10/util/string_view.h>
@ -122,7 +120,7 @@ TypePtr TypeParser::parseNonSimple(const std::string& token) {
} }
} }
expect("]"); expect("]");
return DynamicTypeFactory::create<TupleType>(std::move(types)); return DynamicTypeFactory::create<TupleType>(types);
} }
return nullptr; return nullptr;
} }
@ -186,7 +184,6 @@ TypePtr TypeParser::parse() {
TypePtr TypeParser::parseNamedTuple(const std::string& qualified_name) { TypePtr TypeParser::parseNamedTuple(const std::string& qualified_name) {
std::vector<c10::string_view> field_names; std::vector<c10::string_view> field_names;
std::vector<TypePtr> field_types; std::vector<TypePtr> field_types;
std::string ns;
expect(","); expect(",");
expect("["); expect("[");
while (cur() != "]") { while (cur() != "]") {

View File

@ -12,8 +12,7 @@ namespace c10 {
TypePtr parseType(const std::string& pythonStr); TypePtr parseType(const std::string& pythonStr);
} // namespace c10 } // namespace c10
namespace torch { namespace torch::jit {
namespace jit {
// clang-format off // clang-format off
@ -684,5 +683,4 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
// clang-format on // clang-format on
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -11,8 +11,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
namespace torch { namespace torch::jit {
namespace jit {
struct Instruction; struct Instruction;
struct Upgrader { struct Upgrader {
int min_version; int min_version;
@ -39,5 +38,4 @@ struct ByteCodeFunctionWithOperator {
TORCH_API const std::vector<ByteCodeFunctionWithOperator>& TORCH_API const std::vector<ByteCodeFunctionWithOperator>&
getUpgraderBytecodeList(); getUpgraderBytecodeList();
} // namespace jit } // namespace torch::jit
} // namespace torch

View File

@ -159,7 +159,7 @@ struct DeviceTypePropagationPass : public PropertyPropBase {
} }
private: private:
void propagateNode(Node* n, bool _ = false) override { void propagateNode(Node* n, bool _ = true) override {
GRAPH_DEBUG("processNode"); GRAPH_DEBUG("processNode");
switch (n->kind()) { switch (n->kind()) {
case prim::If: case prim::If:

View File

@ -748,14 +748,13 @@ py::object toPyObject(IValue ivalue) {
std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack( std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const std::vector<std::shared_ptr<Operator>>& operations, const std::vector<std::shared_ptr<Operator>>& operations,
py::args args, const py::args& args,
const py::kwargs& kwargs) { const py::kwargs& kwargs) {
Stack stack; Stack stack;
if (operations.size() == 1) { if (operations.size() == 1) {
std::shared_ptr<Operator> op = operations.at(0); std::shared_ptr<Operator> op = operations.at(0);
// Create a stack full of the arguments and keyword arguments. // Create a stack full of the arguments and keyword arguments.
stack = createStackForSchema( stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt);
op->schema(), std::move(args), kwargs, std::nullopt);
return std::make_pair(std::move(op), std::move(stack)); return std::make_pair(std::move(op), std::move(stack));
} else { } else {
@ -802,10 +801,10 @@ bool checkSchemaAllowFakeScriptObject(
py::object invokeOperatorFromPython( py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations, const std::vector<std::shared_ptr<Operator>>& operations,
py::args args, const py::args& args,
const py::kwargs& kwargs, const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk) { std::optional<c10::DispatchKey> dk) {
auto [found_op, stack] = getOpWithStack(operations, std::move(args), kwargs); auto [found_op, stack] = getOpWithStack(operations, args, kwargs);
{ {
pybind11::gil_scoped_release no_gil_guard; pybind11::gil_scoped_release no_gil_guard;
if (dk) { if (dk) {

View File

@ -1250,12 +1250,12 @@ inline py::object invokeScriptMethodFromPython(
TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack( TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
const std::vector<std::shared_ptr<Operator>>& operations, const std::vector<std::shared_ptr<Operator>>& operations,
py::args args, const py::args& args,
const py::kwargs& kwargs); const py::kwargs& kwargs);
TORCH_PYTHON_API py::object invokeOperatorFromPython( TORCH_PYTHON_API py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations, const std::vector<std::shared_ptr<Operator>>& operations,
py::args args, const py::args& args,
const py::kwargs& kwargs, const py::kwargs& kwargs,
std::optional<c10::DispatchKey> dk = std::nullopt); std::optional<c10::DispatchKey> dk = std::nullopt);

View File

@ -940,19 +940,19 @@ void initJitScriptBindings(PyObject* module) {
special_magic_methods.emplace( special_magic_methods.emplace(
"__str__", "__str__",
[](const Object& self, [](const Object& self,
py::args args, const py::args& args,
const py::kwargs& kwargs) -> py::object { const py::kwargs& kwargs) -> py::object {
auto method = self.find_method("__str__"); auto method = self.find_method("__str__");
if (!method) { if (!method) {
return py::str("ScriptObject <" + self.type()->str() + ">"); return py::str("ScriptObject <" + self.type()->str() + ">");
} }
return invokeScriptMethodFromPython(*method, std::move(args), kwargs); return invokeScriptMethodFromPython(*method, args, kwargs);
}); });
special_magic_methods.emplace( special_magic_methods.emplace(
"__repr__", "__repr__",
[](const Object& self, [](const Object& self,
py::args args, const py::args& args,
const py::kwargs& kwargs) -> py::object { const py::kwargs& kwargs) -> py::object {
auto method = self.find_method("__repr__"); auto method = self.find_method("__repr__");
if (!method) { if (!method) {
@ -960,7 +960,7 @@ void initJitScriptBindings(PyObject* module) {
ss << std::hex << static_cast<const void*>(&self); ss << std::hex << static_cast<const void*>(&self);
return py::str("<torch.ScriptObject object at " + ss.str() + ">"); return py::str("<torch.ScriptObject object at " + ss.str() + ">");
} }
return invokeScriptMethodFromPython(*method, std::move(args), kwargs); return invokeScriptMethodFromPython(*method, args, kwargs);
}); });
for (const char* mm_name : magic_method_names) { for (const char* mm_name : magic_method_names) {
@ -970,7 +970,9 @@ void initJitScriptBindings(PyObject* module) {
object_class.def( object_class.def(
mm_name, mm_name,
[mm_name]( [mm_name](
const Object& self, py::args args, const py::kwargs& kwargs) { const Object& self,
const py::args& args,
const py::kwargs& kwargs) {
auto method = self.find_method(mm_name); auto method = self.find_method(mm_name);
if (!method) { if (!method) {
std::string msg = fmt::format( std::string msg = fmt::format(
@ -979,8 +981,7 @@ void initJitScriptBindings(PyObject* module) {
self.type()->str()); self.type()->str());
throw c10::NotImplementedError(msg); throw c10::NotImplementedError(msg);
} }
return invokeScriptMethodFromPython( return invokeScriptMethodFromPython(*method, args, kwargs);
*method, std::move(args), kwargs);
}); });
} }
} }
@ -1271,7 +1272,7 @@ void initJitScriptBindings(PyObject* module) {
consts["c" + std::to_string(i)] = constant; consts["c" + std::to_string(i)] = constant;
i += 1; i += 1;
} }
return std::make_tuple(pp.str(), consts); return std::make_tuple(pp.str(), std::move(consts));
}) })
.def("apply", &Module::apply) .def("apply", &Module::apply)
.def("__copy__", &Module::copy) .def("__copy__", &Module::copy)
@ -1584,7 +1585,7 @@ void initJitScriptBindings(PyObject* module) {
consts["c" + std::to_string(i)] = constant; consts["c" + std::to_string(i)] = constant;
i += 1; i += 1;
} }
return std::make_tuple(pp.str(), consts); return std::make_tuple(pp.str(), std::move(consts));
}) })
.def_property_readonly("owner", &Method::owner) .def_property_readonly("owner", &Method::owner)
.def_property_readonly("raw_owner", [](const Method& self) { .def_property_readonly("raw_owner", [](const Method& self) {

View File

@ -222,11 +222,11 @@ static torch::_RegisterOrVerify register_or_verify() {
static py::object ophandle_call_boxed( static py::object ophandle_call_boxed(
const c10::OperatorHandle& handle, const c10::OperatorHandle& handle,
py::args args, const py::args& args,
const py::kwargs& kwargs) { const py::kwargs& kwargs) {
auto stack = torch::jit::createStackForSchema( auto stack = torch::jit::createStackForSchema(
handle.schema(), handle.schema(),
std::move(args), args,
kwargs, kwargs,
/*self=*/std::nullopt); /*self=*/std::nullopt);
{ {

View File

@ -24,19 +24,19 @@ void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
} }
py::object ThroughputBenchmark::runOnce( py::object ThroughputBenchmark::runOnce(
py::args&& args, const py::args& args,
const py::kwargs& kwargs) { const py::kwargs& kwargs) {
CHECK(script_module_.initialized() ^ module_.initialized()); CHECK(script_module_.initialized() ^ module_.initialized());
if (script_module_.initialized()) { if (script_module_.initialized()) {
c10::IValue result; c10::IValue result;
{ {
pybind11::gil_scoped_release no_gil_guard; pybind11::gil_scoped_release no_gil_guard;
result = script_module_.runOnce(std::move(args), kwargs); result = script_module_.runOnce(args, kwargs);
} }
return jit::toPyObject(std::move(result)); return jit::toPyObject(std::move(result));
} else { } else {
CHECK(module_.initialized()); CHECK(module_.initialized());
return module_.runOnce(std::move(args), kwargs); return module_.runOnce(args, kwargs);
} }
} }
@ -75,12 +75,12 @@ void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const {
template <> template <>
ScriptModuleOutput ScriptModuleBenchmark::runOnce( ScriptModuleOutput ScriptModuleBenchmark::runOnce(
py::args&& args, const py::args& args,
const py::kwargs& kwargs) const { const py::kwargs& kwargs) const {
CHECK(initialized_); CHECK(initialized_);
auto& function = model_.get_method("forward").function(); auto& function = model_.get_method("forward").function();
ScriptModuleInput stack = jit::createStackForSchema( ScriptModuleInput stack = jit::createStackForSchema(
function.getSchema(), std::move(args), kwargs, model_._ivalue()); function.getSchema(), args, kwargs, model_._ivalue());
return function(std::move(stack)); return function(std::move(stack));
} }
@ -92,8 +92,9 @@ void ModuleBenchmark::runOnce(ModuleInput&& input) const {
} }
template <> template <>
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, const py::kwargs& kwargs) ModuleOutput ModuleBenchmark::runOnce(
const { const py::args& args,
const py::kwargs& kwargs) const {
CHECK(initialized_); CHECK(initialized_);
pybind11::gil_scoped_acquire gil_guard; pybind11::gil_scoped_acquire gil_guard;
return model_(*args, **kwargs); return model_(*args, **kwargs);
@ -103,7 +104,7 @@ template <>
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) { void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
jit::Stack stack = jit::createStackForSchema( jit::Stack stack = jit::createStackForSchema(
model_.get_method("forward").function().getSchema(), model_.get_method("forward").function().getSchema(),
std::move(args), args,
kwargs, kwargs,
model_._ivalue()); model_._ivalue());
inputs_.emplace_back(std::move(stack)); inputs_.emplace_back(std::move(stack));

View File

@ -79,7 +79,7 @@ class BenchmarkHelper {
// would race with Python // would race with Python
void runOnce(Input&&) const; void runOnce(Input&&) const;
// This method is to be used when calling from Python directly // This method is to be used when calling from Python directly
Output runOnce(py::args&&, const py::kwargs&) const; Output runOnce(const py::args&, const py::kwargs&) const;
// Aggregate input in the format Model expects in order to avoid further // Aggregate input in the format Model expects in order to avoid further
// conversions at the benchmark time // conversions at the benchmark time
void addInput(py::args&&, py::kwargs&&); void addInput(py::args&&, py::kwargs&&);
@ -134,15 +134,16 @@ void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const;
template <> template <>
ScriptModuleOutput ScriptModuleBenchmark::runOnce( ScriptModuleOutput ScriptModuleBenchmark::runOnce(
py::args&& args, const py::args& args,
const py::kwargs& kwargs) const; const py::kwargs& kwargs) const;
template <> template <>
void ModuleBenchmark::runOnce(ModuleInput&& input) const; void ModuleBenchmark::runOnce(ModuleInput&& input) const;
template <> template <>
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, const py::kwargs& kwargs) ModuleOutput ModuleBenchmark::runOnce(
const; const py::args& args,
const py::kwargs& kwargs) const;
template <> template <>
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs); void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
@ -180,7 +181,7 @@ class C10_HIDDEN ThroughputBenchmark {
void addInput(py::args args, py::kwargs kwargs); void addInput(py::args args, py::kwargs kwargs);
// Equivalent to just running the model directly on the given input // Equivalent to just running the model directly on the given input
py::object runOnce(py::args&& args, const py::kwargs& kwargs); py::object runOnce(const py::args& args, const py::kwargs& kwargs);
// The main method of the class allows to perform a multi-threaded benchmark // The main method of the class allows to perform a multi-threaded benchmark
// It returns BenchmarkExecutionStats object with a lot of useful statistics // It returns BenchmarkExecutionStats object with a lot of useful statistics