mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62418 Debug handles have one to one correspondence with instruction, so just combine them in one. Test Plan: CI Imported from OSS Reviewed By: raziel Differential Revision: D29993661 fbshipit-source-id: 125c7163174cf66624dd95f110fdc8208fea8a07
190 lines
6.3 KiB
C++
190 lines
6.3 KiB
C++
#include <torch/csrc/jit/mobile/function.h>
|
|
|
|
#include <caffe2/serialize/inline_container.h>
|
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
|
#include <torch/csrc/jit/runtime/instruction.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
#include <torch/custom_class_detail.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
char const* toString(OpCode op);
|
|
namespace mobile {
|
|
Function::Function(c10::QualifiedName name)
|
|
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
|
|
|
const c10::QualifiedName& Function::qualname() const {
|
|
return name_;
|
|
}
|
|
|
|
const std::string& Function::name() const {
|
|
return name_.name();
|
|
}
|
|
|
|
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
|
|
TORCH_CHECK(
|
|
isOpSupportedInMobile(op),
|
|
toString(op),
|
|
" is not supported in mobile module.");
|
|
code_->instructions_with_handles_.emplace_back(
|
|
Instruction(op, X, N), dbg_handle);
|
|
}
|
|
|
|
bool Function::append_operator(
|
|
const std::string& name,
|
|
const std::string& overload_name,
|
|
const c10::optional<int>& num_specified_args,
|
|
int64_t model_version, /* TODO: T90339189 deprecate all v3 when v3 models
|
|
are removed */
|
|
OperatorCacheType& operator_cache) {
|
|
// TODO: The c10::OperatorName class contains 2 std::string members, one
|
|
// for the operator name, and one for the overload name. Creating a new
|
|
// object of type c10::OperatorName creates these 2 strings, which cause
|
|
// a heap memory allocation for each element element in code->opnames_.
|
|
// This can be a significant perf. overhead for models that have a very
|
|
// large list of operators.
|
|
|
|
// Keep the original opname in code_
|
|
code_->op_names_.emplace_back(name, overload_name);
|
|
const auto& opname = code_->op_names_.back();
|
|
|
|
const auto& opname_c10 = opname;
|
|
std::function<void(Stack&)> fn;
|
|
|
|
auto it = operator_cache.find(opname);
|
|
if (it != operator_cache.end()) {
|
|
// Operator (with fully qualified name) was found in the cache.
|
|
if (it->second.has_same_arg_num(num_specified_args)) {
|
|
// And it has the same number (or unspecified number) or arguments.
|
|
code_->operators_.emplace_back(it->second.fn);
|
|
return true;
|
|
}
|
|
// Operator found, but different argument list or specified/unspecified.
|
|
// Fall back to creating one from scratch.
|
|
}
|
|
|
|
auto jit_op = findOperatorFor(opname);
|
|
std::vector<c10::Argument> args;
|
|
if (jit_op) {
|
|
fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
|
|
args = jit_op->schema().arguments();
|
|
} else {
|
|
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
|
|
if (op.has_value()) {
|
|
fn = [op](Stack& stack) { op->callBoxed(&stack); };
|
|
if (op->hasSchema()) {
|
|
args = op->schema().arguments();
|
|
} else {
|
|
TORCH_CHECK(false, "arguments are missing for operator ", opname);
|
|
}
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (model_version == 0x3LL &&
|
|
opname == c10::OperatorName("aten::_convolution", "")) {
|
|
// Since byte-code versions 0x4L, convolution has an additional
|
|
// default-value argument (allow_tf32=True, see
|
|
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
|
|
// backward compatibility with models of byte-code version <= 0x3L, where
|
|
// this bool argument does not yet exist.
|
|
fn = [fn](Stack& stack) {
|
|
stack.push_back(true);
|
|
fn(stack);
|
|
};
|
|
} else {
|
|
// num_specified_args >= 0 indicates number of arguments are available
|
|
// from model. We can use it to handle backward compatibility.
|
|
if (num_specified_args &&
|
|
num_specified_args.value() < static_cast<int64_t>(args.size())) {
|
|
// Sanity check at load time, to save perf at runtime
|
|
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
|
|
auto default_val = args[i].default_value();
|
|
TORCH_CHECK(
|
|
default_val.has_value(),
|
|
"Error happened at preparing for default values for the argument. The ",
|
|
i,
|
|
"th arguement of operator",
|
|
opname,
|
|
" does not have a specified value or default value. ");
|
|
}
|
|
fn = [fn, num_specified_args, args](Stack& stack) {
|
|
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
|
|
stack.push_back(args[i].default_value());
|
|
}
|
|
fn(stack);
|
|
};
|
|
}
|
|
}
|
|
code_->operators_.emplace_back(fn);
|
|
if (it == operator_cache.end()) {
|
|
// We came here because the operator name wasn't found in the cache,
|
|
// not because there was a schema mismatch. Do add into the cache.
|
|
operator_cache.insert(std::make_pair(
|
|
opname, OperatorFunctionWithSchema{fn, num_specified_args}));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void Function::append_constant(const c10::IValue& constant) {
|
|
code_->constants_.push_back(constant);
|
|
}
|
|
|
|
void Function::append_type(const at::TypePtr& type) {
|
|
code_->types_.push_back(type);
|
|
}
|
|
|
|
void Function::set_register_size(size_t size) {
|
|
code_->register_size_ = size;
|
|
}
|
|
|
|
int64_t Function::get_debug_handle(size_t pc) const {
|
|
TORCH_CHECK(code_, "Valid code must exist.");
|
|
TORCH_CHECK(
|
|
pc < code_->instructions_with_handles_.size(),
|
|
"Module debug info index out of boundary.");
|
|
return code_->instructions_with_handles_[pc].debug_handle;
|
|
}
|
|
|
|
void Function::setSchema(c10::FunctionSchema schema) {
|
|
schema_ = std::move(schema);
|
|
}
|
|
|
|
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
|
|
return schema_;
|
|
}
|
|
|
|
bool Function::run(Stack& stack) const {
|
|
const auto& schema = getSchema();
|
|
if (schema) { // if we have a schema then resolve optional args if any
|
|
schema->checkAndNormalizeInputs(
|
|
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
|
}
|
|
InterpreterState interp_state(code_);
|
|
return interp_state.run(stack);
|
|
}
|
|
|
|
c10::IValue Function::operator()(Stack& stack) const {
|
|
run(stack);
|
|
return stack.front();
|
|
}
|
|
|
|
const std::shared_ptr<Code> Function::get_code() const {
|
|
return code_;
|
|
}
|
|
|
|
int64_t Function::getExceptionDebugHandle() const {
|
|
size_t pc = getInterpretersExceptionPC();
|
|
// we dont do bounds check given that pc is obtained
|
|
// via internal method of getInterpretersExceptionPC
|
|
// which returns the PC of where the interpreter is.
|
|
// Although .at will do bounds check anyway.
|
|
return code_->instructions_with_handles_.at(pc).debug_handle;
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace jit
|
|
} // namespace torch
|