mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)
Summary: 1. is to convert Function -> mobile::Function 2. is to serialize mobile::Function This also opens opportunity to create mobile::Module without saving/reloading Fixes #{issue number} Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494 Reviewed By: zhxchen17 Differential Revision: D32293022 Pulled By: qihqi fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e2aeb4a7af
commit
4eb772fde6
@ -67,6 +67,7 @@ set(JIT_TEST_SRCS
|
|||||||
${JIT_TEST_ROOT}/test_irparser.cpp
|
${JIT_TEST_ROOT}/test_irparser.cpp
|
||||||
${JIT_TEST_ROOT}/test_jit_type.cpp
|
${JIT_TEST_ROOT}/test_jit_type.cpp
|
||||||
${JIT_TEST_ROOT}/test_lite_interpreter.cpp
|
${JIT_TEST_ROOT}/test_lite_interpreter.cpp
|
||||||
|
${JIT_TEST_ROOT}/test_lite_interpreter_direct.cpp
|
||||||
${JIT_TEST_ROOT}/test_lite_trainer.cpp
|
${JIT_TEST_ROOT}/test_lite_trainer.cpp
|
||||||
${JIT_TEST_ROOT}/test_memory_dag.cpp
|
${JIT_TEST_ROOT}/test_memory_dag.cpp
|
||||||
${JIT_TEST_ROOT}/test_misc.cpp
|
${JIT_TEST_ROOT}/test_misc.cpp
|
||||||
|
921
test/cpp/jit/test_lite_interpreter_direct.cpp
Normal file
921
test/cpp/jit/test_lite_interpreter_direct.cpp
Normal file
@ -0,0 +1,921 @@
|
|||||||
|
#include <test/cpp/jit/test_utils.h>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <c10/core/TensorOptions.h>
|
||||||
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||||
|
#include <torch/csrc/jit/api/module.h>
|
||||||
|
#include <torch/csrc/jit/frontend/resolver.h>
|
||||||
|
#include <torch/csrc/jit/mobile/backport.h>
|
||||||
|
#include <torch/csrc/jit/mobile/backport_manager.h>
|
||||||
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/mobile/model_compatibility.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
#include <torch/csrc/jit/mobile/parse_bytecode.h>
|
||||||
|
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||||
|
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
|
||||||
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
|
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
|
#include <torch/custom_class.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
// Tests go in torch::jit
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input: Tensor, scale:float):
|
||||||
|
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
|
||||||
|
inputs.emplace_back(at::Scalar(2.0));
|
||||||
|
auto ref = m.forward(inputs);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
res = bc.forward(inputs);
|
||||||
|
|
||||||
|
auto resd = res.toTensor();
|
||||||
|
auto refd = ref.toTensor();
|
||||||
|
ASSERT_TRUE(resd.equal(refd));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_attribute("mobile_optimized", BoolType::get(), true);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
|
||||||
|
|
||||||
|
AT_ASSERT(mobile_optimized);
|
||||||
|
m.setattr("mobile_optimized", false);
|
||||||
|
bc = jitModuleToMobile(m, options);
|
||||||
|
mobile_optimized = bc.attr("mobile_optimized", false).toBool();
|
||||||
|
AT_ASSERT(!mobile_optimized);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(
|
||||||
|
LiteInterpreterDirectTest,
|
||||||
|
MethodInvocation) { // NOLINT (use =delete in gtest)
|
||||||
|
const std::vector<std::string> test_programs{
|
||||||
|
// test invoking a method with default parameter
|
||||||
|
R"(
|
||||||
|
def test_func(self, x, b : int = 4):
|
||||||
|
return self.foo + x + b
|
||||||
|
)",
|
||||||
|
// inner method call with default parameter (gets inlined)
|
||||||
|
R"(
|
||||||
|
def add_with_default_arg(self, x, b : int = 4):
|
||||||
|
return self.foo + x + b
|
||||||
|
def test_func(self, x):
|
||||||
|
return self.add_with_default_arg(x) # invoke method w/ default arg
|
||||||
|
)",
|
||||||
|
// simple method call
|
||||||
|
R"(
|
||||||
|
def test_func(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)",
|
||||||
|
};
|
||||||
|
for (const auto& test_program : test_programs) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(test_program);
|
||||||
|
|
||||||
|
const int fortyTwo = 42; // (keep linter happy)
|
||||||
|
auto minput = fortyTwo * torch::ones({});
|
||||||
|
auto ref = m.run_method("test_func", minput);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
const auto& test_func = bc.get_method("test_func");
|
||||||
|
std::cerr << "hello " << std::endl;
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = test_func({minput});
|
||||||
|
}
|
||||||
|
std::cerr << "hello 3" << std::endl;
|
||||||
|
|
||||||
|
auto resd = res.toTensor().item<float>();
|
||||||
|
auto refd = ref.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Conv) {
|
||||||
|
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
|
||||||
|
if (s && strcmp(s, "1") == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||||
|
m.register_parameter("bias", torch::ones({20}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
|
||||||
|
)");
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
|
||||||
|
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
||||||
|
|
||||||
|
auto outputref = m.forward(inputs).toTensor();
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
auto output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Inline) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def foo1(self, x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def foo2(self, x):
|
||||||
|
return self.foo1(x) + 2
|
||||||
|
|
||||||
|
def foo3(self, x):
|
||||||
|
return self.foo2(x) + 3
|
||||||
|
)JIT");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
|
auto output = bc.get_method("foo3")(inputs);
|
||||||
|
AT_ASSERT(output.toTensor().item<float>() == 7.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Tuple) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def foo(self, x):
|
||||||
|
return (1, 2, x + 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
tuple = self.foo(x)
|
||||||
|
return tuple
|
||||||
|
)JIT");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
|
auto output = bc.get_method("forward")(inputs);
|
||||||
|
AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Dict) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def foo(self, x):
|
||||||
|
return {"result": x + 1}
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
d = self.foo(x)
|
||||||
|
return d
|
||||||
|
)JIT");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
std::vector<torch::jit::IValue> inputs({torch::ones({})});
|
||||||
|
auto output = bc.get_method("forward")(inputs);
|
||||||
|
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Prim) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def forward(self, x):
|
||||||
|
return int(x)
|
||||||
|
)JIT");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto minput = 3.5 * torch::ones({});
|
||||||
|
inputs.emplace_back(minput);
|
||||||
|
auto ref = m.run_method("forward", minput);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resi = res.toInt();
|
||||||
|
auto refi = ref.toInt();
|
||||||
|
AT_ASSERT(resi == refi);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, PrimScalar) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"JIT(
|
||||||
|
def forward(self, x):
|
||||||
|
return int(x.item())
|
||||||
|
)JIT");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto minput = 3.5 * torch::ones({});
|
||||||
|
inputs.emplace_back(minput);
|
||||||
|
auto ref = m.run_method("forward", minput);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resi = res.toInt();
|
||||||
|
auto refi = ref.toInt();
|
||||||
|
AT_ASSERT(resi == refi);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, WrongMethodName) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def add(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto minput = 5 * torch::ones({});
|
||||||
|
inputs.emplace_back(minput);
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(
|
||||||
|
bc.get_method("forward")(inputs), "is not defined");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, SetState) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.foo
|
||||||
|
def __setstate__(self, a):
|
||||||
|
self.foo = a
|
||||||
|
def forward(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto minput = 5 * torch::ones({});
|
||||||
|
inputs.emplace_back(minput);
|
||||||
|
|
||||||
|
std::stringstream ms;
|
||||||
|
m.save(ms);
|
||||||
|
auto loaded_m = load(ms);
|
||||||
|
auto ref = loaded_m.run_method("forward", minput);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
res = bc.get_method("forward")(bcinputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resd = res.toTensor().item<float>();
|
||||||
|
auto refd = ref.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
|
}
|
||||||
|
|
||||||
|
class TorchBindLiteInterpreterDirectTestStruct
|
||||||
|
: public torch::jit::CustomClassHolder {
|
||||||
|
public:
|
||||||
|
std::string get(at::Tensor t) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Hello! Your tensor has ";
|
||||||
|
ss << t.numel();
|
||||||
|
ss << " elements!";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct ClassNamespaceValue : public SugaredValue {
|
||||||
|
explicit ClassNamespaceValue(c10::QualifiedName name)
|
||||||
|
: basename_(std::move(name)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<SugaredValue> attr(
|
||||||
|
const SourceRange&,
|
||||||
|
GraphFunction&,
|
||||||
|
const std::string& name) override {
|
||||||
|
const auto fullName = c10::QualifiedName(basename_, name);
|
||||||
|
|
||||||
|
// Check to see if it is a custom class.
|
||||||
|
if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
|
||||||
|
return std::make_shared<ClassValue>(custom_class);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not a custom class, assume it's another namespace
|
||||||
|
// NOLINTNEXTLINE(performance-move-const-arg)
|
||||||
|
return std::make_shared<ClassNamespaceValue>(fullName);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string kind() const override {
|
||||||
|
return "Class Namespace";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
c10::QualifiedName basename_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestModuleResolver : public Resolver {
|
||||||
|
std::shared_ptr<SugaredValue> resolveValue(
|
||||||
|
const std::string& name,
|
||||||
|
GraphFunction&,
|
||||||
|
const SourceRange&) override {
|
||||||
|
if (name == "torch") {
|
||||||
|
return std::make_shared<BuiltinModule>("aten");
|
||||||
|
} else if (name == "__torch__") {
|
||||||
|
return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr resolveType(const std::string&, const SourceRange&) override {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, BuiltinFunction) {
|
||||||
|
script::Module m("m");
|
||||||
|
auto custom_class_obj =
|
||||||
|
make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
|
||||||
|
m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, x) -> str:
|
||||||
|
return self.my_obj.get(x)
|
||||||
|
)");
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
auto res =
|
||||||
|
bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
|
||||||
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
||||||
|
auto str = res.toStringRef();
|
||||||
|
std::string expected = "Hello! Your tensor has 12 elements!";
|
||||||
|
AT_ASSERT(str == expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined FB_XPLAT_BUILD
|
||||||
|
TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
|
||||||
|
auto runtime_bytecode_version = _get_runtime_bytecode_version();
|
||||||
|
AT_ASSERT(
|
||||||
|
runtime_bytecode_version ==
|
||||||
|
caffe2::serialize::kMaxSupportedBytecodeVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
|
||||||
|
auto runtime_operators_version = _get_runtime_operators_min_max_versions();
|
||||||
|
AT_ASSERT(
|
||||||
|
runtime_operators_version.first ==
|
||||||
|
caffe2::serialize::kMinSupportedFileFormatVersion &&
|
||||||
|
runtime_operators_version.second ==
|
||||||
|
caffe2::serialize::kMaxSupportedFileFormatVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The test below is disarmed for FB internal xplat builds since
|
||||||
|
* BUCK requires us to pass in the script_module_v4.ptl file in
|
||||||
|
* as a resource dependency of the build rule for this file, and
|
||||||
|
* we would need to access it via the C++ Resources API instead
|
||||||
|
* of directly reading from disk (which is what the open source
|
||||||
|
* build/run does).
|
||||||
|
*/
|
||||||
|
TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
|
||||||
|
std::string filePath(__FILE__);
|
||||||
|
auto test_model_file_v4 =
|
||||||
|
filePath.substr(0, filePath.find_last_of("/\\") + 1);
|
||||||
|
test_model_file_v4.append("script_module_v4.ptl");
|
||||||
|
|
||||||
|
auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
|
||||||
|
AT_ASSERT(version_v4 == 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // !defined(FB_XPLAT_BUILD)
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
|
||||||
|
auto runtime_ops = _get_runtime_ops_and_info();
|
||||||
|
// Ballpark estimate of the minimal number of ops; just used to
|
||||||
|
// verify API returns a reasonably large number.
|
||||||
|
AT_ASSERT(runtime_ops.size() > 2900);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, Eval) {
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def __init__(self, x):
|
||||||
|
self.training = True
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.dropout(input, 1.0, self.training)
|
||||||
|
)");
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
|
||||||
|
inputs.push_back(torch::ones({1, 1, 28, 28}));
|
||||||
|
m.eval();
|
||||||
|
auto outputref = m.forward(inputs).toTensor();
|
||||||
|
|
||||||
|
// save m in training mode to make sure that mobile eval() will correctly
|
||||||
|
// change back to eval mode
|
||||||
|
m.train();
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
bc.eval();
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
auto output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(
|
||||||
|
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def add(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)");
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def add_it(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto minput = 5 * torch::ones({});
|
||||||
|
inputs.emplace_back(minput);
|
||||||
|
auto ref = m.get_method("add_it")(inputs);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
auto bcinputs = inputs;
|
||||||
|
auto method = bc.find_method("add_it");
|
||||||
|
AT_ASSERT(method != c10::nullopt);
|
||||||
|
res = (*method)(std::move(bcinputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resd = res.toTensor().item<float>();
|
||||||
|
auto refd = ref.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def add_three(self, x, y):
|
||||||
|
return self.foo + x + y
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto inputx = 5 * torch::ones({});
|
||||||
|
auto inputy = 4 * torch::ones({});
|
||||||
|
auto ref = m.run_method("add_three", inputx, inputy);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res = bc.run_method("add_three", inputx, inputy);
|
||||||
|
|
||||||
|
auto resd = res.toTensor().item<float>();
|
||||||
|
auto refd = ref.toTensor().item<float>();
|
||||||
|
AT_ASSERT(resd == refd);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, DuplicateSetState) {
|
||||||
|
Module m("M");
|
||||||
|
m.register_parameter("foo", torch::ones({}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def __getstate__(self):
|
||||||
|
return self.foo + self.foo
|
||||||
|
def __setstate__(self, a):
|
||||||
|
self.foo = a
|
||||||
|
def forward(self, x):
|
||||||
|
b = 4
|
||||||
|
return self.foo + x + b
|
||||||
|
)");
|
||||||
|
|
||||||
|
Module b("B");
|
||||||
|
b.register_module("M0", m);
|
||||||
|
b.register_module("M1", m);
|
||||||
|
b.define(R"(
|
||||||
|
def forward(self, x):
|
||||||
|
return self.M0.forward(x) + self.M1.forward(x)
|
||||||
|
)");
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
const auto methods = bc.get_methods();
|
||||||
|
const size_t expected_n = 3;
|
||||||
|
ASSERT_EQ(methods.size(), expected_n);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
|
||||||
|
torch::jit::Module m("m");
|
||||||
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||||
|
m.register_parameter("bias", torch::ones({20}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
x1 = torch.zeros(2, 2)
|
||||||
|
x2 = torch.empty_like(torch.empty(2, 2))
|
||||||
|
x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
|
||||||
|
return (x1, x2, x3)
|
||||||
|
)");
|
||||||
|
m.eval();
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module ptl_model = jitModuleToMobile(m, options);
|
||||||
|
std::set<std::string> operator_names =
|
||||||
|
torch::jit::mobile::_export_operator_list(ptl_model);
|
||||||
|
std::set<std::string> expected_operator_names = {
|
||||||
|
"aten::_convolution",
|
||||||
|
"aten::empty.memory_format",
|
||||||
|
"aten::empty_like",
|
||||||
|
"aten::zeros",
|
||||||
|
};
|
||||||
|
EXPECT_EQ(operator_names, expected_operator_names)
|
||||||
|
<< "Expected the root operator lists to be the same";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
|
||||||
|
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
|
||||||
|
if (s && strcmp(s, "1") == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
|
||||||
|
Module m("m");
|
||||||
|
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
|
||||||
|
m.register_parameter("bias", torch::ones({20}), false);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
|
||||||
|
)");
|
||||||
|
|
||||||
|
inputs.emplace_back(torch::ones({1, 1, 28, 28}));
|
||||||
|
|
||||||
|
auto outputref = m.forward(inputs).toTensor();
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 1; ++i) {
|
||||||
|
res = bc.get_method("forward")(inputs);
|
||||||
|
}
|
||||||
|
auto output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(output.equal(outputref));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void testLiteModuleCompareResultTensors(
|
||||||
|
Module& m,
|
||||||
|
const std::vector<torch::jit::IValue>& inputs,
|
||||||
|
const std::string& method_name = "forward") {
|
||||||
|
auto outputref = m.get_method(method_name)(inputs).toTensor();
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
IValue res;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = bc.get_method(method_name)(inputs);
|
||||||
|
}
|
||||||
|
auto output = res.toTensor();
|
||||||
|
AT_ASSERT(outputref.dim() == output.dim());
|
||||||
|
AT_ASSERT(output.equal(outputref));
|
||||||
|
}
|
||||||
|
|
||||||
|
void testDefaultArgsPinv2(int num_args) {
|
||||||
|
Module m("m");
|
||||||
|
if (num_args == 1) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input)
|
||||||
|
)");
|
||||||
|
} else if (num_args == 2) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input, 1e-5)
|
||||||
|
)");
|
||||||
|
} else if (num_args == 3) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input, 1e-5, True)
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
const int N = 28;
|
||||||
|
auto input = torch::range(1, N * N, 1);
|
||||||
|
input[0] = 1; // a more stable matrix
|
||||||
|
input = input.view({N, N});
|
||||||
|
inputs.emplace_back(input);
|
||||||
|
testLiteModuleCompareResultTensors(m, inputs);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#if !defined FB_XPLAT_BUILD
|
||||||
|
TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
|
||||||
|
// Test with different number of specified arguments.
|
||||||
|
// Arguments not specified take default value.
|
||||||
|
for (int num_args = 1; num_args <= 3; ++num_args) {
|
||||||
|
testDefaultArgsPinv2(num_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytecode with one specified argument:
|
||||||
|
// (6,
|
||||||
|
// ('__torch__.m.forward',
|
||||||
|
// (('instructions',
|
||||||
|
// (('STOREN', 1, 2),
|
||||||
|
// ('DROPR', 1, 0),
|
||||||
|
// ('MOVE', 2, 0),
|
||||||
|
// ('OP', 0, 0),
|
||||||
|
// ('RET', 0, 0))),
|
||||||
|
// ('operators', (('aten::linalg_pinv', '', 1),)),
|
||||||
|
// ('constants', (False, 1e-15)), # default constants are not
|
||||||
|
// used
|
||||||
|
// ('types', ()),
|
||||||
|
// ('register_size', 2)),
|
||||||
|
// (('arguments',
|
||||||
|
// ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
|
||||||
|
// None)),
|
||||||
|
// (('name', 'input'), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)))),
|
||||||
|
// ('returns',
|
||||||
|
// ((('name', ''), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)),)))))
|
||||||
|
|
||||||
|
// bytecode with 2 specified argument:
|
||||||
|
// (6,
|
||||||
|
// ('__torch__.m.forward',
|
||||||
|
// (('instructions',
|
||||||
|
// (('STOREN', 1, 2),
|
||||||
|
// ('DROPR', 1, 0),
|
||||||
|
// ('MOVE', 2, 0),
|
||||||
|
// ('LOADC', 1, 0), # added LOADC for specified argument
|
||||||
|
// ('OP', 0, 0),
|
||||||
|
// ('RET', 0, 0))),
|
||||||
|
// ('operators', (('aten::linalg_pinv', '', 2),)),
|
||||||
|
// ('constants', (False, 1e-05)), # updated constant table
|
||||||
|
// ('types', ()),
|
||||||
|
// ('register_size', 2)),
|
||||||
|
// (('arguments',
|
||||||
|
// ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
|
||||||
|
// None)),
|
||||||
|
// (('name', 'input'), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)))),
|
||||||
|
// ('returns',
|
||||||
|
// ((('name', ''), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)),)))))
|
||||||
|
|
||||||
|
// bytecode with 3 specified arguments:
|
||||||
|
// (6,
|
||||||
|
// ('__torch__.m.forward',
|
||||||
|
// (('instructions',
|
||||||
|
// (('STOREN', 1, 2),
|
||||||
|
// ('DROPR', 1, 0),
|
||||||
|
// ('MOVE', 2, 0),
|
||||||
|
// ('LOADC', 1, 0),
|
||||||
|
// ('LOADC', 0, 0),
|
||||||
|
// ('OP', 0, 0),
|
||||||
|
// ('RET', 0, 0))),
|
||||||
|
// ('operators', (('aten::linalg_pinv', '', 3),)),
|
||||||
|
// ('constants', (True, 1e-05)),
|
||||||
|
// ('types', ()),
|
||||||
|
// ('register_size', 2)),
|
||||||
|
// (('arguments',
|
||||||
|
// ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
|
||||||
|
// None)),
|
||||||
|
// (('name', 'input'), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)))),
|
||||||
|
// ('returns',
|
||||||
|
// ((('name', ''), ('type', 'Tensor'), ('default_value',
|
||||||
|
// None)),)))))
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
|
||||||
|
// The second argument is specified, but the value is the same as the default
|
||||||
|
// value. It's treated as "not specified" since the value can be fetched from
|
||||||
|
// schema.
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_tensorinv(input, 2)
|
||||||
|
)");
|
||||||
|
torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
|
||||||
|
auto arg_nums = code.op_to_num_specified_args();
|
||||||
|
ASSERT_EQ(arg_nums.size(), 1);
|
||||||
|
ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
const int N = 4;
|
||||||
|
auto input = torch::rand({N, N, N, N});
|
||||||
|
inputs.emplace_back(input);
|
||||||
|
testLiteModuleCompareResultTensors(m, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testDefaultArgsPinvWithOutArg2(int num_args) {
|
||||||
|
Module m("m");
|
||||||
|
if (num_args == 1) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input, out=input)
|
||||||
|
)");
|
||||||
|
} else if (num_args == 2) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input, 1e-5, out=input)
|
||||||
|
)");
|
||||||
|
} else if (num_args == 3) {
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.linalg_pinv(input, 1e-5, True, out=input)
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
const int N = 28;
|
||||||
|
auto input = torch::range(1, N * N, 1);
|
||||||
|
input[0] = 10000; // a more stable matrix
|
||||||
|
input = input.view({N, N});
|
||||||
|
auto ref = m.run_method("forward", input);
|
||||||
|
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
|
||||||
|
TORCH_CHECK(input.equal(ref.toTensor()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
|
||||||
|
// Test with different number of specified arguments + out arg.
|
||||||
|
// Arguments not specified take default value.
|
||||||
|
for (int num_args = 1; num_args <= 3; ++num_args) {
|
||||||
|
testDefaultArgsPinvWithOutArg2(num_args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, x, h):
|
||||||
|
torch.add(x, h, out=x)
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
auto input_x = 2 * torch::ones({});
|
||||||
|
auto input_h = torch::ones({});
|
||||||
|
auto ref = m.run_method("forward", input_x, input_h);
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
mobile::Module bc = jitModuleToMobile(m, options);
|
||||||
|
bc.run_method("forward", input_x, input_h);
|
||||||
|
AT_ASSERT(input_x.equal(4 * torch::ones({})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
|
||||||
|
Module a("A");
|
||||||
|
a.define(R"(
|
||||||
|
def bar(self, x, y):
|
||||||
|
return x + y
|
||||||
|
)");
|
||||||
|
Module b("B");
|
||||||
|
b.register_module("A0", a);
|
||||||
|
b.define(R"(
|
||||||
|
def foo(self, x, y):
|
||||||
|
return self.A0.bar(x, y) + 2
|
||||||
|
)");
|
||||||
|
Module c("C");
|
||||||
|
c.register_module("B0", b);
|
||||||
|
c.define(R"(
|
||||||
|
def forward(self, x, y):
|
||||||
|
return self.B0.foo(x, y) + 3
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<IValue> inputs;
|
||||||
|
inputs.emplace_back(torch::rand({2, 4}));
|
||||||
|
inputs.emplace_back(torch::rand({13, 9}));
|
||||||
|
|
||||||
|
CompilationOptions options;
|
||||||
|
auto lite_m = jitModuleToMobile(c, options);
|
||||||
|
std::string error_pattern = R"(
|
||||||
|
Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
|
||||||
|
Traceback of TorchScript (most recent call last):
|
||||||
|
File "<string>", line 3, in <unknown>
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return self.B0.foo(x, y) + 3
|
||||||
|
~~~~~~~~~~~ <--- HERE
|
||||||
|
|
||||||
|
File "<string>", line 3, in foo
|
||||||
|
|
||||||
|
def foo(self, x, y):
|
||||||
|
return self.A0.bar(x, y) + 2
|
||||||
|
~~~~~~~~~~~ <--- HERE
|
||||||
|
|
||||||
|
File "<string>", line 3, in bar
|
||||||
|
|
||||||
|
def bar(self, x, y):
|
||||||
|
return x + y
|
||||||
|
~~~~~ <--- HERE
|
||||||
|
)";
|
||||||
|
ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
|
||||||
|
}
|
||||||
|
#endif // !defined(FB_XPLAT_BUILD)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static auto reg =
|
||||||
|
torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
|
||||||
|
"_TorchScriptTesting",
|
||||||
|
"_LiteInterpreterDirectTest")
|
||||||
|
.def(torch::init<>())
|
||||||
|
.def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
|
||||||
|
.def_pickle(
|
||||||
|
// __getattr__
|
||||||
|
[](const c10::intrusive_ptr<
|
||||||
|
TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
|
||||||
|
return 0;
|
||||||
|
},
|
||||||
|
// __setattr__
|
||||||
|
[](int64_t) {
|
||||||
|
return c10::make_intrusive<
|
||||||
|
TorchBindLiteInterpreterDirectTestStruct>();
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
|
||||||
|
// Create 3 methods:
|
||||||
|
//
|
||||||
|
// 1. forward() returns a tensor with dtype=torch.int64 (4)
|
||||||
|
// 2. forward2() returns a tensor with dtype=torch.float32 (6)
|
||||||
|
// 3. forward3() returns a tensor with dtype=torch.float32 but
|
||||||
|
// the dtype is inferred by the input tensor's dtype
|
||||||
|
//
|
||||||
|
// If caching works correctly, then the result from the full-jit
|
||||||
|
// module and the lite module will be the same. Otherwise, it
|
||||||
|
// will be different if we don't correctly ignore the cache
|
||||||
|
// entry for an operator that has a different number of
|
||||||
|
// arguments.
|
||||||
|
Module m("m");
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self):
|
||||||
|
ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
|
||||||
|
return ret1.fill_(25)
|
||||||
|
)");
|
||||||
|
m.define(R"(
|
||||||
|
def forward2(self):
|
||||||
|
ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
|
||||||
|
return ret1.fill_(32.0)
|
||||||
|
)");
|
||||||
|
m.define(R"(
|
||||||
|
def forward3(self):
|
||||||
|
ret1 = torch.new_empty(torch.zeros(10), [10])
|
||||||
|
return ret1.fill_(12.0)
|
||||||
|
)");
|
||||||
|
|
||||||
|
std::vector<torch::jit::IValue> inputs;
|
||||||
|
testLiteModuleCompareResultTensors(m, inputs, "forward");
|
||||||
|
testLiteModuleCompareResultTensors(m, inputs, "forward2");
|
||||||
|
testLiteModuleCompareResultTensors(m, inputs, "forward3");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
@ -24,6 +24,10 @@ struct TORCH_API GraphFunction : public Function {
|
|||||||
|
|
||||||
void run(Stack& stack) override;
|
void run(Stack& stack) override;
|
||||||
|
|
||||||
|
std::function<void(GraphFunction&)> function_creator() const {
|
||||||
|
return function_creator_;
|
||||||
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
||||||
Stack& stack,
|
Stack& stack,
|
||||||
TaskLauncher taskLauncher = at::launch) override;
|
TaskLauncher taskLauncher = at::launch) override;
|
||||||
|
@ -20,6 +20,7 @@ struct Code {
|
|||||||
std::vector<Instruction> instructions_;
|
std::vector<Instruction> instructions_;
|
||||||
std::vector<DebugHandle> debug_handles_;
|
std::vector<DebugHandle> debug_handles_;
|
||||||
std::vector<c10::OperatorName> op_names_;
|
std::vector<c10::OperatorName> op_names_;
|
||||||
|
std::vector<int> operator_input_sizes_;
|
||||||
std::vector<std::function<void(Stack&)>> operators_;
|
std::vector<std::function<void(Stack&)>> operators_;
|
||||||
std::vector<c10::IValue> constants_;
|
std::vector<c10::IValue> constants_;
|
||||||
std::vector<c10::TypePtr> types_;
|
std::vector<c10::TypePtr> types_;
|
||||||
|
@ -23,6 +23,10 @@ class MobileDebugTable {
|
|||||||
MobileDebugTable(
|
MobileDebugTable(
|
||||||
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
|
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
|
||||||
const std::shared_ptr<CompilationUnit>& cu);
|
const std::shared_ptr<CompilationUnit>& cu);
|
||||||
|
|
||||||
|
template <typename It>
|
||||||
|
MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {}
|
||||||
|
|
||||||
std::string getSourceDebugString(
|
std::string getSourceDebugString(
|
||||||
const int64_t debug_handle,
|
const int64_t debug_handle,
|
||||||
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
|
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
|
||||||
@ -36,6 +40,11 @@ class MobileDebugTable {
|
|||||||
const std::vector<int64_t>& debug_handles,
|
const std::vector<int64_t>& debug_handles,
|
||||||
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
|
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
|
||||||
|
|
||||||
|
const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap()
|
||||||
|
const {
|
||||||
|
return callstack_ptr_map_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
|
std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
|
||||||
const std::vector<int64_t>& debug_handles,
|
const std::vector<int64_t>& debug_handles,
|
||||||
|
@ -13,6 +13,14 @@ namespace mobile {
|
|||||||
Function::Function(c10::QualifiedName name)
|
Function::Function(c10::QualifiedName name)
|
||||||
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
||||||
|
|
||||||
|
Function::Function(
|
||||||
|
c10::QualifiedName name,
|
||||||
|
std::shared_ptr<Code> code,
|
||||||
|
at::optional<c10::FunctionSchema> schema)
|
||||||
|
: name_(std::move(name)),
|
||||||
|
code_(std::move(code)),
|
||||||
|
schema_(std::move(schema)) {}
|
||||||
|
|
||||||
const c10::QualifiedName& Function::qualname() const {
|
const c10::QualifiedName& Function::qualname() const {
|
||||||
return name_;
|
return name_;
|
||||||
}
|
}
|
||||||
@ -43,89 +51,11 @@ bool Function::append_operator(
|
|||||||
// Keep the original opname in code_
|
// Keep the original opname in code_
|
||||||
code_->op_names_.emplace_back(name, overload_name);
|
code_->op_names_.emplace_back(name, overload_name);
|
||||||
const auto& opname = code_->op_names_.back();
|
const auto& opname = code_->op_names_.back();
|
||||||
const auto full_name = c10::toString(opname);
|
auto func = makeOperatorFunction(opname, num_specified_args, model_version);
|
||||||
|
if (!func.has_value()) {
|
||||||
std::function<void(Stack&)> fn;
|
return false;
|
||||||
|
|
||||||
const std::vector<c10::Argument>* pArgs = nullptr;
|
|
||||||
bool promoted_op = mobile::hasPrimOpsFn(full_name);
|
|
||||||
if (promoted_op) {
|
|
||||||
fn = mobile::getPrimOpsFn(full_name);
|
|
||||||
} else {
|
|
||||||
std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
|
|
||||||
if (jit_op) {
|
|
||||||
fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
|
|
||||||
pArgs = &jit_op->schema().arguments();
|
|
||||||
} else {
|
|
||||||
auto op = c10::Dispatcher::singleton().findSchema(opname);
|
|
||||||
if (op.has_value()) {
|
|
||||||
fn = [op](Stack& stack) { op->callBoxed(&stack); };
|
|
||||||
if (op->hasSchema()) {
|
|
||||||
pArgs = &op->schema().arguments();
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "arguments are missing for operator ", opname);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
code_->operators_.emplace_back(*func);
|
||||||
if (!promoted_op) {
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
|
|
||||||
const auto& args = *pArgs;
|
|
||||||
if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
|
|
||||||
opname.overload_name.empty()) {
|
|
||||||
// Since byte-code versions 0x4L, convolution has an additional
|
|
||||||
// default-value argument (allow_tf32=True, see
|
|
||||||
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
|
|
||||||
// backward compatibility with models of byte-code version <= 0x3L, where
|
|
||||||
// this bool argument does not yet exist.
|
|
||||||
fn = [fn](Stack& stack) {
|
|
||||||
stack.push_back(true);
|
|
||||||
fn(stack);
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// num_specified_args >= 0 indicates number of arguments are available
|
|
||||||
// from model. We can use it to handle backward compatibility.
|
|
||||||
if (num_specified_args &&
|
|
||||||
num_specified_args.value() < static_cast<int64_t>(args.size())) {
|
|
||||||
fn = [fn, num_specified_args, &args](Stack& stack) {
|
|
||||||
std::vector<IValue> out_args;
|
|
||||||
// The following logic pops and temporarily stores all out arguments
|
|
||||||
// from the stack (which can be 0 or more, and always appended to the
|
|
||||||
// schema), in order to push the necessary default values. Finally,
|
|
||||||
// the out arguments are pushed back into the stack.
|
|
||||||
for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
|
|
||||||
out_args.push_back(stack.back());
|
|
||||||
stack.pop_back();
|
|
||||||
}
|
|
||||||
size_t start_index = num_specified_args.value() - out_args.size();
|
|
||||||
TORCH_CHECK(
|
|
||||||
start_index >= 0,
|
|
||||||
"The number of output arguments is: ",
|
|
||||||
out_args.size(),
|
|
||||||
", which is more then the number of specified arguments: ",
|
|
||||||
num_specified_args.value());
|
|
||||||
for (size_t i = start_index; i < (args.size() - out_args.size());
|
|
||||||
++i) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
args[i].default_value().has_value(),
|
|
||||||
"Error happened at preparing for default values for the argument. The ",
|
|
||||||
i,
|
|
||||||
"th argument ",
|
|
||||||
args[i].name(),
|
|
||||||
" does not have a specified value or default value. ");
|
|
||||||
|
|
||||||
stack.push_back(args[i].default_value());
|
|
||||||
}
|
|
||||||
stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
|
|
||||||
fn(stack);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
code_->operators_.emplace_back(fn);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,6 +127,93 @@ const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
|
|||||||
return getInterpretersExceptionDebugHandles();
|
return getInterpretersExceptionDebugHandles();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
|
||||||
|
c10::OperatorName opname,
|
||||||
|
c10::optional<int> num_specified_args,
|
||||||
|
int64_t model_version) {
|
||||||
|
std::function<void(Stack&)> fn;
|
||||||
|
const auto full_name = c10::toString(opname);
|
||||||
|
const std::vector<c10::Argument>* pArgs = nullptr;
|
||||||
|
bool promoted_op = mobile::hasPrimOpsFn(full_name);
|
||||||
|
if (promoted_op) {
|
||||||
|
fn = mobile::getPrimOpsFn(full_name);
|
||||||
|
} else {
|
||||||
|
std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
|
||||||
|
if (jit_op) {
|
||||||
|
fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
|
||||||
|
pArgs = &jit_op->schema().arguments();
|
||||||
|
} else {
|
||||||
|
auto op = c10::Dispatcher::singleton().findSchema(opname);
|
||||||
|
if (op.has_value()) {
|
||||||
|
fn = [op](Stack& stack) { op->callBoxed(&stack); };
|
||||||
|
if (op->hasSchema()) {
|
||||||
|
pArgs = &op->schema().arguments();
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "arguments are missing for operator ", opname);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!promoted_op) {
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
|
||||||
|
const auto& args = *pArgs;
|
||||||
|
if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
|
||||||
|
opname.overload_name.empty()) {
|
||||||
|
// Since byte-code versions 0x4L, convolution has an additional
|
||||||
|
// default-value argument (allow_tf32=True, see
|
||||||
|
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
|
||||||
|
// backward compatibility with models of byte-code version <= 0x3L, where
|
||||||
|
// this bool argument does not yet exist.
|
||||||
|
fn = [fn](Stack& stack) {
|
||||||
|
stack.push_back(true);
|
||||||
|
fn(stack);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// num_specified_args >= 0 indicates number of arguments are available
|
||||||
|
// from model. We can use it to handle backward compatibility.
|
||||||
|
if (num_specified_args &&
|
||||||
|
num_specified_args.value() < static_cast<int64_t>(args.size())) {
|
||||||
|
fn = [fn, num_specified_args, &args](Stack& stack) {
|
||||||
|
std::vector<IValue> out_args;
|
||||||
|
// The following logic pops and temporarily stores all out arguments
|
||||||
|
// from the stack (which can be 0 or more, and always appended to the
|
||||||
|
// schema), in order to push the necessary default values. Finally,
|
||||||
|
// the out arguments are pushed back into the stack.
|
||||||
|
for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
|
||||||
|
out_args.push_back(stack.back());
|
||||||
|
stack.pop_back();
|
||||||
|
}
|
||||||
|
size_t start_index = num_specified_args.value() - out_args.size();
|
||||||
|
TORCH_CHECK(
|
||||||
|
start_index >= 0,
|
||||||
|
"The number of output arguments is: ",
|
||||||
|
out_args.size(),
|
||||||
|
", which is more then the number of specified arguments: ",
|
||||||
|
num_specified_args.value());
|
||||||
|
for (size_t i = start_index; i < (args.size() - out_args.size());
|
||||||
|
++i) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
args[i].default_value().has_value(),
|
||||||
|
"Error happened at preparing for default values for the argument. The ",
|
||||||
|
i,
|
||||||
|
"th argument ",
|
||||||
|
args[i].name(),
|
||||||
|
" does not have a specified value or default value. ");
|
||||||
|
|
||||||
|
stack.push_back(args[i].default_value());
|
||||||
|
}
|
||||||
|
stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
|
||||||
|
fn(stack);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fn;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -17,6 +17,10 @@ struct Code;
|
|||||||
class TORCH_API Function : public torch::jit::Function {
|
class TORCH_API Function : public torch::jit::Function {
|
||||||
public:
|
public:
|
||||||
explicit Function(c10::QualifiedName name);
|
explicit Function(c10::QualifiedName name);
|
||||||
|
Function(
|
||||||
|
c10::QualifiedName name,
|
||||||
|
std::shared_ptr<Code> code,
|
||||||
|
at::optional<c10::FunctionSchema> schema);
|
||||||
void run(Stack& stack) override;
|
void run(Stack& stack) override;
|
||||||
at::IValue operator()(Stack& stack);
|
at::IValue operator()(Stack& stack);
|
||||||
void ensure_defined() override {}
|
void ensure_defined() override {}
|
||||||
@ -24,6 +28,9 @@ class TORCH_API Function : public torch::jit::Function {
|
|||||||
const c10::QualifiedName& qualname() const override;
|
const c10::QualifiedName& qualname() const override;
|
||||||
bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
|
bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
|
||||||
|
|
||||||
|
// NOTE: the APIs below is dangerous: if you call append_instruction with
|
||||||
|
// dbg_handle and then call it without; then the dbg_handle will become
|
||||||
|
// misaligned. Therefore only use ONE variant at time.
|
||||||
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
|
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
|
||||||
void append_instruction(OpCode op, int X, int N);
|
void append_instruction(OpCode op, int X, int N);
|
||||||
bool append_operator(
|
bool append_operator(
|
||||||
@ -56,6 +63,11 @@ class TORCH_API Function : public torch::jit::Function {
|
|||||||
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
|
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
|
||||||
|
c10::OperatorName opname,
|
||||||
|
c10::optional<int> num_specified_args,
|
||||||
|
int64_t model_version);
|
||||||
|
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -94,15 +94,15 @@ bool InterpreterState::run(Stack& stack) {
|
|||||||
debug_handle = *handle;
|
debug_handle = *handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
// std::cout << "RUNNING " << pc << " "
|
// std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
|
||||||
// << code_->instructions_with_handles_[pc].instruction;
|
|
||||||
// if (inst.op == OP) {
|
// if (inst.op == OP) {
|
||||||
// std::cout << ", " << code_->op_names_[inst.X].name;
|
// std::cout << ", " << code.op_names_[inst.X].name;
|
||||||
// if (!code_->op_names_[inst.X].overload_name.empty()) {
|
// if (!code.op_names_[inst.X].overload_name.empty()) {
|
||||||
// std::cout << "." << code_->op_names_[inst.X].overload_name;
|
// std::cout << "." << code.op_names_[inst.X].overload_name;
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// std::cout << std::endl;
|
// std::cout << std::endl;
|
||||||
|
// std::cout << "top " << stack.back().tagKind() << std::endl;
|
||||||
|
|
||||||
// TODO(iliacher): remove the workaround after RecordFunction is in
|
// TODO(iliacher): remove the workaround after RecordFunction is in
|
||||||
// Dispatcher
|
// Dispatcher
|
||||||
|
@ -135,7 +135,7 @@ class TORCH_API Module {
|
|||||||
std::unordered_map<std::string, std::string> metadata_;
|
std::unordered_map<std::string, std::string> metadata_;
|
||||||
std::shared_ptr<CompilationUnit> cu_;
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
MobileDebugTable debug_table_;
|
MobileDebugTable debug_table_;
|
||||||
bool has_debug_handles_;
|
bool has_debug_handles_ = false;
|
||||||
};
|
};
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
@ -33,7 +33,6 @@ using torch::distributed::autograd::DistAutogradContainer;
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
@ -1,23 +1,333 @@
|
|||||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include <torch/csrc/jit/runtime/instruction.h>
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
#include <torch/csrc/jit/serialization/export.h>
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
|
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <torch/csrc/jit/api/function_impl.h>
|
||||||
|
#include <torch/csrc/jit/api/method.h>
|
||||||
|
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
||||||
|
#include <torch/csrc/jit/backends/backend_debug_info.h>
|
||||||
|
#include <torch/csrc/jit/frontend/source_range.h>
|
||||||
|
#include <torch/csrc/jit/ir/attributes.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/ir/type_hashing.h>
|
||||||
|
#include <torch/csrc/jit/mobile/function.h>
|
||||||
|
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/mobile/method.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
|
#include <torch/csrc/jit/passes/inliner.h>
|
||||||
|
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||||
|
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||||
|
#include <torch/csrc/jit/serialization/pickle.h>
|
||||||
|
#include <torch/csrc/jit/serialization/python_print.h>
|
||||||
|
#include <torch/csrc/jit/serialization/source_range_serialization.h>
|
||||||
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||||
|
|
||||||
|
#include <caffe2/serialize/inline_container.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
void BytecodeExportSet::add(
|
std::vector<Method> gatherGetSetStates(ObjectPtr obj) {
|
||||||
const c10::QualifiedName& qn,
|
std::vector<Method> methods;
|
||||||
ExportedFunction exported) {
|
// Use DFS on IValue's to traverse dependencies of module._ivalue and
|
||||||
items_.emplace(qn, std::move(exported));
|
// add all setstate/getstates to initial stack.
|
||||||
|
std::vector<ObjectPtr> ivalue_stack;
|
||||||
|
ivalue_stack.emplace_back(obj);
|
||||||
|
while (!ivalue_stack.empty()) {
|
||||||
|
ObjectPtr cur = ivalue_stack.back();
|
||||||
|
ivalue_stack.pop_back();
|
||||||
|
auto type = cur->type();
|
||||||
|
Function* setstate = type->findMethod("__setstate__");
|
||||||
|
Function* getstate = type->findMethod("__getstate__");
|
||||||
|
if (getstate && setstate) {
|
||||||
|
if (setstate->isGraphFunction()) {
|
||||||
|
methods.emplace_back(cur, setstate);
|
||||||
|
}
|
||||||
|
if (getstate->isGraphFunction()) {
|
||||||
|
methods.emplace_back(cur, getstate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
||||||
|
IValue field = cur->getSlot(i);
|
||||||
|
if (field.isObject()) {
|
||||||
|
ivalue_stack.emplace_back(field.toObject());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return methods;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BytecodeExportSet::update(const c10::QualifiedName& qn, bool toplevel) {
|
std::vector<Method> findAllDependentFunctions(
|
||||||
items_.at(qn).toplevel = toplevel;
|
const Module& module,
|
||||||
|
Graph& graph) {
|
||||||
|
std::vector<Method> methods;
|
||||||
|
std::unordered_set<c10::string_view> called_method_names;
|
||||||
|
auto nodes = findAllNodes(graph, c10::prim::CallMethod, true);
|
||||||
|
for (Node* node : nodes) {
|
||||||
|
if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) {
|
||||||
|
const FunctionSchema* schema = iface->getMethod(node->s(attr::name));
|
||||||
|
called_method_names.insert(schema->name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& submodule : module.modules()) {
|
||||||
|
for (const auto& m : submodule.get_methods()) {
|
||||||
|
if (called_method_names.find(m.function().qualname().name()) !=
|
||||||
|
called_method_names.end()) {
|
||||||
|
methods.emplace_back(m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return methods;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BytecodeExportSet::contains(const c10::QualifiedName& qn) const {
|
// NOTE: order of functions returned will be:
|
||||||
return items_.find(qn) != items_.end();
|
// 1. functions originated from the methods passed in will be first
|
||||||
|
// 2. All the dependent functions will come afterwards.
|
||||||
|
// This order is meaningful because currently mobile Module looks up
|
||||||
|
// methods with linear search.
|
||||||
|
std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
|
||||||
|
const std::vector<Method>& initial_methods,
|
||||||
|
bool incl_dependent_functions) {
|
||||||
|
std::set<std::pair<std::string, Function*>> visited;
|
||||||
|
std::deque<Method> stack;
|
||||||
|
std::copy(
|
||||||
|
initial_methods.begin(),
|
||||||
|
initial_methods.end(),
|
||||||
|
std::back_inserter(stack));
|
||||||
|
std::vector<std::unique_ptr<GraphFunction>> inlined_functions;
|
||||||
|
while (!stack.empty()) {
|
||||||
|
Method cur = stack.front();
|
||||||
|
stack.pop_front();
|
||||||
|
auto tup = std::make_pair(
|
||||||
|
cur.owner()._ivalue()->type()->name()->qualifiedName(),
|
||||||
|
&cur.function());
|
||||||
|
if (visited.find(tup) != visited.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
visited.insert(tup);
|
||||||
|
const auto& f = toGraphFunction(cur.function());
|
||||||
|
auto graph = f.graph()->copyUnique();
|
||||||
|
Inline(*graph);
|
||||||
|
c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name());
|
||||||
|
|
||||||
|
if (incl_dependent_functions) {
|
||||||
|
std::vector<Method> dependent_methods =
|
||||||
|
findAllDependentFunctions(cur.owner(), *graph);
|
||||||
|
std::copy(
|
||||||
|
dependent_methods.begin(),
|
||||||
|
dependent_methods.end(),
|
||||||
|
std::back_inserter(stack));
|
||||||
|
}
|
||||||
|
auto inlined_func = std::make_unique<GraphFunction>(
|
||||||
|
qn, std::move(graph), f.function_creator());
|
||||||
|
inlined_func->setSchema(f.getSchema());
|
||||||
|
inlined_functions.emplace_back(std::move(inlined_func));
|
||||||
|
}
|
||||||
|
return inlined_functions;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<mobile::Code> compileGraphToMobileCode(
|
||||||
|
const std::string& name,
|
||||||
|
const std::shared_ptr<Graph>& graph,
|
||||||
|
const CompilationOptions& compilation_options,
|
||||||
|
BackendDebugInfoRecorder& debug_info_recorder) {
|
||||||
|
MobileCode code(
|
||||||
|
graph,
|
||||||
|
name,
|
||||||
|
compilation_options.enable_default_value_for_unspecified_arg,
|
||||||
|
compilation_options.enable_default_args_before_out_args);
|
||||||
|
|
||||||
|
std::unique_ptr<mobile::Code> mobile_code_ptr =
|
||||||
|
std::make_unique<mobile::Code>();
|
||||||
|
mobile::Code& mobile_code = *mobile_code_ptr;
|
||||||
|
|
||||||
|
// operator names
|
||||||
|
std::vector<std::string> method_names;
|
||||||
|
std::vector<int64_t> op_debug_handles;
|
||||||
|
int next_new_op_index = 0;
|
||||||
|
|
||||||
|
auto op_to_specified_args = code.op_to_num_specified_args();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < code.instructions().size(); ++i) {
|
||||||
|
Instruction ins = code.instructions()[i];
|
||||||
|
|
||||||
|
if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
|
||||||
|
// Found a new op (assumes new operators ordered by ascending ins.X)
|
||||||
|
auto node = code.instructions_source()[i];
|
||||||
|
const c10::OperatorName& opname = node->schema().operator_name();
|
||||||
|
auto unique_name = c10::toString(opname);
|
||||||
|
// For operator with vararg, adding default arguments would be confusing
|
||||||
|
// and is not allowed. For an operator with num_args = -1, it means the
|
||||||
|
// number of arguments is not available for this operator, we don't do any
|
||||||
|
// backward compatibility adaptation at runtime.
|
||||||
|
c10::optional<int> num_args = c10::nullopt;
|
||||||
|
auto it = op_to_specified_args.find(unique_name);
|
||||||
|
if (it != op_to_specified_args.end()) {
|
||||||
|
num_args = it->second;
|
||||||
|
}
|
||||||
|
mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1));
|
||||||
|
mobile_code.op_names_.emplace_back(opname);
|
||||||
|
auto func = mobile::makeOperatorFunction(
|
||||||
|
opname, num_args, compilation_options.model_version);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
func.has_value(),
|
||||||
|
"Operator with name: ",
|
||||||
|
toString(opname),
|
||||||
|
" not found");
|
||||||
|
mobile_code.operators_.emplace_back(*func);
|
||||||
|
next_new_op_index++;
|
||||||
|
}
|
||||||
|
// CALL nodes at this point represent built-in (i.e. non-Graph)
|
||||||
|
// functions that were not inlined. Here we convert the CALL
|
||||||
|
// instructions for these functions into INTERFACE_CALL instructions
|
||||||
|
// s.t. at runtime, we will look up the Function* on the Type of the
|
||||||
|
// 0th argument in the stack and call that directly.
|
||||||
|
if (ins.op == CALL) {
|
||||||
|
auto node = code.instructions_source()[i];
|
||||||
|
if (node->kind() == prim::CallMethod) {
|
||||||
|
// NB: replacing instruction
|
||||||
|
auto method_name_idx =
|
||||||
|
code.constant_table().size() + method_names.size();
|
||||||
|
method_names.emplace_back(node->s(attr::name));
|
||||||
|
ins = Instruction{
|
||||||
|
INTERFACE_CALL,
|
||||||
|
static_cast<int32_t>(method_name_idx),
|
||||||
|
static_cast<uint16_t>(node->inputs().size())};
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
false, "Unsupported node kind on CALL opcode for mobile");
|
||||||
|
}
|
||||||
|
} else if (ins.op == RET) {
|
||||||
|
auto node = code.instructions_source()[i];
|
||||||
|
for (const auto& input : node->inputs()) {
|
||||||
|
const auto& input_type = input->type();
|
||||||
|
if (input_type->kind() == TypeKind::ListType ||
|
||||||
|
input_type->kind() == TypeKind::DictType) {
|
||||||
|
for (const TypePtr& element_type : input_type->containedTypes()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
element_type->kind() != TypeKind::ClassType,
|
||||||
|
"Returining a list or dictionary with pytorch class type ",
|
||||||
|
"is not supported in mobile module "
|
||||||
|
"(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
|
||||||
|
"Workaround: instead of using pytorch class as their element type, ",
|
||||||
|
"use a combination of list, dictionary, and single types.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
isOpSupportedInMobile(ins.op),
|
||||||
|
toString(ins.op),
|
||||||
|
" is not supported in mobile module.");
|
||||||
|
}
|
||||||
|
auto node = code.instructions_source()[i];
|
||||||
|
int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
|
||||||
|
// Note 1-to-1 correspondence between instructions and debug handles
|
||||||
|
mobile_code.instructions_.emplace_back(ins);
|
||||||
|
mobile_code.debug_handles_.emplace_back(debug_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy constants
|
||||||
|
mobile_code.constants_ = code.constant_table();
|
||||||
|
|
||||||
|
// Make a copy of the constants and append the method names
|
||||||
|
// that we emitted for the converted INTERFACE_CALL nodes above.
|
||||||
|
for (auto& method_name : method_names) {
|
||||||
|
mobile_code.constants_.emplace_back(method_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile_code.types_ = code.type_table();
|
||||||
|
mobile_code.register_size_ = code.register_size();
|
||||||
|
return mobile_code_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void checkSchema(const FunctionSchema& schema) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
schema.overload_name().empty(), // @TODO: is this check correct?
|
||||||
|
"Overloads are not supported in mobile modules.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
!schema.is_varret(),
|
||||||
|
"A variable number of return values is not supported in mobile modules.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLoweredModule(const Module& m) {
|
||||||
|
c10::QualifiedName type_name;
|
||||||
|
if (m.type()->name()) {
|
||||||
|
type_name = m.type()->name().value();
|
||||||
|
}
|
||||||
|
bool isLoweredModule = false;
|
||||||
|
for (const auto& atom : type_name.atoms()) {
|
||||||
|
if (atom == "LoweredModule") {
|
||||||
|
isLoweredModule = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isLoweredModule;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the global static map of backend debug info
|
||||||
|
// contains debug info for this module and any of its children.
|
||||||
|
// If so combine all the maps together and return one.
|
||||||
|
void getBackendDebugInfoMap(
|
||||||
|
const Module& m,
|
||||||
|
BackendDebugInfoMapType& debug_map) {
|
||||||
|
if (isLoweredModule(m)) {
|
||||||
|
auto backend_debug_info =
|
||||||
|
m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
|
||||||
|
const auto& map = backend_debug_info->getDebugInfoMap();
|
||||||
|
if (map) {
|
||||||
|
debug_map.insert(map.value().begin(), map.value().end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto& c : m.children()) {
|
||||||
|
getBackendDebugInfoMap(c, debug_map);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile::Module jitModuleToMobile(
|
||||||
|
const Module& module,
|
||||||
|
const CompilationOptions& options) {
|
||||||
|
std::shared_ptr<mobile::CompilationUnit> mcu =
|
||||||
|
std::make_shared<mobile::CompilationUnit>();
|
||||||
|
BackendDebugInfoRecorder debug_info_recorder;
|
||||||
|
|
||||||
|
std::vector<Method> methods_to_export = module.get_methods();
|
||||||
|
std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue());
|
||||||
|
std::copy(
|
||||||
|
getsetstates.begin(),
|
||||||
|
getsetstates.end(),
|
||||||
|
std::back_inserter(methods_to_export));
|
||||||
|
|
||||||
|
for (const auto& func :
|
||||||
|
inlineFunctions(methods_to_export, options.incl_interface_call)) {
|
||||||
|
std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode(
|
||||||
|
func->name(), func->graph(), options, debug_info_recorder);
|
||||||
|
const auto& schema = func->getSchema();
|
||||||
|
checkSchema(schema);
|
||||||
|
auto mobile_func = std::make_unique<mobile::Function>(
|
||||||
|
func->qualname(), mobile_code_ptr, schema);
|
||||||
|
mcu->register_function(std::move(mobile_func));
|
||||||
|
}
|
||||||
|
|
||||||
|
mobile::Module m(module._ivalue(), mcu);
|
||||||
|
m.setHasDebugHandles(true);
|
||||||
|
BackendDebugInfoMapType backend_debug_info_map;
|
||||||
|
getBackendDebugInfoMap(module, backend_debug_info_map);
|
||||||
|
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
|
||||||
|
debug_handle_cs_ptr_map.insert(
|
||||||
|
backend_debug_info_map.begin(), backend_debug_info_map.end());
|
||||||
|
m.setDebugTable(MobileDebugTable(
|
||||||
|
debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
|
||||||
|
return m;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
@ -1,59 +1,31 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/qualified_name.h>
|
#include <ATen/core/qualified_name.h>
|
||||||
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
||||||
|
#include <torch/csrc/jit/mobile/function.h>
|
||||||
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
struct ExportedFunction {
|
struct TORCH_API CompilationOptions {
|
||||||
ExportedFunction(
|
bool incl_interface_call = false;
|
||||||
const Module& m,
|
bool enable_default_value_for_unspecified_arg = false;
|
||||||
const Function& f,
|
bool enable_default_args_before_out_args = true;
|
||||||
std::unique_ptr<Graph> g,
|
int model_version = caffe2::serialize::kProducedBytecodeVersion;
|
||||||
bool t)
|
|
||||||
: mod(m), function(f), optimizedGraph(std::move(g)), toplevel(t) {}
|
|
||||||
Module mod;
|
|
||||||
const Function& function;
|
|
||||||
std::unique_ptr<Graph> optimizedGraph;
|
|
||||||
bool toplevel;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API BytecodeExportSet {
|
TORCH_API mobile::Module jitModuleToMobile(
|
||||||
public:
|
const Module& module,
|
||||||
BytecodeExportSet() = default;
|
const CompilationOptions& options);
|
||||||
BytecodeExportSet(const BytecodeExportSet&) = delete;
|
|
||||||
BytecodeExportSet& operator=(const BytecodeExportSet&) = delete;
|
|
||||||
BytecodeExportSet(BytecodeExportSet&&) = default;
|
|
||||||
BytecodeExportSet& operator=(BytecodeExportSet&&) = default;
|
|
||||||
|
|
||||||
void add(const c10::QualifiedName& qn, ExportedFunction);
|
|
||||||
void update(const c10::QualifiedName& qn, bool toplevel);
|
|
||||||
bool contains(const c10::QualifiedName& qn) const;
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void visit(F&& f) {
|
|
||||||
for (auto& item : items_) {
|
|
||||||
if (item.second.toplevel) {
|
|
||||||
f(item.first, item.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto& item : items_) {
|
|
||||||
if (!item.second.toplevel) {
|
|
||||||
f(item.first, item.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<c10::QualifiedName, ExportedFunction> items_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -38,6 +38,18 @@
|
|||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
|
CompilationOptions getOptionsFromGlobal() {
|
||||||
|
CompilationOptions compilation_options;
|
||||||
|
compilation_options.enable_default_args_before_out_args =
|
||||||
|
BytecodeEmitMode::is_default_args_before_out_args_enabled();
|
||||||
|
compilation_options.enable_default_value_for_unspecified_arg =
|
||||||
|
BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
|
||||||
|
compilation_options.incl_interface_call = getMobileInterfaceCallExport();
|
||||||
|
compilation_options.model_version =
|
||||||
|
caffe2::serialize::kProducedBytecodeVersion;
|
||||||
|
return compilation_options;
|
||||||
|
}
|
||||||
|
|
||||||
IValue to_tuple(std::initializer_list<IValue> ivalues) {
|
IValue to_tuple(std::initializer_list<IValue> ivalues) {
|
||||||
return c10::ivalue::Tuple::create(ivalues);
|
return c10::ivalue::Tuple::create(ivalues);
|
||||||
}
|
}
|
||||||
@ -63,138 +75,49 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::pair<IValue, IValue> getFunctionTuple(
|
std::pair<IValue, IValue> getFunctionTuple(
|
||||||
const Module& module,
|
const CompilationUnit& compilation_unit,
|
||||||
const Function& func,
|
const mobile::Function& func,
|
||||||
std::unique_ptr<Graph> optimizedGraph,
|
|
||||||
BackendDebugInfoRecorder& debug_info_recorder,
|
BackendDebugInfoRecorder& debug_info_recorder,
|
||||||
const std::string& qn,
|
|
||||||
TypeNameUniquer& type_name_uniquer_) {
|
TypeNameUniquer& type_name_uniquer_) {
|
||||||
TORCH_INTERNAL_ASSERT(optimizedGraph);
|
const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
|
||||||
std::shared_ptr<MobileCode> code;
|
|
||||||
code = std::make_shared<MobileCode>(
|
|
||||||
std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */);
|
|
||||||
auto instructions_copy = code->instructions();
|
|
||||||
|
|
||||||
// operator names
|
|
||||||
std::vector<c10::OperatorName> opnames;
|
|
||||||
std::vector<std::string> method_names;
|
|
||||||
std::vector<int64_t> op_debug_handles;
|
|
||||||
int next_new_op_index = 0;
|
|
||||||
for (size_t i = 0; i < instructions_copy.size(); ++i) {
|
|
||||||
Instruction ins = instructions_copy[i];
|
|
||||||
if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
|
|
||||||
// Found a new op (assumes new operators ordered by ascending ins.X)
|
|
||||||
auto node = code->instructions_source()[i];
|
|
||||||
opnames.emplace_back(node->schema().operator_name());
|
|
||||||
next_new_op_index++;
|
|
||||||
}
|
|
||||||
// CALL nodes at this point represent built-in (i.e. non-Graph)
|
|
||||||
// functions that were not inlined. Here we convert the CALL
|
|
||||||
// instructions for these functions into INTERFACE_CALL instructions
|
|
||||||
// s.t. at runtime, we will look up the Function* on the Type of the
|
|
||||||
// 0th argument in the stack and call that directly.
|
|
||||||
if (ins.op == CALL) {
|
|
||||||
auto node = code->instructions_source()[i];
|
|
||||||
if (node->kind() == prim::CallMethod) {
|
|
||||||
// NB: replacing instruction
|
|
||||||
auto method_name_idx =
|
|
||||||
code->constant_table().size() + method_names.size();
|
|
||||||
method_names.emplace_back(node->s(attr::name));
|
|
||||||
Instruction new_instr{
|
|
||||||
INTERFACE_CALL,
|
|
||||||
static_cast<int32_t>(method_name_idx),
|
|
||||||
static_cast<uint16_t>(node->inputs().size())};
|
|
||||||
instructions_copy[i] = new_instr;
|
|
||||||
} else {
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
false, "Unsupported node kind on CALL opcode for mobile");
|
|
||||||
}
|
|
||||||
} else if (ins.op == RET) {
|
|
||||||
auto node = code->instructions_source()[i];
|
|
||||||
for (const auto& input : node->inputs()) {
|
|
||||||
const auto& input_type = input->type();
|
|
||||||
if (input_type->kind() == TypeKind::ListType ||
|
|
||||||
input_type->kind() == TypeKind::DictType) {
|
|
||||||
for (const TypePtr& element_type : input_type->containedTypes()) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
element_type->kind() != TypeKind::ClassType,
|
|
||||||
"Returining a list or dictionary with pytorch class type ",
|
|
||||||
"is not supported in mobile module "
|
|
||||||
"(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
|
|
||||||
"Workaround: instead of using pytorch class as their element type, ",
|
|
||||||
"use a combination of list, dictionary, and single types.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(
|
|
||||||
isOpSupportedInMobile(ins.op),
|
|
||||||
toString(ins.op),
|
|
||||||
" is not supported in mobile module.");
|
|
||||||
}
|
|
||||||
auto node = code->instructions_source()[i];
|
|
||||||
int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
|
|
||||||
// Note 1-to-1 correspondence between instructions and debug handles
|
|
||||||
op_debug_handles.emplace_back(debug_handle);
|
|
||||||
}
|
|
||||||
|
|
||||||
// instructions
|
// instructions
|
||||||
std::vector<IValue> instructions;
|
std::vector<IValue> instructions;
|
||||||
instructions.reserve(instructions_copy.size());
|
instructions.reserve(mobile_code_ptr->instructions_.size());
|
||||||
for (Instruction ins : instructions_copy) {
|
for (Instruction ins : mobile_code_ptr->instructions_) {
|
||||||
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
|
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// operators
|
// operators
|
||||||
std::vector<IValue> operators;
|
std::vector<IValue> operators;
|
||||||
auto op_to_specified_args = code->op_to_num_specified_args();
|
operators.reserve(mobile_code_ptr->op_names_.size());
|
||||||
operators.reserve(opnames.size());
|
for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
|
||||||
for (const auto& opname : opnames) {
|
const auto& opname = mobile_code_ptr->op_names_[i];
|
||||||
auto unique_name = c10::toString(opname);
|
const int size = mobile_code_ptr->operator_input_sizes_[i];
|
||||||
// For operator with vararg, adding default arguments would be confusing and
|
|
||||||
// is not allowed. For an operator with num_args = -1, it means the number
|
|
||||||
// of arguments is not available for this operator, we don't do any backward
|
|
||||||
// compatibility adaptation at runtime.
|
|
||||||
int num_args = -1;
|
|
||||||
auto it = op_to_specified_args.find(unique_name);
|
|
||||||
if (it != op_to_specified_args.end()) {
|
|
||||||
num_args = it->second;
|
|
||||||
}
|
|
||||||
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
|
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
|
||||||
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
|
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
|
||||||
} else {
|
} else {
|
||||||
operators.emplace_back(
|
operators.emplace_back(
|
||||||
to_tuple({opname.name, opname.overload_name, num_args}));
|
to_tuple({opname.name, opname.overload_name, size}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// constants
|
|
||||||
//
|
|
||||||
// Make a copy of the constants and append the method names
|
|
||||||
// that we emitted for the converted INTERFACE_CALL nodes above.
|
|
||||||
auto constants = code->constant_table();
|
|
||||||
for (auto& method_name : method_names) {
|
|
||||||
constants.emplace_back(std::move(method_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
// types
|
// types
|
||||||
std::vector<IValue> types;
|
std::vector<IValue> types;
|
||||||
types.reserve(code->type_table().size());
|
types.reserve(mobile_code_ptr->types_.size());
|
||||||
static const std::string torch_prefix("__torch__");
|
static const std::string torch_prefix("__torch__");
|
||||||
static const std::string class_prefix("__torch__.torch.classes");
|
static const std::string class_prefix("__torch__.torch.classes");
|
||||||
std::shared_ptr<torch::jit::CompilationUnit> cu =
|
|
||||||
module._ivalue()->compilation_unit();
|
|
||||||
|
|
||||||
for (const TypePtr& t : code->type_table()) {
|
for (const TypePtr& t : mobile_code_ptr->types_) {
|
||||||
std::string type_str = t->annotation_str();
|
std::string type_str = t->annotation_str();
|
||||||
if (t->kind() == TypeKind::TupleType) {
|
if (t->kind() == TypeKind::TupleType) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
cu->get_named_tuple(t->str()),
|
compilation_unit.get_named_tuple(t->str()),
|
||||||
"Can't find definition for the qualified name: ",
|
"Can't find definition for the qualified name: ",
|
||||||
t->str(),
|
t->str(),
|
||||||
"(TypeKind::TupleType) in compilation unit.",
|
"(TypeKind::TupleType) in compilation unit.",
|
||||||
"Please report a bug to PyTorch.");
|
"Please report a bug to PyTorch.");
|
||||||
auto named_tuple_type = cu->get_named_tuple(t->str());
|
auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
|
||||||
if (named_tuple_type != nullptr) {
|
if (named_tuple_type != nullptr) {
|
||||||
std::string named_tuple_str = t->str();
|
std::string named_tuple_str = t->str();
|
||||||
named_tuple_str.append("[NamedTuple, [");
|
named_tuple_str.append("[NamedTuple, [");
|
||||||
@ -254,12 +177,12 @@ std::pair<IValue, IValue> getFunctionTuple(
|
|||||||
|
|
||||||
// since the register location is embedded into the bytecode, pass the
|
// since the register location is embedded into the bytecode, pass the
|
||||||
// register size
|
// register size
|
||||||
auto register_size = static_cast<int>(code->register_size());
|
auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
|
||||||
|
|
||||||
auto codeTable = Table(
|
auto codeTable = Table(
|
||||||
{{"instructions", to_tuple(instructions)},
|
{{"instructions", to_tuple(instructions)},
|
||||||
{"operators", to_tuple(operators)},
|
{"operators", to_tuple(operators)},
|
||||||
{"constants", to_tuple(constants)},
|
{"constants", to_tuple(mobile_code_ptr->constants_)},
|
||||||
{"types", to_tuple(types)},
|
{"types", to_tuple(types)},
|
||||||
{"register_size", register_size}});
|
{"register_size", register_size}});
|
||||||
|
|
||||||
@ -273,14 +196,7 @@ std::pair<IValue, IValue> getFunctionTuple(
|
|||||||
}
|
}
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
};
|
};
|
||||||
TORCH_CHECK(
|
|
||||||
schema.overload_name().empty(), // @TODO: is this check correct?
|
|
||||||
"Overloads are not supported in mobile modules.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
!schema.is_varret(),
|
|
||||||
"A variable number of return values is not supported in mobile modules.");
|
|
||||||
auto makeArgTuple = [&](const std::vector<Argument>& args) {
|
auto makeArgTuple = [&](const std::vector<Argument>& args) {
|
||||||
std::vector<IValue> argTables;
|
std::vector<IValue> argTables;
|
||||||
for (auto&& arg : args) {
|
for (auto&& arg : args) {
|
||||||
@ -315,6 +231,17 @@ std::pair<IValue, IValue> getFunctionTuple(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// function tuple
|
// function tuple
|
||||||
|
std::string qn;
|
||||||
|
if (func.name() == "__setstate__" || func.name() == "__getstate__") {
|
||||||
|
auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
classtype, "class is null ", func.qualname().qualifiedName());
|
||||||
|
qn = c10::QualifiedName(
|
||||||
|
type_name_uniquer_.getUniqueName(classtype), func.name())
|
||||||
|
.qualifiedName();
|
||||||
|
} else {
|
||||||
|
qn = func.qualname().qualifiedName();
|
||||||
|
}
|
||||||
auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
|
auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
|
||||||
|
|
||||||
c10::optional<IValue> debug_info_vals;
|
c10::optional<IValue> debug_info_vals;
|
||||||
@ -324,41 +251,27 @@ std::pair<IValue, IValue> getFunctionTuple(
|
|||||||
// debug handles generated by debug_handle_manager
|
// debug handles generated by debug_handle_manager
|
||||||
// will correspond to {source_range, inlinedCallStackPtr} which we will
|
// will correspond to {source_range, inlinedCallStackPtr} which we will
|
||||||
// serialize separately.
|
// serialize separately.
|
||||||
IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
|
IValue module_debug_tuple =
|
||||||
|
c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
|
||||||
auto function_debug_info =
|
auto function_debug_info =
|
||||||
Table({{"function_debug_handles", module_debug_tuple}});
|
Table({{"function_debug_handles", module_debug_tuple}});
|
||||||
debug_info_vals = to_tuple({qn, function_debug_info});
|
debug_info_vals = to_tuple({qn, function_debug_info});
|
||||||
return std::make_pair(bytecode_vals, debug_info_vals);
|
return std::make_pair(bytecode_vals, debug_info_vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pushFunctionToIValues(
|
void pushMobileFunctionsToIValues(
|
||||||
BytecodeExportSet exportSet,
|
const CompilationUnit& compilation_unit,
|
||||||
|
const mobile::Module& module,
|
||||||
std::vector<c10::IValue>& elements,
|
std::vector<c10::IValue>& elements,
|
||||||
std::vector<c10::IValue>& debugInfoElements,
|
std::vector<c10::IValue>& debugInfoElements,
|
||||||
BackendDebugInfoRecorder& recorder,
|
BackendDebugInfoRecorder& recorder,
|
||||||
TypeNameUniquer& uniquer) {
|
TypeNameUniquer& uniquer) {
|
||||||
exportSet.visit(
|
for (const auto& method : module.get_methods()) {
|
||||||
[&](const c10::QualifiedName& qn, ExportedFunction& exported) {
|
auto tuple = getFunctionTuple(
|
||||||
auto tuple = getFunctionTuple(
|
compilation_unit, method.function(), recorder, uniquer);
|
||||||
exported.mod,
|
elements.push_back(std::move(tuple.first));
|
||||||
exported.function,
|
debugInfoElements.push_back(std::move(tuple.second));
|
||||||
std::move(exported.optimizedGraph),
|
}
|
||||||
recorder,
|
|
||||||
qn.qualifiedName(),
|
|
||||||
uniquer);
|
|
||||||
elements.push_back(std::move(tuple.first));
|
|
||||||
debugInfoElements.push_back(std::move(tuple.second));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void pushFunctionToIValues(
|
|
||||||
BytecodeExportSet exportSet,
|
|
||||||
std::vector<c10::IValue>& elements,
|
|
||||||
BackendDebugInfoRecorder& recorder,
|
|
||||||
TypeNameUniquer& uniquer) {
|
|
||||||
std::vector<c10::IValue> debugInfoElements;
|
|
||||||
pushFunctionToIValues(
|
|
||||||
std::move(exportSet), elements, debugInfoElements, recorder, uniquer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
|
std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
|
||||||
@ -402,61 +315,6 @@ std::vector<ModuleMethod> getModuleInterfaceExports(
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void exportFunction(
|
|
||||||
BytecodeExportSet& exportSet,
|
|
||||||
const ModuleMethod& method,
|
|
||||||
bool toplevel = false) {
|
|
||||||
const auto& func = method.function;
|
|
||||||
const auto& qn = method.exportName;
|
|
||||||
if (exportSet.contains(qn)) {
|
|
||||||
if (toplevel) {
|
|
||||||
exportSet.update(qn, toplevel);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto graph = func.graph()->copyUnique();
|
|
||||||
Inline(*graph);
|
|
||||||
auto interfaceCalls = getInterfaceCalls(*graph);
|
|
||||||
exportSet.add(
|
|
||||||
qn, ExportedFunction{method.module, func, std::move(graph), toplevel});
|
|
||||||
|
|
||||||
if (!getMobileInterfaceCallExport()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls);
|
|
||||||
for (const auto& interface : interfaces) {
|
|
||||||
exportFunction(exportSet, interface);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void setstateTuple(
|
|
||||||
BytecodeExportSet& exportSet,
|
|
||||||
const Module& module,
|
|
||||||
const IValue& ivalue,
|
|
||||||
TypeNameUniquer& type_name_uniquer_,
|
|
||||||
bool toplevel = false) {
|
|
||||||
if (!ivalue.isObject())
|
|
||||||
return;
|
|
||||||
auto obj = ivalue.toObject();
|
|
||||||
auto type = obj->type();
|
|
||||||
if (checkHasValidSetGetState(type)) {
|
|
||||||
Function& setstate = type->getMethod("__setstate__");
|
|
||||||
auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() +
|
|
||||||
"." + setstate.name();
|
|
||||||
if (exportSet.contains(qn)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (auto f = tryToGraphFunction(setstate)) {
|
|
||||||
exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
|
||||||
setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isLoweredModule(const Module& m) {
|
bool isLoweredModule(const Module& m) {
|
||||||
c10::QualifiedName type_name;
|
c10::QualifiedName type_name;
|
||||||
if (m.type()->name()) {
|
if (m.type()->name()) {
|
||||||
@ -544,24 +402,6 @@ bool getMobileInterfaceCallExport() {
|
|||||||
return mobileInterfaceCallExport().load(std::memory_order_relaxed);
|
return mobileInterfaceCallExport().load(std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
BytecodeExportSet moduleMethodsTuple(
|
|
||||||
const Module& module,
|
|
||||||
TypeNameUniquer& type_name_uniquer_) {
|
|
||||||
BytecodeExportSet exportSet;
|
|
||||||
auto methods = module.get_methods();
|
|
||||||
// top level methods
|
|
||||||
for (const auto& method : methods) {
|
|
||||||
const auto& f = toGraphFunction(method.function());
|
|
||||||
exportFunction(
|
|
||||||
exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// __setstate__ of all components
|
|
||||||
setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true);
|
|
||||||
|
|
||||||
return exportSet;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
|
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
|
||||||
GetExtraFilesHook() = std::move(hook);
|
GetExtraFilesHook() = std::move(hook);
|
||||||
}
|
}
|
||||||
@ -774,9 +614,12 @@ void ScriptModuleSerializer::writeByteCode(
|
|||||||
// Always save debug handles
|
// Always save debug handles
|
||||||
debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
|
debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
|
||||||
|
|
||||||
BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_);
|
mobile::Module mobile_module =
|
||||||
pushFunctionToIValues(
|
jitModuleToMobile(module, getOptionsFromGlobal());
|
||||||
std::move(exportSet),
|
|
||||||
|
pushMobileFunctionsToIValues(
|
||||||
|
*module._ivalue()->compilation_unit(),
|
||||||
|
mobile_module,
|
||||||
elements,
|
elements,
|
||||||
debug_info_elements,
|
debug_info_elements,
|
||||||
debug_info_recorder,
|
debug_info_recorder,
|
||||||
@ -840,9 +683,9 @@ void ScriptModuleSerializer::writeByteCode(
|
|||||||
getBackendDebugInfoMap(module, backend_debug_info_map);
|
getBackendDebugInfoMap(module, backend_debug_info_map);
|
||||||
// Now get the debug-handles-to-inlined-cs-ptr-map
|
// Now get the debug-handles-to-inlined-cs-ptr-map
|
||||||
// And serialize that in a separate archive
|
// And serialize that in a separate archive
|
||||||
auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
|
const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
|
||||||
debug_handle_cs_ptr_map.insert(
|
BackendDebugInfoMapType debug_handle_cs_ptr_map(
|
||||||
backend_debug_info_map.begin(), backend_debug_info_map.end());
|
debug_info.begin(), debug_info.end());
|
||||||
CallStackDebugInfoPickler cs_debug_info_pickler;
|
CallStackDebugInfoPickler cs_debug_info_pickler;
|
||||||
auto cs_data = cs_debug_info_pickler.pickle(
|
auto cs_data = cs_debug_info_pickler.pickle(
|
||||||
debug_handle_cs_ptr_map, source_range_tags_);
|
debug_handle_cs_ptr_map, source_range_tags_);
|
||||||
@ -962,31 +805,13 @@ void ExportModule(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
|
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
|
||||||
std::vector<c10::IValue> elements;
|
mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
|
||||||
BackendDebugInfoRecorder dummy;
|
for (const auto& method : mobile_m.get_methods()) {
|
||||||
TypeNameUniquer dummy_uniquer = TypeNameUniquer();
|
for (const auto& op : method.function().get_code()->op_names_) {
|
||||||
BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
|
|
||||||
pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
|
|
||||||
for (const auto& element : elements) {
|
|
||||||
auto table = element.toTupleRef().elements()[1];
|
|
||||||
auto row =
|
|
||||||
table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
row->elements().at(0).toStringRef() == "operators",
|
|
||||||
"Expected operators but found ",
|
|
||||||
row->elements().at(0).toStringRef());
|
|
||||||
const auto& ops_list = row->elements().at(1).toTupleRef().elements();
|
|
||||||
for (const auto& op : ops_list) {
|
|
||||||
const auto& op_item = op.toTupleRef().elements();
|
|
||||||
TORCH_CHECK(
|
|
||||||
op_item.size() >= 2,
|
|
||||||
"There should be either two parts (name and overload name), ",
|
|
||||||
"or three parts (name, overload name and number of specified args) ",
|
|
||||||
"for an operator.");
|
|
||||||
auto opname = op_item[0].toString()->string();
|
|
||||||
auto overload = op_item[1].toString()->string();
|
|
||||||
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
|
||||||
opnames.emplace(overload.empty() ? opname : opname + "." + overload);
|
opnames.emplace(
|
||||||
|
op.overload_name.empty() ? op.name
|
||||||
|
: op.name + "." + op.overload_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user