Create method to map JIT module to (source, constant) and back. (#74119)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74119

  implemented function to generate source as ExtraFilesMap and constants

  wrote function to construct jit module given (ivalue, source,
  constant) tripple.

Test Plan: unittest

Reviewed By: pavithranrao

Differential Revision: D34803945

fbshipit-source-id: 2edc798407fe68294cb4c3c7516f5bd143df88c3
(cherry picked from commit 35e54e166b8f0f5cfe8f08c07866b59ae61ee79d)
This commit is contained in:
Han Qi
2022-03-15 11:24:20 -07:00
committed by PyTorch MergeBot
parent cc5f8aea5c
commit ded82ad7c7
5 changed files with 268 additions and 0 deletions

View File

@ -3,7 +3,9 @@
#include <test/cpp/jit/test_utils.h>
#include <sstream>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/torch.h>
@ -13,6 +15,20 @@
namespace torch {
namespace jit {
namespace {
Module roundtripThroughMobile(const Module& m) {
ExtraFilesMap files;
std::vector<IValue> constants;
jitModuleToPythonCodeAndConstants(m, &files, &constants);
CompilationOptions options;
mobile::Module mobilem = jitModuleToMobile(m, options);
return jitModuleFromSourceAndConstants(
mobilem._ivalue(), files, constants, 8);
}
} // namespace
TEST(SerializationTest, ExtraFilesHookPreference) {
// Tests that an extra file written explicitly has precedence over
// extra files written by a hook
@ -149,5 +165,78 @@ TEST(SerializationTest, TestJitStream_CUDA) {
// Check if both the output tensors are equal
ASSERT_TRUE(op.equal(c));
}
TEST(TestSourceRoundTrip, UpsampleNearest2d) {
Module m("m");
m.define(R"(
def forward(self, input: Tensor, scale:float):
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
)");
std::vector<IValue> inputs;
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
inputs.emplace_back(at::Scalar(2.0));
auto ref = m.forward(inputs);
Module m2 = roundtripThroughMobile(m);
auto res = m2.forward(inputs);
auto resd = res.toTensor();
auto refd = ref.toTensor();
ASSERT_TRUE(resd.equal(refd));
}
TEST(TestSourceRoundTrip, CheckAttrAccess) {
Module m("m");
m.register_attribute("mobile_optimized", BoolType::get(), true);
Module m2 = roundtripThroughMobile(m);
bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
AT_ASSERT(mobile_optimized);
}
TEST(TestSourceRoundTrip,
MethodInvocation) { // NOLINT (use =delete in gtest)
const std::vector<std::string> test_programs{
// test invoking a method with default parameter
R"(
def test_func(self, x, b : int = 4):
return self.foo + x + b
)",
// inner method call with default parameter (gets inlined)
R"(
def add_with_default_arg(self, x, b : int = 4):
return self.foo + x + b
def test_func(self, x):
return self.add_with_default_arg(x) # invoke method w/ default arg
)",
// simple method call
R"(
def test_func(self, x):
b = 4
return self.foo + x + b
)",
};
for (const auto& test_program : test_programs) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(test_program);
const int fortyTwo = 42; // (keep linter happy)
auto minput = fortyTwo * torch::ones({});
auto ref = m.run_method("test_func", minput);
Module m2 = roundtripThroughMobile(m);
const auto& test_func = m2.get_method("test_func");
IValue res;
for (int i = 0; i < 3; ++i) {
res = test_func({minput});
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
}
} // namespace jit
} // namespace torch

View File

@ -403,5 +403,75 @@ Module load(
return deserializer.deserialize(device, extra_files);
}
// Replace object with a newly created but equivalent object.
// The goal is to replace object's methods. However, since object's
// methods are attached to type; we need to replace it's type.
// Non-objects are unchanged; however, nested structures such as list, dict
// are also reconstructed because they might contain an object.
static IValue recreateObject(IValue ivalue, TypeResolver resolver) {
if (ivalue.isObject()) {
auto obj = ivalue.toObject();
auto classtype_old = obj->type();
auto newtype = resolver(*classtype_old->name());
size_t n = classtype_old->numAttributes();
auto newobj = c10::ivalue::Object::create(newtype, n);
for (const auto i : c10::irange(n)) {
newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
}
return newobj;
} else if (ivalue.isList()) {
auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
for (const auto& ival : ivalue.toList()) {
res.emplace_back(recreateObject(ival, resolver));
}
return res;
} else if (ivalue.isGenericDict()) {
auto result = c10::impl::GenericDict(
ivalue.type()->containedType(0), ivalue.type()->containedType(1));
for (const auto& kv : ivalue.toGenericDict()) {
result.insert_or_assign(
recreateObject(kv.key(), resolver),
recreateObject(kv.value(), resolver));
}
return result;
} else if (ivalue.isTuple()) {
std::vector<IValue> res;
for (const auto& ival : ivalue.toTuple()->elements()) {
res.push_back(recreateObject(ival, resolver));
}
return c10::ivalue::Tuple::create(res);
}
// Leaf types are returned verbatim.
return ivalue;
}
Module jitModuleFromSourceAndConstants(
const IValue& ivalue,
const ExtraFilesMap& source,
const std::vector<IValue>& constants,
int32_t version) {
auto compilation_unit = std::make_shared<CompilationUnit>();
SourceImporter importer(
compilation_unit,
&constants,
[&source](const std::string& qualifier) -> std::shared_ptr<SourceView> {
auto source_iter = source.find(qualifier);
if (source_iter == source.end()) {
return nullptr;
}
return std::make_shared<Source>(
source_iter->second, qualifier, 1, nullptr);
},
version);
auto type_resolver = [&](const c10::QualifiedName& qn) {
auto cls = importer.loadType(qn);
return c10::StrongTypePtr(compilation_unit, std::move(cls));
};
auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
Module m(newIvalue);
rewriteQuantizedConvForBC(m);
return m;
}
} // namespace jit
} // namespace torch

View File

@ -98,5 +98,11 @@ TORCH_API Module load(
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files);
TORCH_API Module jitModuleFromSourceAndConstants(
const IValue& ivalue,
const ExtraFilesMap& source,
const std::vector<IValue>& constants,
int32_t version);
} // namespace jit
} // namespace torch

View File

@ -2,6 +2,7 @@
#include <algorithm>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
@ -17,6 +18,7 @@
#include <torch/csrc/jit/operator_upgraders/version_map.h>
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
using c10::QualifiedName;
@ -1662,5 +1664,98 @@ uint64_t PythonPrint::minVersion() const {
PythonPrint::~PythonPrint() = default;
std::vector<IValue> traverseIValueAndGetObjects(IValue ivalue) {
std::vector<IValue> result;
std::vector<IValue> stack;
stack.emplace_back(ivalue);
while (!stack.empty()) {
IValue head = stack.back();
stack.pop_back();
if (head.isObject()) {
result.push_back(head);
auto obj = head.toObject();
ClassTypePtr type = obj->type();
if (type->hasMethod("__getstate__")) {
Function& getstate = type->getMethod("__getstate__");
stack.emplace_back(getstate({obj}));
} else {
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
stack.emplace_back(obj->getSlot(i));
}
}
} else if (ivalue.isGenericDict()) {
for (const auto& kv : ivalue.toGenericDict()) {
// skip key because key cannot be an object
stack.emplace_back(kv.value());
}
} else if (ivalue.isList()) {
for (const auto& v : ivalue.toList()) {
stack.emplace_back(v);
}
} else if (ivalue.isTuple()) {
for (const auto& v : ivalue.toTuple()->elements()) {
stack.emplace_back(v);
}
}
}
return result;
}
c10::optional<std::string> printType(
const c10::Type& type,
torch::jit::TypeNameUniquer& type_name_uniquer) {
if (auto dyn = type.castRaw<c10::DynamicType>()) {
return dyn->fallback()->annotation_str(
[&](auto&& t) { return printType(t, type_name_uniquer); });
}
auto namedType = type.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return type_name_uniquer.getUniqueName(namedType).qualifiedName();
}
return c10::nullopt;
}
void jitModuleToPythonCodeAndConstants(
const Module& module,
ExtraFilesMap* jit_sources, // output
std::vector<IValue>* constants // output
) {
std::vector<IValue> objects = traverseIValueAndGetObjects(module._ivalue());
std::unordered_set<c10::QualifiedName> visited;
PrintDepsTable class_deps;
TypeNameUniquer uniquer;
auto type_printer = [&](const c10::Type& t) { return printType(t, uniquer); };
// Group by prefix; because every prefix is a file.
std::unordered_map<std::string, PythonPrint> grouped_by_prefix;
for (const IValue& obj : objects) {
ObjectPtr obj_ptr = obj.toObject();
ClassTypePtr class_type = obj_ptr->type();
class_deps.add(class_type);
}
for (int i = 0; i < class_deps.size(); ++i) {
auto type = class_deps[i];
auto qualname = uniquer.getUniqueName(type);
std::string qualifier = qualname.prefix();
auto pp_iter = grouped_by_prefix.find(qualifier);
if (pp_iter == grouped_by_prefix.end()) {
pp_iter = grouped_by_prefix
.emplace(
qualifier,
PythonPrint(
*constants,
class_deps,
type_printer,
/*enforce_importable=*/true))
.first;
}
pp_iter->second.printNamedType(type);
}
for (const auto& kv : grouped_by_prefix) {
(*jit_sources)[kv.first] = kv.second.str();
}
}
} // namespace jit
} // namespace torch

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <iostream>
#include <vector>
@ -49,5 +50,12 @@ struct TORCH_API PythonPrint {
};
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
TORCH_API void jitModuleToPythonCodeAndConstants(
const Module& module,
ExtraFilesMap* jit_sources, // output
std::vector<IValue>* constants // output
);
} // namespace jit
} // namespace torch