mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Create method to map JIT module to (source, constant) and back. (#74119)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74119 implemented function to generate source as ExtraFilesMap and constants wrote function to construct jit module given (ivalue, source, constant) tripple. Test Plan: unittest Reviewed By: pavithranrao Differential Revision: D34803945 fbshipit-source-id: 2edc798407fe68294cb4c3c7516f5bd143df88c3 (cherry picked from commit 35e54e166b8f0f5cfe8f08c07866b59ae61ee79d)
This commit is contained in:
@ -3,7 +3,9 @@
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <sstream>
|
||||
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
#include <torch/torch.h>
|
||||
@ -13,6 +15,20 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
Module roundtripThroughMobile(const Module& m) {
|
||||
ExtraFilesMap files;
|
||||
std::vector<IValue> constants;
|
||||
jitModuleToPythonCodeAndConstants(m, &files, &constants);
|
||||
CompilationOptions options;
|
||||
mobile::Module mobilem = jitModuleToMobile(m, options);
|
||||
return jitModuleFromSourceAndConstants(
|
||||
mobilem._ivalue(), files, constants, 8);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(SerializationTest, ExtraFilesHookPreference) {
|
||||
// Tests that an extra file written explicitly has precedence over
|
||||
// extra files written by a hook
|
||||
@ -149,5 +165,78 @@ TEST(SerializationTest, TestJitStream_CUDA) {
|
||||
// Check if both the output tensors are equal
|
||||
ASSERT_TRUE(op.equal(c));
|
||||
}
|
||||
|
||||
TEST(TestSourceRoundTrip, UpsampleNearest2d) {
|
||||
Module m("m");
|
||||
m.define(R"(
|
||||
def forward(self, input: Tensor, scale:float):
|
||||
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
|
||||
)");
|
||||
|
||||
std::vector<IValue> inputs;
|
||||
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
|
||||
inputs.emplace_back(at::Scalar(2.0));
|
||||
auto ref = m.forward(inputs);
|
||||
|
||||
Module m2 = roundtripThroughMobile(m);
|
||||
auto res = m2.forward(inputs);
|
||||
|
||||
auto resd = res.toTensor();
|
||||
auto refd = ref.toTensor();
|
||||
ASSERT_TRUE(resd.equal(refd));
|
||||
}
|
||||
|
||||
TEST(TestSourceRoundTrip, CheckAttrAccess) {
|
||||
Module m("m");
|
||||
m.register_attribute("mobile_optimized", BoolType::get(), true);
|
||||
Module m2 = roundtripThroughMobile(m);
|
||||
bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
|
||||
AT_ASSERT(mobile_optimized);
|
||||
}
|
||||
|
||||
TEST(TestSourceRoundTrip,
|
||||
MethodInvocation) { // NOLINT (use =delete in gtest)
|
||||
const std::vector<std::string> test_programs{
|
||||
// test invoking a method with default parameter
|
||||
R"(
|
||||
def test_func(self, x, b : int = 4):
|
||||
return self.foo + x + b
|
||||
)",
|
||||
// inner method call with default parameter (gets inlined)
|
||||
R"(
|
||||
def add_with_default_arg(self, x, b : int = 4):
|
||||
return self.foo + x + b
|
||||
def test_func(self, x):
|
||||
return self.add_with_default_arg(x) # invoke method w/ default arg
|
||||
)",
|
||||
// simple method call
|
||||
R"(
|
||||
def test_func(self, x):
|
||||
b = 4
|
||||
return self.foo + x + b
|
||||
)",
|
||||
};
|
||||
for (const auto& test_program : test_programs) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(test_program);
|
||||
|
||||
const int fortyTwo = 42; // (keep linter happy)
|
||||
auto minput = fortyTwo * torch::ones({});
|
||||
auto ref = m.run_method("test_func", minput);
|
||||
|
||||
Module m2 = roundtripThroughMobile(m);
|
||||
const auto& test_func = m2.get_method("test_func");
|
||||
IValue res;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
res = test_func({minput});
|
||||
}
|
||||
|
||||
auto resd = res.toTensor().item<float>();
|
||||
auto refd = ref.toTensor().item<float>();
|
||||
AT_ASSERT(resd == refd);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -403,5 +403,75 @@ Module load(
|
||||
return deserializer.deserialize(device, extra_files);
|
||||
}
|
||||
|
||||
// Replace object with a newly created but equivalent object.
|
||||
// The goal is to replace object's methods. However, since object's
|
||||
// methods are attached to type; we need to replace it's type.
|
||||
// Non-objects are unchanged; however, nested structures such as list, dict
|
||||
// are also reconstructed because they might contain an object.
|
||||
static IValue recreateObject(IValue ivalue, TypeResolver resolver) {
|
||||
if (ivalue.isObject()) {
|
||||
auto obj = ivalue.toObject();
|
||||
auto classtype_old = obj->type();
|
||||
auto newtype = resolver(*classtype_old->name());
|
||||
size_t n = classtype_old->numAttributes();
|
||||
auto newobj = c10::ivalue::Object::create(newtype, n);
|
||||
for (const auto i : c10::irange(n)) {
|
||||
newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
|
||||
}
|
||||
return newobj;
|
||||
} else if (ivalue.isList()) {
|
||||
auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
|
||||
for (const auto& ival : ivalue.toList()) {
|
||||
res.emplace_back(recreateObject(ival, resolver));
|
||||
}
|
||||
return res;
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
auto result = c10::impl::GenericDict(
|
||||
ivalue.type()->containedType(0), ivalue.type()->containedType(1));
|
||||
for (const auto& kv : ivalue.toGenericDict()) {
|
||||
result.insert_or_assign(
|
||||
recreateObject(kv.key(), resolver),
|
||||
recreateObject(kv.value(), resolver));
|
||||
}
|
||||
return result;
|
||||
} else if (ivalue.isTuple()) {
|
||||
std::vector<IValue> res;
|
||||
for (const auto& ival : ivalue.toTuple()->elements()) {
|
||||
res.push_back(recreateObject(ival, resolver));
|
||||
}
|
||||
return c10::ivalue::Tuple::create(res);
|
||||
}
|
||||
// Leaf types are returned verbatim.
|
||||
return ivalue;
|
||||
}
|
||||
|
||||
Module jitModuleFromSourceAndConstants(
|
||||
const IValue& ivalue,
|
||||
const ExtraFilesMap& source,
|
||||
const std::vector<IValue>& constants,
|
||||
int32_t version) {
|
||||
auto compilation_unit = std::make_shared<CompilationUnit>();
|
||||
SourceImporter importer(
|
||||
compilation_unit,
|
||||
&constants,
|
||||
[&source](const std::string& qualifier) -> std::shared_ptr<SourceView> {
|
||||
auto source_iter = source.find(qualifier);
|
||||
if (source_iter == source.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<Source>(
|
||||
source_iter->second, qualifier, 1, nullptr);
|
||||
},
|
||||
version);
|
||||
auto type_resolver = [&](const c10::QualifiedName& qn) {
|
||||
auto cls = importer.loadType(qn);
|
||||
return c10::StrongTypePtr(compilation_unit, std::move(cls));
|
||||
};
|
||||
auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
|
||||
Module m(newIvalue);
|
||||
rewriteQuantizedConvForBC(m);
|
||||
return m;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -98,5 +98,11 @@ TORCH_API Module load(
|
||||
c10::optional<c10::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
|
||||
TORCH_API Module jitModuleFromSourceAndConstants(
|
||||
const IValue& ivalue,
|
||||
const ExtraFilesMap& source,
|
||||
const std::vector<IValue>& constants,
|
||||
int32_t version);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
@ -17,6 +18,7 @@
|
||||
#include <torch/csrc/jit/operator_upgraders/version_map.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||
|
||||
using c10::QualifiedName;
|
||||
|
||||
@ -1662,5 +1664,98 @@ uint64_t PythonPrint::minVersion() const {
|
||||
|
||||
PythonPrint::~PythonPrint() = default;
|
||||
|
||||
std::vector<IValue> traverseIValueAndGetObjects(IValue ivalue) {
|
||||
std::vector<IValue> result;
|
||||
std::vector<IValue> stack;
|
||||
stack.emplace_back(ivalue);
|
||||
while (!stack.empty()) {
|
||||
IValue head = stack.back();
|
||||
stack.pop_back();
|
||||
if (head.isObject()) {
|
||||
result.push_back(head);
|
||||
auto obj = head.toObject();
|
||||
ClassTypePtr type = obj->type();
|
||||
if (type->hasMethod("__getstate__")) {
|
||||
Function& getstate = type->getMethod("__getstate__");
|
||||
stack.emplace_back(getstate({obj}));
|
||||
} else {
|
||||
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
||||
stack.emplace_back(obj->getSlot(i));
|
||||
}
|
||||
}
|
||||
} else if (ivalue.isGenericDict()) {
|
||||
for (const auto& kv : ivalue.toGenericDict()) {
|
||||
// skip key because key cannot be an object
|
||||
stack.emplace_back(kv.value());
|
||||
}
|
||||
} else if (ivalue.isList()) {
|
||||
for (const auto& v : ivalue.toList()) {
|
||||
stack.emplace_back(v);
|
||||
}
|
||||
} else if (ivalue.isTuple()) {
|
||||
for (const auto& v : ivalue.toTuple()->elements()) {
|
||||
stack.emplace_back(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
c10::optional<std::string> printType(
|
||||
const c10::Type& type,
|
||||
torch::jit::TypeNameUniquer& type_name_uniquer) {
|
||||
if (auto dyn = type.castRaw<c10::DynamicType>()) {
|
||||
return dyn->fallback()->annotation_str(
|
||||
[&](auto&& t) { return printType(t, type_name_uniquer); });
|
||||
}
|
||||
auto namedType = type.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return type_name_uniquer.getUniqueName(namedType).qualifiedName();
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
void jitModuleToPythonCodeAndConstants(
|
||||
const Module& module,
|
||||
ExtraFilesMap* jit_sources, // output
|
||||
std::vector<IValue>* constants // output
|
||||
) {
|
||||
std::vector<IValue> objects = traverseIValueAndGetObjects(module._ivalue());
|
||||
std::unordered_set<c10::QualifiedName> visited;
|
||||
PrintDepsTable class_deps;
|
||||
TypeNameUniquer uniquer;
|
||||
auto type_printer = [&](const c10::Type& t) { return printType(t, uniquer); };
|
||||
|
||||
// Group by prefix; because every prefix is a file.
|
||||
std::unordered_map<std::string, PythonPrint> grouped_by_prefix;
|
||||
for (const IValue& obj : objects) {
|
||||
ObjectPtr obj_ptr = obj.toObject();
|
||||
ClassTypePtr class_type = obj_ptr->type();
|
||||
class_deps.add(class_type);
|
||||
}
|
||||
|
||||
for (int i = 0; i < class_deps.size(); ++i) {
|
||||
auto type = class_deps[i];
|
||||
auto qualname = uniquer.getUniqueName(type);
|
||||
std::string qualifier = qualname.prefix();
|
||||
auto pp_iter = grouped_by_prefix.find(qualifier);
|
||||
if (pp_iter == grouped_by_prefix.end()) {
|
||||
pp_iter = grouped_by_prefix
|
||||
.emplace(
|
||||
qualifier,
|
||||
PythonPrint(
|
||||
*constants,
|
||||
class_deps,
|
||||
type_printer,
|
||||
/*enforce_importable=*/true))
|
||||
.first;
|
||||
}
|
||||
pp_iter->second.printNamedType(type);
|
||||
}
|
||||
for (const auto& kv : grouped_by_prefix) {
|
||||
(*jit_sources)[kv.first] = kv.second.str();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@ -49,5 +50,12 @@ struct TORCH_API PythonPrint {
|
||||
};
|
||||
|
||||
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
|
||||
|
||||
TORCH_API void jitModuleToPythonCodeAndConstants(
|
||||
const Module& module,
|
||||
ExtraFilesMap* jit_sources, // output
|
||||
std::vector<IValue>* constants // output
|
||||
);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
Reference in New Issue
Block a user