[PyTorchEdge] backport v8 to v7 to support promoted ops as instruction (#71662)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71662

backport v8 to v7 to support promoted ops as instruction

a flag to help export as instruction from v8 and export as operators for v7 and below

Test Plan:
```
buck test caffe2/test/cpp/jit:jit -- LiteInterpreterTest.BackPortByteCodeModelAllVersions

Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
    ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 461 tests discovered (15.693)
    ✓ Pass: caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions (2.712)
Summary
  Pass: 1
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
```

```
buck run mode/opt //caffe2/torch/fb/mobile/upgrader_codegen:upgrader_codegen

buck test mode/opt //caffe2/test:upgrader_codegen -- mobile.test_upgrader_codegen.TestLiteScriptModule
Parsing buck files: finished in 0.8 sec
Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules)
Building: finished in 01:39.4 min (100%) 11031/11031 jobs, 2/11031 updated
  Total time: 01:40.2 min
More details at https://www.internalfb.com/intern/buck/build/a8b0e417-019c-44ba-be6b-23379411a965
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: 44fbfa66-cce8-4277-82ac-f89d79558581
Trace available for this run at /tmp/tpx-20220202-160956.915412/trace.log
RemoteExecution session id: reSessionID-44fbfa66-cce8-4277-82ac-f89d79558581-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601
    ✓ ListingSuccess: caffe2/test:upgrader_codegen : 1 tests discovered (1.249)
    ✓ Pass: caffe2/test:upgrader_codegen - test_generate_bytecode (mobile.test_upgrader_codegen.TestLiteScriptModule) (1.365)
Summary
  Pass: 1
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601
```

Reviewed By: iseeyuan

Differential Revision: D33719098

fbshipit-source-id: e2d2b23d298f98e4d4fcdfc344f7b8c6f92cff26
(cherry picked from commit 81b956c23abc19489b69eee986721252474d00dc)
This commit is contained in:
Pavithran Ramachandran
2022-02-14 19:42:44 -08:00
committed by PyTorch MergeBot
parent d2d982c739
commit a482aeb0ce
12 changed files with 163 additions and 53 deletions

View File

@ -110,22 +110,28 @@ constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
// 0x2L: (Comment missing)
// 0x3L: (Comment missing)
// 0x4L: (update) Added schema to function tuple. Forward-compatible change.
// 0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize
// extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting
// to the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage`
// Forward-compatibility change.
// 0x6L: Implicit opereator versioning using number of specified argument.
// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845
// for details.
// 0x7L: Enable support for operators with default arguments plus out arguments.
constexpr uint64_t kProducedBytecodeVersion = 0x7L;
// 0x5L: (update) Update bytecode is sharing constant tensor files from
// torchscript, and only serialize extra tensors that are not in the
// torchscript constant table. Also update tensor storage schema adapting to
// the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example:
// `140245072983168.storage` Forward-compatibility change. 0x6L: Implicit
// opereator versioning using number of specified argument. Refer to the
// summary of https://github.com/pytorch/pytorch/pull/56845 for details. 0x7L:
// Enable support for operators with default arguments plus out arguments.
// 0x8L: Emit promoted operators as instructions
constexpr uint64_t kProducedBytecodeVersion = 0x8L;
// static_assert(
// kProducedBytecodeVersion >= kProducedFileFormatVersion,
// "kProducedBytecodeVersion must be higher or equal to
// kProducedFileFormatVersion.");
// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion
// for limited backward/forward compatibility support of bytecode. If
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion (in loader),
// we should support this model_version. For example, we provide a wrapper to
// handle an updated operator.
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion
// (in loader), we should support this model_version. For example, we provide a
// wrapper to handle an updated operator.
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L;

View File

@ -571,19 +571,34 @@ namespace {
void compareModelOutput(
c10::ArrayRef<IValue> actual_result_list,
const std::vector<Tensor>& expect_result_list) {
const std::vector<IValue>& expect_result_list) {
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3]));
actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
AT_ASSERT(
actual_result_list[1].toTensor().dim() ==
expect_result_list[1].toTensor().dim());
AT_ASSERT(
actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
AT_ASSERT(
actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
ASSERT_EQ(
actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
AT_ASSERT(
actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
ASSERT_EQ(
actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
}
void runAndCheckTorchScriptModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
@ -600,7 +615,7 @@ void runAndCheckTorchScriptModel(
void runAndCheckBytecodeModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
@ -618,7 +633,7 @@ void runAndCheckBytecodeModel(
void backportAllVersionCheck(
std::stringstream& test_model_file_stream,
std::vector<IValue>& input_data,
std::vector<Tensor>& expect_result_list,
std::vector<IValue>& expect_result_list,
const int64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);
@ -668,6 +683,9 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
module.register_parameter("bias", torch::ones({20}), false);
module.define(R"(
def fn(self, x:float=1.0):
return x
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
@ -677,8 +695,22 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
x = 2 * torch.ones(1)
h = torch.ones(1)
torch.add(x, h, out=x)
return (x1, x2, x3, x)
)");
device = torch.ones(1, 1).cpu().device.type
is_cuda = x1.is_cuda
bool_val = True
check_is = [] is None
check_is_not = [1] is not None
check_not = not bool_val
num_to_tensor = torch.tensor([self.fn()])
d = {"a": "abc"}
check_dict_index = d["a"]
check_dim = x1.dim()
return (
x1, x2, x3, x, device, is_cuda, check_is,
check_is_not, num_to_tensor, check_dict_index,
check_dim, check_not
)
)");
torch::jit::Module module_freeze = freeze(module);
@ -686,12 +718,21 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
module_freeze._save_for_mobile(input_model_stream);
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<Tensor> expect_result_list;
std::vector<IValue> expect_result_list;
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
expect_result_list.emplace_back(
at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
expect_result_list.emplace_back(3 * at::ones({1}));
// "cpu" False, False, True, tensor(1), "abc", 2, False)
expect_result_list.emplace_back(c10::IValue("cpu"));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(true));
expect_result_list.emplace_back(c10::IValue(at::ones({1})));
expect_result_list.emplace_back(c10::IValue("abc"));
expect_result_list.emplace_back(c10::IValue(2));
expect_result_list.emplace_back(c10::IValue(false));
backportAllVersionCheck(
input_model_stream,

View File

@ -151,7 +151,7 @@ class TestOptimizer(TestCase):
bn_scripted_module = torch.jit.script(bn_test_module)
bn_scripted_module.eval()
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))
@ -252,7 +252,7 @@ class TestOptimizer(TestCase):
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
bn_no_forward_scripted_module.eval()
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(bn_no_forward_scripted_module.foo.graph)

View File

@ -27,6 +27,7 @@ constexpr int64_t kBytecodeVersionV4 = 0x4L;
constexpr int64_t kBytecodeVersionV5 = 0x5L;
constexpr int64_t kBytecodeVersionV6 = 0x6L;
constexpr int64_t kBytecodeVersionV7 = 0x7L;
constexpr int64_t kBytecodeVersionV8 = 0x8L;
} // namespace
/********************** Utility Functions **********************/
@ -434,7 +435,8 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
true /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
@ -501,7 +503,8 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
@ -512,6 +515,39 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
return output_model_stream;
}
std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
// extra_files are kept
auto records = reader->getAllRecords();
bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
ExtraFilesMap extra_files;
for (const auto& record : records) {
std::size_t found = record.find_last_of("/\\");
auto path = record.substr(0, found);
if ("extra" == path) {
extra_files.emplace(record.substr(found + 1), "");
}
}
Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files);
std::stringstream intermediate_model_stream;
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
true /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
// Update the bytecode version (from 8 to 7)
std::stringstream output_model_stream =
update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);
return output_model_stream;
}
} // namespace
/********************** BackportManager **********************/
@ -528,6 +564,7 @@ BackportManager::BackportManager() {
registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
}
std::unordered_map<

View File

@ -346,7 +346,7 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::STOREN, 1, 7},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::__IS__, 0, 0},
Instruction{OpCode::JF, 10, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
@ -355,17 +355,17 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JMP, 10, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::STORE, 8, 0},
Instruction{OpCode::DROPR, 7, 0},
Instruction{OpCode::DROPR, 6, 0},
@ -385,7 +385,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
8
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::linspace", "", 7}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
@ -397,20 +396,20 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::STOREN, 1, 4},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::__IS__, 0, 0},
Instruction{OpCode::JF, 7, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JMP, 7, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::STORE, 5, 0},
Instruction{OpCode::DROPR, 4, 0},
Instruction{OpCode::DROPR, 2, 0},
@ -427,7 +426,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
5
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::linspace", "out", 4}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
@ -439,7 +437,7 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::STOREN, 1, 8},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::__IS__, 0, 0},
Instruction{OpCode::JF, 11, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
@ -449,18 +447,18 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::LOAD, 8, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JMP, 11, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::LOAD, 8, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::STORE, 9, 0},
Instruction{OpCode::DROPR, 8, 0},
Instruction{OpCode::DROPR, 7, 0},
@ -481,7 +479,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
9
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::logspace", "", 8}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
@ -493,22 +490,22 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
Instruction{OpCode::STOREN, 1, 5},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::__IS__, 0, 0},
Instruction{OpCode::JF, 8, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JMP, 8, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::STORE, 6, 0},
Instruction{OpCode::DROPR, 5, 0},
Instruction{OpCode::DROPR, 4, 0},
@ -526,7 +523,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
6
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::logspace", "out", 5}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list

View File

@ -1059,12 +1059,14 @@ MobileCode::MobileCode(
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
bool emit_promoted_ops,
size_t remaining_bailout_depth)
: Code(new interpreter::MobileCodeImpl(
graph,
std::move(function_name),
emit_default_input_instructions,
support_default_args_before_out,
emit_promoted_ops,
remaining_bailout_depth)) {}
MobileCode::~MobileCode() = default;

View File

@ -88,6 +88,7 @@ struct TORCH_API MobileCode : Code {
std::string function_name,
bool emit_default_input_instructions = true,
bool support_default_args_before_out = true,
bool emit_promoted_ops = true,
size_t remaining_bailout_depth = 0);
~MobileCode();
};

View File

@ -869,10 +869,12 @@ struct MobileCodeImpl : CodeImpl {
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
bool emit_promoted_ops,
size_t remaining_bailout_depth)
: CodeImpl(graph, function_name, remaining_bailout_depth, false),
emit_default_input_instructions_(emit_default_input_instructions),
support_default_args_before_out_(support_default_args_before_out) {
support_default_args_before_out_(support_default_args_before_out),
emit_promoted_ops_(emit_promoted_ops) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
run();
}
@ -965,7 +967,6 @@ struct MobileCodeImpl : CodeImpl {
int64_t X = 0,
uint64_t N = 0,
bool emit_inputs = true) override {
bool emit_promoted_ops_ = false;
if (emit_promoted_ops_) {
CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs);
} else {
@ -977,6 +978,8 @@ struct MobileCodeImpl : CodeImpl {
bool emit_default_input_instructions_;
// To support forward compatibility for bytecode version bump from v6 to v7
bool support_default_args_before_out_;
// To support forward compatibility for bytecode version bump from v7 to v8
bool emit_promoted_ops_;
};
} // namespace interpreter

View File

@ -201,6 +201,9 @@ struct TORCH_API BytecodeEmitMode {
static bool is_default_args_before_out_args_enabled();
static void set_default_args_before_out_args_enabled(bool enabled);
static bool is_emit_promoted_ops_enabled();
static void set_default_emit_promoted_ops_enabled(bool enabled);
};
// RAII guard to switch the way JIT emits the bytecode for inputs.
@ -216,24 +219,32 @@ struct TORCH_API BytecodeEmitMode {
struct TORCH_API BytecodeEmitModeGuard {
BytecodeEmitModeGuard(
bool enable_default_value_for_unspecified_arg,
bool enable_default_args_before_out_args)
bool enable_default_args_before_out_args,
bool enable_emit_promoted_ops)
: prev_default_value_for_unspecified_arg_mode(
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
prev_default_args_before_out_args(
BytecodeEmitMode::is_default_args_before_out_args_enabled()) {
BytecodeEmitMode::is_default_args_before_out_args_enabled()),
prev_default_emit_promoted_ops(
BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
enable_default_value_for_unspecified_arg);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
enable_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
enable_emit_promoted_ops);
}
~BytecodeEmitModeGuard() {
BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
prev_default_value_for_unspecified_arg_mode);
BytecodeEmitMode::set_default_args_before_out_args_enabled(
prev_default_args_before_out_args);
BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
prev_default_emit_promoted_ops);
}
bool prev_default_value_for_unspecified_arg_mode;
bool prev_default_args_before_out_args;
bool prev_default_emit_promoted_ops;
};
TORCH_API IValue to_tuple(std::vector<IValue> ivalues);

View File

@ -142,7 +142,8 @@ mobile::Code compileGraphToMobileCode(
graph,
name,
compilation_options.enable_default_value_for_unspecified_arg,
compilation_options.enable_default_args_before_out_args);
compilation_options.enable_default_args_before_out_args,
compilation_options.enable_emit_promoted_ops);
mobile::Code mobile_code;

View File

@ -20,6 +20,7 @@ struct TORCH_API CompilationOptions {
bool incl_interface_call = false;
bool enable_default_value_for_unspecified_arg = false;
bool enable_default_args_before_out_args = true;
bool enable_emit_promoted_ops = true;
int model_version = caffe2::serialize::kProducedBytecodeVersion;
};

View File

@ -44,6 +44,8 @@ CompilationOptions getOptionsFromGlobal() {
BytecodeEmitMode::is_default_args_before_out_args_enabled();
compilation_options.enable_default_value_for_unspecified_arg =
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
compilation_options.enable_emit_promoted_ops =
BytecodeEmitMode::is_emit_promoted_ops_enabled();
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
compilation_options.model_version =
caffe2::serialize::kProducedBytecodeVersion;
@ -864,5 +866,14 @@ void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) {
emitDefautlArgsWithOutArgs = enabled;
}
thread_local bool emitDefaultEmitPromotedOps =
caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true;
bool BytecodeEmitMode::is_emit_promoted_ops_enabled() {
return emitDefaultEmitPromotedOps;
}
void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) {
emitDefaultEmitPromotedOps = enabled;
}
} // namespace jit
} // namespace torch