mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move prim ops from JIT registration to C10 (#30612)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30612 The first version to move prim ops to c10 registration. After the reviewers are fine with the initial changes, more operators will be moved in the same style. Test Plan: Imported from OSS Differential Revision: D19237648 Pulled By: iseeyuan fbshipit-source-id: c5a519604efffb80564a556536f17d829f71d9f9
This commit is contained in:
committed by
Facebook Github Bot
parent
5579611544
commit
f362cd510d
@ -427,6 +427,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops_c10.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_string_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
|
||||
|
@ -48,6 +48,7 @@ white_list = [
|
||||
('upsample_nearest3d_backward.grad_input', datetime.date(9999, 1, 1)),
|
||||
('upsample_nearest3d_backward', datetime.date(9999, 1, 1)),
|
||||
('_test_optional_float', datetime.date(9999, 1, 1)),
|
||||
('aten::Int', datetime.date(2020, 1, 30)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -127,5 +127,31 @@ void testLiteInterpreterPrimOverload() {
|
||||
auto output = bc.run_method("forward", inputs);
|
||||
AT_ASSERT(output.toIntList()[2] == 3);
|
||||
}
|
||||
|
||||
void testLiteInterpreterPrim() {
|
||||
script::Module m("m");
|
||||
m.define(R"JIT(
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
)JIT");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
auto minput = 3.5 * torch::ones({});
|
||||
inputs.emplace_back(minput);
|
||||
auto ref = m.run_method("forward", minput);
|
||||
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
IValue res;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto bcinputs = inputs;
|
||||
res = bc.run_method("forward", bcinputs);
|
||||
}
|
||||
|
||||
auto resi = res.toInt();
|
||||
auto refi = ref.toInt();
|
||||
AT_ASSERT(resi == refi);
|
||||
}
|
||||
} // namespace torch
|
||||
} // namespace jit
|
||||
|
@ -77,7 +77,8 @@ namespace jit {
|
||||
_(LiteInterpreterPrimOverload) \
|
||||
_(CommonAncestor) \
|
||||
_(AutogradSymbols) \
|
||||
_(MobileTypeParser)
|
||||
_(MobileTypeParser) \
|
||||
_(LiteInterpreterPrim)
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
_(ArgumentSpec) \
|
||||
|
@ -142,6 +142,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/utils/memory_dag.cpp",
|
||||
"torch/csrc/jit/print_handler.cpp",
|
||||
"torch/csrc/jit/register_prim_ops.cpp",
|
||||
"torch/csrc/jit/register_prim_ops_c10.cpp",
|
||||
"torch/csrc/jit/register_string_ops.cpp",
|
||||
"torch/csrc/jit/register_special_ops.cpp",
|
||||
"torch/csrc/jit/register_distributed_ops.cpp",
|
||||
|
@ -38,7 +38,9 @@ void Function::append_operator(const std::string& name,
|
||||
auto opname = code_->op_names_.back();
|
||||
// Add "_" prefix to work around the double registration both of jit/generated
|
||||
// and here. TODO: remove it when we have separate build for lite interpreter.
|
||||
opname.name = "_" + opname.name;
|
||||
if (opname.name != "aten::Int") {
|
||||
opname.name = "_" + opname.name;
|
||||
}
|
||||
auto op = c10::Dispatcher::singleton().findSchema(opname);
|
||||
TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found.");
|
||||
code_->operators_.emplace_back(op);
|
||||
|
@ -284,12 +284,6 @@ static auto registry = torch::RegisterOperators().op(
|
||||
#endif
|
||||
return at::flatten(self, start_dim, end_dim);
|
||||
})
|
||||
).op(
|
||||
"_aten::Int",
|
||||
torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId,
|
||||
[](at::Tensor a) -> int64_t {
|
||||
return a.item<int64_t>();
|
||||
})
|
||||
).op(
|
||||
"_prim::NumToTensor",
|
||||
torch::RegisterOperators::options().catchAllKernel(
|
||||
|
@ -364,46 +364,6 @@ RegisterOperators reg(
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::Int(Tensor a) -> int",
|
||||
[](Stack& stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
push(stack, a.item<int64_t>());
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::Int(bool a) -> int",
|
||||
[](Stack& stack) {
|
||||
bool b;
|
||||
pop(stack, b);
|
||||
push(stack, (int)b);
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::Int(float a) -> int",
|
||||
[](Stack& stack) {
|
||||
double d;
|
||||
pop(stack, d);
|
||||
push(stack, (int64_t)d);
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::Int(Scalar a) -> int",
|
||||
[](Stack& stack) {
|
||||
IValue scalar;
|
||||
pop(stack, scalar);
|
||||
if (scalar.isInt()) {
|
||||
push(stack, std::move(scalar));
|
||||
} else {
|
||||
push(stack, static_cast<int64_t>(scalar.toDouble()));
|
||||
}
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::Float(Tensor a) -> float",
|
||||
[](Stack& stack) {
|
||||
|
35
torch/csrc/jit/register_prim_ops_c10.cpp
Normal file
35
torch/csrc/jit/register_prim_ops_c10.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/stack.h>
|
||||
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using torch::jit::peek;
|
||||
using torch::jit::drop;
|
||||
using torch::jit::pack;
|
||||
using torch::jit::push;
|
||||
using torch::jit::pop;
|
||||
using at::Tensor;
|
||||
using at::Scalar;
|
||||
using c10::IValue;
|
||||
|
||||
static auto registry_prim = torch::RegisterOperators().op("aten::Int.Tensor(Tensor a) -> int",
|
||||
torch::RegisterOperators::options().catchAllKernel(
|
||||
[](at::Tensor a) -> int64_t {
|
||||
return a.item<int64_t>();
|
||||
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
|
||||
).op("aten::Int.bool(bool a) -> int",
|
||||
torch::RegisterOperators::options().catchAllKernel(
|
||||
[](bool b) -> int64_t {
|
||||
return static_cast<int64_t>(b);
|
||||
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
|
||||
).op("aten::Int.float(float a) -> int",
|
||||
torch::RegisterOperators::options().catchAllKernel(
|
||||
[](double d) -> int64_t {
|
||||
return static_cast<int64_t>(d);
|
||||
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
|
||||
).op("aten::Int.Scalar(Scalar a) -> int",
|
||||
torch::RegisterOperators::options().catchAllKernel(
|
||||
[](Scalar scalar) -> int64_t {
|
||||
return scalar.toInt();
|
||||
}).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
|
||||
);
|
Reference in New Issue
Block a user