mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use JIT op registration directly for lite interpreter. (#34070)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34070 The first step to make all operators available for lite interpreter. The original code used manual registration for lite interpreter ops with a "_" prefix, for two reasons: 1. To minimize the build size. 2. To avoid duplicate registration in OSS (majorly feature testing and unit tests). Now since we have more and more models to support, the manual registration way is not practical. To make this process automatic while keeping the binary size under control, we plan to: 1. Make all necessary ops callable from lite interpreter. 2. The binary size would be increased because of step 1. Use ljk53 's custom build to selectively build the binary with ops used in specific models. The ops will be automatically collected using get_opnames. 3. The temporary "register_mobile_ops.cpp" can be removed. Test Plan: Imported from OSS Differential Revision: D20291596 Pulled By: iseeyuan fbshipit-source-id: 553b4699619cd71fea20658f3bc8c2d48852ef5c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3789db40f2
commit
361eed6a6e
@ -474,7 +474,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/fallback.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/api/function_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/runtime/vararg_functions.cpp
|
||||
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/codegen.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/eval.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/expr.cpp
|
||||
|
@ -121,6 +121,7 @@ white_list = [
|
||||
('_xnnpack::conv2d_prepack', datetime.date(2020, 4, 2)),
|
||||
('_xnnpack::linear_packed', datetime.date(2020, 4, 2)),
|
||||
('_xnnpack::linear_prepack', datetime.date(2020, 4, 2)),
|
||||
('_aten', datetime.date(2020, 4, 15)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,43 +1,55 @@
|
||||
#include "function.h"
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
#include <torch/custom_class_detail.h>
|
||||
#include "interpreter.h"
|
||||
|
||||
namespace torch{
|
||||
namespace jit{
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
char const * toString(OpCode op);
|
||||
char const* toString(OpCode op);
|
||||
namespace mobile {
|
||||
Function::Function(c10::QualifiedName name)
|
||||
: name_(name), code_(std::make_shared<Code>()) {}
|
||||
|
||||
void Function::append_instruction(OpCode op, int X, int N) {
|
||||
TORCH_CHECK(isOpSupportedInMobile(op), toString(op),
|
||||
" is not supported in mobile module.");
|
||||
TORCH_CHECK(
|
||||
isOpSupportedInMobile(op),
|
||||
toString(op),
|
||||
" is not supported in mobile module.");
|
||||
code_->instructions_.emplace_back(op, X, N);
|
||||
}
|
||||
|
||||
bool Function::append_operator(const std::string& name,
|
||||
const std::string& overload_name) {
|
||||
bool Function::append_operator(
|
||||
const std::string& name,
|
||||
const std::string& overload_name) {
|
||||
// Keep the original opname in code_
|
||||
code_->op_names_.emplace_back(name, overload_name);
|
||||
auto opname = code_->op_names_.back();
|
||||
// Add "_" prefix to work around the double registration both of jit/generated
|
||||
// and here. TODO: remove it when we have separate build for lite interpreter.
|
||||
if (opname.name != "aten::Int") {
|
||||
opname.name = "_" + opname.name;
|
||||
|
||||
auto opname_c10 = opname;
|
||||
std::function<void(Stack&)> fn;
|
||||
|
||||
// Add "_" prefix to work around the double registration, for operators
|
||||
// registered in register_mobile_ops.cpp.
|
||||
// TODO: remove it when we migrate all c10 ops.
|
||||
if (opname_c10.name != "aten::Int") {
|
||||
opname_c10.name = "_" + opname_c10.name;
|
||||
}
|
||||
auto op = c10::Dispatcher::singleton().findSchema(opname);
|
||||
if (!op.has_value()) {
|
||||
return false;
|
||||
auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
|
||||
if (op.has_value()) {
|
||||
fn = [op](Stack& stack) {
|
||||
c10::Dispatcher::singleton().callBoxed(*op, &stack);
|
||||
};
|
||||
} else { // Not found in c10 registration, use JIT dispatch
|
||||
auto jit_op = findOperatorFor(opname);
|
||||
TORCH_CHECK(
|
||||
jit_op, opname.name, ".", opname.overload_name, " cannot be found.");
|
||||
fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
|
||||
}
|
||||
// TODO: operator.h now does not depend on Node* so we can also look up operators from
|
||||
// that registry for use in mobile as a way to share implementations.
|
||||
auto fn = [op](Stack& stack) {
|
||||
c10::Dispatcher::singleton().callBoxed(*op, &stack);
|
||||
};
|
||||
|
||||
code_->operators_.emplace_back(fn);
|
||||
return true;
|
||||
}
|
||||
@ -59,5 +71,5 @@ bool Function::run(Stack& stack) const {
|
||||
return interp_state.run(stack);
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace torch
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user