mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorchEdge] backport v8 to v7 to support promoted ops as instruction (#71662)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71662 backport v8 to v7 to support promoted ops as instruction a flag to help export as instruction from v8 and export as operators for v7 and below Test Plan: ``` buck test caffe2/test/cpp/jit:jit -- LiteInterpreterTest.BackPortByteCodeModelAllVersions Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927 ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 461 tests discovered (15.693) ✓ Pass: caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions (2.712) Summary Pass: 1 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927 ``` ``` buck run mode/opt //caffe2/torch/fb/mobile/upgrader_codegen:upgrader_codegen buck test mode/opt //caffe2/test:upgrader_codegen -- mobile.test_upgrader_codegen.TestLiteScriptModule Parsing buck files: finished in 0.8 sec Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules) Building: finished in 01:39.4 min (100%) 11031/11031 jobs, 2/11031 updated Total time: 01:40.2 min More details at https://www.internalfb.com/intern/buck/build/a8b0e417-019c-44ba-be6b-23379411a965 BUILD SUCCEEDED Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 44fbfa66-cce8-4277-82ac-f89d79558581 Trace available for this run at /tmp/tpx-20220202-160956.915412/trace.log RemoteExecution session id: reSessionID-44fbfa66-cce8-4277-82ac-f89d79558581-tpx Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601 ✓ ListingSuccess: caffe2/test:upgrader_codegen : 1 tests discovered (1.249) ✓ Pass: caffe2/test:upgrader_codegen - test_generate_bytecode (mobile.test_upgrader_codegen.TestLiteScriptModule) (1.365) Summary Pass: 1 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601 ``` Reviewed By: iseeyuan Differential Revision: D33719098 fbshipit-source-id: e2d2b23d298f98e4d4fcdfc344f7b8c6f92cff26 (cherry picked from commit 81b956c23abc19489b69eee986721252474d00dc)
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2d982c739
commit
a482aeb0ce
@ -110,22 +110,28 @@ constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
|
||||
// 0x2L: (Comment missing)
|
||||
// 0x3L: (Comment missing)
|
||||
// 0x4L: (update) Added schema to function tuple. Forward-compatible change.
|
||||
// 0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize
|
||||
// extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting
|
||||
// to the unify format, the root key of tensor storage is updated from {index} to
|
||||
// {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage`
|
||||
// Forward-compatibility change.
|
||||
// 0x6L: Implicit opereator versioning using number of specified argument.
|
||||
// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845
|
||||
// for details.
|
||||
// 0x7L: Enable support for operators with default arguments plus out arguments.
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x7L;
|
||||
// 0x5L: (update) Update bytecode is sharing constant tensor files from
|
||||
// torchscript, and only serialize extra tensors that are not in the
|
||||
// torchscript constant table. Also update tensor storage schema adapting to
|
||||
// the unify format, the root key of tensor storage is updated from {index} to
|
||||
// {the_pointer_value_the_tensor.storage}, for example:
|
||||
// `140245072983168.storage` Forward-compatibility change. 0x6L: Implicit
|
||||
// opereator versioning using number of specified argument. Refer to the
|
||||
// summary of https://github.com/pytorch/pytorch/pull/56845 for details. 0x7L:
|
||||
// Enable support for operators with default arguments plus out arguments.
|
||||
// 0x8L: Emit promoted operators as instructions
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x8L;
|
||||
|
||||
// static_assert(
|
||||
// kProducedBytecodeVersion >= kProducedFileFormatVersion,
|
||||
// "kProducedBytecodeVersion must be higher or equal to
|
||||
// kProducedFileFormatVersion.");
|
||||
|
||||
// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion
|
||||
// for limited backward/forward compatibility support of bytecode. If
|
||||
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion (in loader),
|
||||
// we should support this model_version. For example, we provide a wrapper to
|
||||
// handle an updated operator.
|
||||
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion
|
||||
// (in loader), we should support this model_version. For example, we provide a
|
||||
// wrapper to handle an updated operator.
|
||||
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
|
||||
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L;
|
||||
|
||||
|
@ -571,19 +571,34 @@ namespace {
|
||||
|
||||
void compareModelOutput(
|
||||
c10::ArrayRef<IValue> actual_result_list,
|
||||
const std::vector<Tensor>& expect_result_list) {
|
||||
const std::vector<IValue>& expect_result_list) {
|
||||
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
|
||||
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
|
||||
AT_ASSERT(
|
||||
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
|
||||
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
|
||||
AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3]));
|
||||
actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
|
||||
AT_ASSERT(
|
||||
actual_result_list[1].toTensor().dim() ==
|
||||
expect_result_list[1].toTensor().dim());
|
||||
AT_ASSERT(
|
||||
actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
|
||||
AT_ASSERT(
|
||||
actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
|
||||
ASSERT_EQ(
|
||||
actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
|
||||
ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
|
||||
ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
|
||||
ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
|
||||
AT_ASSERT(
|
||||
actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
|
||||
ASSERT_EQ(
|
||||
actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
|
||||
ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
|
||||
ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
|
||||
}
|
||||
|
||||
void runAndCheckTorchScriptModel(
|
||||
std::stringstream& input_model_stream,
|
||||
const std::vector<IValue>& input_data,
|
||||
const std::vector<Tensor>& expect_result_list,
|
||||
const std::vector<IValue>& expect_result_list,
|
||||
const int64_t expect_version) {
|
||||
auto actual_version = _get_model_bytecode_version(input_model_stream);
|
||||
AT_ASSERT(actual_version == expect_version);
|
||||
@ -600,7 +615,7 @@ void runAndCheckTorchScriptModel(
|
||||
void runAndCheckBytecodeModel(
|
||||
std::stringstream& input_model_stream,
|
||||
const std::vector<IValue>& input_data,
|
||||
const std::vector<Tensor>& expect_result_list,
|
||||
const std::vector<IValue>& expect_result_list,
|
||||
const int64_t expect_version) {
|
||||
auto actual_version = _get_model_bytecode_version(input_model_stream);
|
||||
AT_ASSERT(actual_version == expect_version);
|
||||
@ -618,7 +633,7 @@ void runAndCheckBytecodeModel(
|
||||
void backportAllVersionCheck(
|
||||
std::stringstream& test_model_file_stream,
|
||||
std::vector<IValue>& input_data,
|
||||
std::vector<Tensor>& expect_result_list,
|
||||
std::vector<IValue>& expect_result_list,
|
||||
const int64_t expect_from_version) {
|
||||
auto from_version = _get_model_bytecode_version(test_model_file_stream);
|
||||
AT_ASSERT(from_version == expect_from_version);
|
||||
@ -668,6 +683,9 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
module.register_parameter("bias", torch::ones({20}), false);
|
||||
module.define(R"(
|
||||
def fn(self, x:float=1.0):
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x1 = torch.zeros(2, 2)
|
||||
x2 = torch.empty_like(torch.empty(2, 2))
|
||||
@ -677,8 +695,22 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
|
||||
x = 2 * torch.ones(1)
|
||||
h = torch.ones(1)
|
||||
torch.add(x, h, out=x)
|
||||
return (x1, x2, x3, x)
|
||||
)");
|
||||
device = torch.ones(1, 1).cpu().device.type
|
||||
is_cuda = x1.is_cuda
|
||||
bool_val = True
|
||||
check_is = [] is None
|
||||
check_is_not = [1] is not None
|
||||
check_not = not bool_val
|
||||
num_to_tensor = torch.tensor([self.fn()])
|
||||
d = {"a": "abc"}
|
||||
check_dict_index = d["a"]
|
||||
check_dim = x1.dim()
|
||||
return (
|
||||
x1, x2, x3, x, device, is_cuda, check_is,
|
||||
check_is_not, num_to_tensor, check_dict_index,
|
||||
check_dim, check_not
|
||||
)
|
||||
)");
|
||||
|
||||
torch::jit::Module module_freeze = freeze(module);
|
||||
|
||||
@ -686,12 +718,21 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
|
||||
module_freeze._save_for_mobile(input_model_stream);
|
||||
std::vector<IValue> input_data =
|
||||
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
|
||||
std::vector<Tensor> expect_result_list;
|
||||
std::vector<IValue> expect_result_list;
|
||||
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
|
||||
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
|
||||
expect_result_list.emplace_back(
|
||||
at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
|
||||
expect_result_list.emplace_back(3 * at::ones({1}));
|
||||
// "cpu" False, False, True, tensor(1), "abc", 2, False)
|
||||
expect_result_list.emplace_back(c10::IValue("cpu"));
|
||||
expect_result_list.emplace_back(c10::IValue(false));
|
||||
expect_result_list.emplace_back(c10::IValue(false));
|
||||
expect_result_list.emplace_back(c10::IValue(true));
|
||||
expect_result_list.emplace_back(c10::IValue(at::ones({1})));
|
||||
expect_result_list.emplace_back(c10::IValue("abc"));
|
||||
expect_result_list.emplace_back(c10::IValue(2));
|
||||
expect_result_list.emplace_back(c10::IValue(false));
|
||||
|
||||
backportAllVersionCheck(
|
||||
input_model_stream,
|
||||
|
@ -151,7 +151,7 @@ class TestOptimizer(TestCase):
|
||||
bn_scripted_module = torch.jit.script(bn_test_module)
|
||||
bn_scripted_module.eval()
|
||||
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(str(get_forward(bn_scripted_module._c).graph))
|
||||
|
||||
@ -252,7 +252,7 @@ class TestOptimizer(TestCase):
|
||||
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
|
||||
bn_no_forward_scripted_module.eval()
|
||||
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
|
||||
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
|
||||
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
||||
.run(bn_no_forward_scripted_module.foo.graph)
|
||||
|
||||
|
@ -27,6 +27,7 @@ constexpr int64_t kBytecodeVersionV4 = 0x4L;
|
||||
constexpr int64_t kBytecodeVersionV5 = 0x5L;
|
||||
constexpr int64_t kBytecodeVersionV6 = 0x6L;
|
||||
constexpr int64_t kBytecodeVersionV7 = 0x7L;
|
||||
constexpr int64_t kBytecodeVersionV8 = 0x8L;
|
||||
} // namespace
|
||||
|
||||
/********************** Utility Functions **********************/
|
||||
@ -434,7 +435,8 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
|
||||
{
|
||||
BytecodeEmitModeGuard argNumGuard(
|
||||
true /*emit_default_input_instructions*/,
|
||||
false /*enable_defaults_args_with_out_args*/);
|
||||
false /*enable_defaults_args_with_out_args*/,
|
||||
false /*enable_emit_promoted_ops*/);
|
||||
torch_script._save_for_mobile(
|
||||
intermediate_model_stream, extra_files, hasBytecodeDebug);
|
||||
}
|
||||
@ -501,7 +503,8 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
|
||||
{
|
||||
BytecodeEmitModeGuard argNumGuard(
|
||||
false /*emit_default_input_instructions*/,
|
||||
false /*enable_defaults_args_with_out_args*/);
|
||||
false /*enable_defaults_args_with_out_args*/,
|
||||
false /*enable_emit_promoted_ops*/);
|
||||
torch_script._save_for_mobile(
|
||||
intermediate_model_stream, extra_files, hasBytecodeDebug);
|
||||
}
|
||||
@ -512,6 +515,39 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
|
||||
return output_model_stream;
|
||||
}
|
||||
|
||||
std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
|
||||
std::shared_ptr<IStreamAdapter> rai =
|
||||
std::make_shared<IStreamAdapter>(&input_model_stream);
|
||||
auto reader = std::make_shared<PyTorchStreamReader>(rai);
|
||||
// extra_files are kept
|
||||
auto records = reader->getAllRecords();
|
||||
bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
|
||||
ExtraFilesMap extra_files;
|
||||
for (const auto& record : records) {
|
||||
std::size_t found = record.find_last_of("/\\");
|
||||
auto path = record.substr(0, found);
|
||||
if ("extra" == path) {
|
||||
extra_files.emplace(record.substr(found + 1), "");
|
||||
}
|
||||
}
|
||||
Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files);
|
||||
std::stringstream intermediate_model_stream;
|
||||
{
|
||||
BytecodeEmitModeGuard argNumGuard(
|
||||
false /*emit_default_input_instructions*/,
|
||||
true /*enable_defaults_args_with_out_args*/,
|
||||
false /*enable_emit_promoted_ops*/);
|
||||
torch_script._save_for_mobile(
|
||||
intermediate_model_stream, extra_files, hasBytecodeDebug);
|
||||
}
|
||||
|
||||
// Update the bytecode version (from 8 to 7)
|
||||
std::stringstream output_model_stream =
|
||||
update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);
|
||||
|
||||
return output_model_stream;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/********************** BackportManager **********************/
|
||||
@ -528,6 +564,7 @@ BackportManager::BackportManager() {
|
||||
registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
|
||||
registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
|
||||
registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
|
||||
registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
|
||||
}
|
||||
|
||||
std::unordered_map<
|
||||
|
@ -346,7 +346,7 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::STOREN, 1, 7},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::__IS__, 0, 0},
|
||||
Instruction{OpCode::JF, 10, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
@ -355,17 +355,17 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::LOAD, 5, 0},
|
||||
Instruction{OpCode::LOAD, 6, 0},
|
||||
Instruction{OpCode::LOAD, 7, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::JMP, 10, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::OP, 2, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::LOAD, 5, 0},
|
||||
Instruction{OpCode::LOAD, 6, 0},
|
||||
Instruction{OpCode::LOAD, 7, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::STORE, 8, 0},
|
||||
Instruction{OpCode::DROPR, 7, 0},
|
||||
Instruction{OpCode::DROPR, 6, 0},
|
||||
@ -385,7 +385,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
8
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::__is__", "", 2}),
|
||||
OperatorString({"aten::linspace", "", 7}),
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
@ -397,20 +396,20 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::STOREN, 1, 4},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::__IS__, 0, 0},
|
||||
Instruction{OpCode::JF, 7, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOADC, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::JMP, 7, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::OP, 2, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::STORE, 5, 0},
|
||||
Instruction{OpCode::DROPR, 4, 0},
|
||||
Instruction{OpCode::DROPR, 2, 0},
|
||||
@ -427,7 +426,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
5
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::__is__", "", 2}),
|
||||
OperatorString({"aten::linspace", "out", 4}),
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
@ -439,7 +437,7 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::STOREN, 1, 8},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::__IS__, 0, 0},
|
||||
Instruction{OpCode::JF, 11, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
@ -449,18 +447,18 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::LOAD, 6, 0},
|
||||
Instruction{OpCode::LOAD, 7, 0},
|
||||
Instruction{OpCode::LOAD, 8, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::JMP, 11, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::OP, 2, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::LOAD, 5, 0},
|
||||
Instruction{OpCode::LOAD, 6, 0},
|
||||
Instruction{OpCode::LOAD, 7, 0},
|
||||
Instruction{OpCode::LOAD, 8, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::STORE, 9, 0},
|
||||
Instruction{OpCode::DROPR, 8, 0},
|
||||
Instruction{OpCode::DROPR, 7, 0},
|
||||
@ -481,7 +479,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
9
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::__is__", "", 2}),
|
||||
OperatorString({"aten::logspace", "", 8}),
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
@ -493,22 +490,22 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
Instruction{OpCode::STOREN, 1, 5},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::__IS__, 0, 0},
|
||||
Instruction{OpCode::JF, 8, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOADC, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::LOAD, 5, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::JMP, 8, 0},
|
||||
Instruction{OpCode::LOAD, 1, 0},
|
||||
Instruction{OpCode::LOAD, 2, 0},
|
||||
Instruction{OpCode::LOAD, 3, 0},
|
||||
Instruction{OpCode::OP, 2, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::LOAD, 4, 0},
|
||||
Instruction{OpCode::LOAD, 5, 0},
|
||||
Instruction{OpCode::OP, 1, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::STORE, 6, 0},
|
||||
Instruction{OpCode::DROPR, 5, 0},
|
||||
Instruction{OpCode::DROPR, 4, 0},
|
||||
@ -526,7 +523,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
6
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::__is__", "", 2}),
|
||||
OperatorString({"aten::logspace", "out", 5}),
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
|
@ -1059,12 +1059,14 @@ MobileCode::MobileCode(
|
||||
std::string function_name,
|
||||
bool emit_default_input_instructions,
|
||||
bool support_default_args_before_out,
|
||||
bool emit_promoted_ops,
|
||||
size_t remaining_bailout_depth)
|
||||
: Code(new interpreter::MobileCodeImpl(
|
||||
graph,
|
||||
std::move(function_name),
|
||||
emit_default_input_instructions,
|
||||
support_default_args_before_out,
|
||||
emit_promoted_ops,
|
||||
remaining_bailout_depth)) {}
|
||||
|
||||
MobileCode::~MobileCode() = default;
|
||||
|
@ -88,6 +88,7 @@ struct TORCH_API MobileCode : Code {
|
||||
std::string function_name,
|
||||
bool emit_default_input_instructions = true,
|
||||
bool support_default_args_before_out = true,
|
||||
bool emit_promoted_ops = true,
|
||||
size_t remaining_bailout_depth = 0);
|
||||
~MobileCode();
|
||||
};
|
||||
|
@ -869,10 +869,12 @@ struct MobileCodeImpl : CodeImpl {
|
||||
std::string function_name,
|
||||
bool emit_default_input_instructions,
|
||||
bool support_default_args_before_out,
|
||||
bool emit_promoted_ops,
|
||||
size_t remaining_bailout_depth)
|
||||
: CodeImpl(graph, function_name, remaining_bailout_depth, false),
|
||||
emit_default_input_instructions_(emit_default_input_instructions),
|
||||
support_default_args_before_out_(support_default_args_before_out) {
|
||||
support_default_args_before_out_(support_default_args_before_out),
|
||||
emit_promoted_ops_(emit_promoted_ops) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
||||
run();
|
||||
}
|
||||
@ -965,7 +967,6 @@ struct MobileCodeImpl : CodeImpl {
|
||||
int64_t X = 0,
|
||||
uint64_t N = 0,
|
||||
bool emit_inputs = true) override {
|
||||
bool emit_promoted_ops_ = false;
|
||||
if (emit_promoted_ops_) {
|
||||
CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs);
|
||||
} else {
|
||||
@ -977,6 +978,8 @@ struct MobileCodeImpl : CodeImpl {
|
||||
bool emit_default_input_instructions_;
|
||||
// To support forward compatibility for bytecode version bump from v6 to v7
|
||||
bool support_default_args_before_out_;
|
||||
// To support forward compatibility for bytecode version bump from v7 to v8
|
||||
bool emit_promoted_ops_;
|
||||
};
|
||||
|
||||
} // namespace interpreter
|
||||
|
@ -201,6 +201,9 @@ struct TORCH_API BytecodeEmitMode {
|
||||
|
||||
static bool is_default_args_before_out_args_enabled();
|
||||
static void set_default_args_before_out_args_enabled(bool enabled);
|
||||
|
||||
static bool is_emit_promoted_ops_enabled();
|
||||
static void set_default_emit_promoted_ops_enabled(bool enabled);
|
||||
};
|
||||
|
||||
// RAII guard to switch the way JIT emits the bytecode for inputs.
|
||||
@ -216,24 +219,32 @@ struct TORCH_API BytecodeEmitMode {
|
||||
struct TORCH_API BytecodeEmitModeGuard {
|
||||
BytecodeEmitModeGuard(
|
||||
bool enable_default_value_for_unspecified_arg,
|
||||
bool enable_default_args_before_out_args)
|
||||
bool enable_default_args_before_out_args,
|
||||
bool enable_emit_promoted_ops)
|
||||
: prev_default_value_for_unspecified_arg_mode(
|
||||
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
|
||||
prev_default_args_before_out_args(
|
||||
BytecodeEmitMode::is_default_args_before_out_args_enabled()) {
|
||||
BytecodeEmitMode::is_default_args_before_out_args_enabled()),
|
||||
prev_default_emit_promoted_ops(
|
||||
BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
|
||||
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
|
||||
enable_default_value_for_unspecified_arg);
|
||||
BytecodeEmitMode::set_default_args_before_out_args_enabled(
|
||||
enable_default_args_before_out_args);
|
||||
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
|
||||
enable_emit_promoted_ops);
|
||||
}
|
||||
~BytecodeEmitModeGuard() {
|
||||
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
|
||||
prev_default_value_for_unspecified_arg_mode);
|
||||
BytecodeEmitMode::set_default_args_before_out_args_enabled(
|
||||
prev_default_args_before_out_args);
|
||||
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
|
||||
prev_default_emit_promoted_ops);
|
||||
}
|
||||
bool prev_default_value_for_unspecified_arg_mode;
|
||||
bool prev_default_args_before_out_args;
|
||||
bool prev_default_emit_promoted_ops;
|
||||
};
|
||||
|
||||
TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
|
||||
|
@ -142,7 +142,8 @@ mobile::Code compileGraphToMobileCode(
|
||||
graph,
|
||||
name,
|
||||
compilation_options.enable_default_value_for_unspecified_arg,
|
||||
compilation_options.enable_default_args_before_out_args);
|
||||
compilation_options.enable_default_args_before_out_args,
|
||||
compilation_options.enable_emit_promoted_ops);
|
||||
|
||||
mobile::Code mobile_code;
|
||||
|
||||
|
@ -20,6 +20,7 @@ struct TORCH_API CompilationOptions {
|
||||
bool incl_interface_call = false;
|
||||
bool enable_default_value_for_unspecified_arg = false;
|
||||
bool enable_default_args_before_out_args = true;
|
||||
bool enable_emit_promoted_ops = true;
|
||||
int model_version = caffe2::serialize::kProducedBytecodeVersion;
|
||||
};
|
||||
|
||||
|
@ -44,6 +44,8 @@ CompilationOptions getOptionsFromGlobal() {
|
||||
BytecodeEmitMode::is_default_args_before_out_args_enabled();
|
||||
compilation_options.enable_default_value_for_unspecified_arg =
|
||||
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
|
||||
compilation_options.enable_emit_promoted_ops =
|
||||
BytecodeEmitMode::is_emit_promoted_ops_enabled();
|
||||
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
|
||||
compilation_options.model_version =
|
||||
caffe2::serialize::kProducedBytecodeVersion;
|
||||
@ -864,5 +866,14 @@ void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) {
|
||||
emitDefautlArgsWithOutArgs = enabled;
|
||||
}
|
||||
|
||||
thread_local bool emitDefaultEmitPromotedOps =
|
||||
caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true;
|
||||
bool BytecodeEmitMode::is_emit_promoted_ops_enabled() {
|
||||
return emitDefaultEmitPromotedOps;
|
||||
}
|
||||
void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) {
|
||||
emitDefaultEmitPromotedOps = enabled;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user