Move prim ops from JIT registration to C10 (#30612)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30612

The first version to move prim ops to c10 registration. After the reviewers are fine with the initial changes, more operators will be moved in the same style.

Test Plan: Imported from OSS

Differential Revision: D19237648

Pulled By: iseeyuan

fbshipit-source-id: c5a519604efffb80564a556536f17d829f71d9f9
This commit is contained in:
Martin Yuan
2020-01-04 13:46:05 -08:00
committed by Facebook Github Bot
parent 5579611544
commit f362cd510d
9 changed files with 69 additions and 48 deletions

View File

@ -427,6 +427,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops_c10.cpp
${TORCH_SRC_DIR}/csrc/jit/register_string_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/scope.cpp

View File

@ -48,6 +48,7 @@ white_list = [
('upsample_nearest3d_backward.grad_input', datetime.date(9999, 1, 1)),
('upsample_nearest3d_backward', datetime.date(9999, 1, 1)),
('_test_optional_float', datetime.date(9999, 1, 1)),
('aten::Int', datetime.date(2020, 1, 30)),
]

View File

@ -127,5 +127,31 @@ void testLiteInterpreterPrimOverload() {
auto output = bc.run_method("forward", inputs);
AT_ASSERT(output.toIntList()[2] == 3);
}
void testLiteInterpreterPrim() {
script::Module m("m");
m.define(R"JIT(
def forward(self, x):
return int(x)
)JIT");
std::vector<IValue> inputs;
auto minput = 3.5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("forward", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
auto bcinputs = inputs;
res = bc.run_method("forward", bcinputs);
}
auto resi = res.toInt();
auto refi = ref.toInt();
AT_ASSERT(resi == refi);
}
} // namespace torch
} // namespace jit

View File

@ -77,7 +77,8 @@ namespace jit {
_(LiteInterpreterPrimOverload) \
_(CommonAncestor) \
_(AutogradSymbols) \
_(MobileTypeParser)
_(MobileTypeParser) \
_(LiteInterpreterPrim)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \

View File

@ -142,6 +142,7 @@ libtorch_sources = [
"torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/print_handler.cpp",
"torch/csrc/jit/register_prim_ops.cpp",
"torch/csrc/jit/register_prim_ops_c10.cpp",
"torch/csrc/jit/register_string_ops.cpp",
"torch/csrc/jit/register_special_ops.cpp",
"torch/csrc/jit/register_distributed_ops.cpp",

View File

@ -38,7 +38,9 @@ void Function::append_operator(const std::string& name,
auto opname = code_->op_names_.back();
// Add "_" prefix to work around the double registration both of jit/generated
// and here. TODO: remove it when we have separate build for lite interpreter.
opname.name = "_" + opname.name;
if (opname.name != "aten::Int") {
opname.name = "_" + opname.name;
}
auto op = c10::Dispatcher::singleton().findSchema(opname);
TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found.");
code_->operators_.emplace_back(op);

View File

@ -284,12 +284,6 @@ static auto registry = torch::RegisterOperators().op(
#endif
return at::flatten(self, start_dim, end_dim);
})
).op(
"_aten::Int",
torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
[](at::Tensor a) -> int64_t {
return a.item<int64_t>();
})
).op(
"_prim::NumToTensor",
torch::RegisterOperators::options().catchAllKernel(

View File

@ -364,46 +364,6 @@ RegisterOperators reg(
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::Int(Tensor a) -> int",
[](Stack& stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.item<int64_t>());
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::Int(bool a) -> int",
[](Stack& stack) {
bool b;
pop(stack, b);
push(stack, (int)b);
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::Int(float a) -> int",
[](Stack& stack) {
double d;
pop(stack, d);
push(stack, (int64_t)d);
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::Int(Scalar a) -> int",
[](Stack& stack) {
IValue scalar;
pop(stack, scalar);
if (scalar.isInt()) {
push(stack, std::move(scalar));
} else {
push(stack, static_cast<int64_t>(scalar.toDouble()));
}
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::Float(Tensor a) -> float",
[](Stack& stack) {

View File

@ -0,0 +1,35 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/ATen.h>
#include <ATen/core/stack.h>
using Stack = std::vector<c10::IValue>;
using torch::jit::peek;
using torch::jit::drop;
using torch::jit::pack;
using torch::jit::push;
using torch::jit::pop;
using at::Tensor;
using at::Scalar;
using c10::IValue;
static auto registry_prim = torch::RegisterOperators().op("aten::Int.Tensor(Tensor a) -> int",
torch::RegisterOperators::options().catchAllKernel(
[](at::Tensor a) -> int64_t {
return a.item<int64_t>();
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
).op("aten::Int.bool(bool a) -> int",
torch::RegisterOperators::options().catchAllKernel(
[](bool b) -> int64_t {
return static_cast<int64_t>(b);
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
).op("aten::Int.float(float a) -> int",
torch::RegisterOperators::options().catchAllKernel(
[](double d) -> int64_t {
return static_cast<int64_t>(d);
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
).op("aten::Int.Scalar(Scalar a) -> int",
torch::RegisterOperators::options().catchAllKernel(
[](Scalar scalar) -> int64_t {
return scalar.toInt();
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
);