mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for __getstate__/__setstate__ on module (#20242)
Summary: Adds support for `__getstate__` and `__setstate__` on modules that are called as part of export (`torch.save()`) and import (`torch.jit.load`). * `__getstate__` and `__setstate__` must be TorchScript functions with the signatures `() -> T` and `(T) -> None` respectively * The results of `__getstate__` are stored using the pickler in `states.pkl` with one for each module in definition order (`__getstate__` returns `None` by default if an imlpementation is not provided) * This prevents sharing between `__getstate__` and attributes, but this should be fine since their use is mostly unrelated (attributes are for storing values to be used in script methods, `__getstate__` for running arbitrary computations during import) Follow up * Somehow replacing `__getstate__`/`__setstate__` with a `ScriptMethodStub` makes `MyScriptModule().__getstate__()` call `ScriptModule.__getstate__()` when used in Python. This should be fixed so semantics in Python are preserved, but it doesn't affect the typical usage. ](https://our.intern.facebook.com/intern/diff/15287161/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/20242 Pulled By: driazati Differential Revision: D15287161 fbshipit-source-id: b3f5f33ab74a21a89e6d15460af63aff75cab2d8
This commit is contained in:
committed by
Facebook Github Bot
parent
5f14ef8cc1
commit
cd28ff5395
@ -69,6 +69,9 @@ message ModuleDef {
|
||||
optional bool optimize = 8;
|
||||
|
||||
repeated AttributeDef attributes = 9;
|
||||
|
||||
// Used for retrieving module state from the pickled IValues table
|
||||
optional int64 get_state_attribute_id = 10;
|
||||
}
|
||||
|
||||
// Represents all non-module code that the model depends on.
|
||||
@ -79,7 +82,7 @@ message LibDef {
|
||||
}
|
||||
|
||||
enum ProtoVersion {
|
||||
PROTO_VERSION_NEWEST = 0x0000000000000004;
|
||||
PROTO_VERSION_NEWEST = 0x0000000000000005;
|
||||
}
|
||||
|
||||
message ModelDef {
|
||||
|
@ -12039,6 +12039,48 @@ a")
|
||||
s = u'\u00a3'.encode('utf8')[:1]
|
||||
self.checkScript(index_str_to_tensor, (s,))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
|
||||
def test_get_set_state(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
__constants__ = ['number']
|
||||
|
||||
def __init__(self, number, submodule=None):
|
||||
super(M, self).__init__()
|
||||
self.register_buffer('buffer1', torch.ones(2, 2))
|
||||
self.register_buffer('buffer2', torch.ones(2, 2))
|
||||
self.number = number
|
||||
if submodule:
|
||||
self.submodule = submodule
|
||||
|
||||
@torch.jit.script_method
|
||||
def __getstate__(self):
|
||||
# type: () -> Tuple[Tensor, Tensor, int]
|
||||
return (self.buffer1, self.buffer2, 74)
|
||||
|
||||
@torch.jit.script_method
|
||||
def __setstate__(self, state):
|
||||
# type: (Tuple[Tensor, Tensor, int]) -> None
|
||||
self.buffer1 = state[0] + 10
|
||||
self.buffer2 = state[1] + 10
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
m = M(23, submodule=M(99))
|
||||
m.save(fname)
|
||||
loaded = torch.jit.load(fname)
|
||||
|
||||
# Check original module
|
||||
self.assertEqual(m.buffer1, torch.ones(2, 2))
|
||||
self.assertEqual(m.buffer2, torch.ones(2, 2))
|
||||
|
||||
# Check top level module
|
||||
self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10)
|
||||
self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
|
||||
|
||||
# Check submodule
|
||||
self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10)
|
||||
self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10)
|
||||
|
||||
|
||||
def test_string_slicing(self):
|
||||
def fn1(x):
|
||||
# type: (str) -> str
|
||||
|
@ -504,7 +504,10 @@ class ScriptModuleSerializer final {
|
||||
// to dump the content of a tensor
|
||||
void writeTensorTable(torch::ModelDef* model_def);
|
||||
|
||||
void writeAttributeTable();
|
||||
// Write the list of ivalues to a file as a pickle program
|
||||
void writePickleArchive(
|
||||
const std::string& name,
|
||||
const std::vector<IValue>& ivalues);
|
||||
void writeLibs(torch::ModelDef* model_def);
|
||||
|
||||
void convertModule(
|
||||
@ -513,10 +516,8 @@ class ScriptModuleSerializer final {
|
||||
const std::string& name,
|
||||
torch::ModuleDef* module_def);
|
||||
|
||||
void convertParameter(
|
||||
const script::Slot& param,
|
||||
torch::ParameterDef* param_def,
|
||||
bool is_parameter);
|
||||
IValue moduleGetState(const script::Module& module);
|
||||
bool moduleHasValidGetSetState(const script::Module& module);
|
||||
|
||||
void convertClass(const ClassTypePtr& type, torch::ModelDef* model_def);
|
||||
|
||||
@ -526,7 +527,9 @@ class ScriptModuleSerializer final {
|
||||
// all tensors that will be stored
|
||||
std::vector<at::Tensor> tensor_table_;
|
||||
|
||||
std::vector<IValue> attribute_table_;
|
||||
// A list of attributes (indexed by attr_def->id()) and module state (indexed
|
||||
// by module_def->id())
|
||||
std::vector<IValue> pickled_ivalues_;
|
||||
|
||||
// all classes used by this module hierarchy
|
||||
std::vector<ClassTypePtr> class_table_;
|
||||
@ -656,8 +659,8 @@ void ScriptModuleSerializer::convertModel(
|
||||
convertModule(
|
||||
module, "", writer_.archiveName(), model_def->mutable_main_module());
|
||||
|
||||
// This may write some attributes to the tensor_table_
|
||||
writeAttributeTable();
|
||||
|
||||
writePickleArchive("attributes.pkl", pickled_ivalues_);
|
||||
|
||||
writeTensorTable(model_def);
|
||||
writeLibs(model_def);
|
||||
@ -669,6 +672,82 @@ void ScriptModuleSerializer::convertModel(
|
||||
}
|
||||
}
|
||||
|
||||
bool ScriptModuleSerializer::moduleHasValidGetSetState(
|
||||
const script::Module& module) {
|
||||
// Check that the schemas for __getstate__ and __setstate__ are correct
|
||||
auto getstate = module.module_object()->type()->getMethod("__getstate__");
|
||||
if (getstate == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto get_schema =
|
||||
module.module_object()->type()->getMethod("__getstate__")->getSchema();
|
||||
|
||||
// Check __getstate__
|
||||
// __getstate__ is expected to be (self) -> T
|
||||
AT_CHECK(
|
||||
get_schema.arguments().size() == 1,
|
||||
"'__getstate__' must have 'self' as its only argument, but found ",
|
||||
get_schema.arguments().size(),
|
||||
" arguments");
|
||||
AT_CHECK(
|
||||
get_schema.returns().size() == 1,
|
||||
"'__getstate__' must return 1 value, but found ",
|
||||
get_schema.returns().size());
|
||||
|
||||
// Check __setstate__ if the method exists
|
||||
// __setstate__ is expected to be (self, T) -> None
|
||||
// TODO: use getMethod("__getstate__") once methods are not lowered
|
||||
auto setstate = module.class_compilation_unit().find_function("__setstate__");
|
||||
if (setstate == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto set_schema = setstate->getSchema();
|
||||
|
||||
AT_CHECK(
|
||||
set_schema.arguments().size() == 2,
|
||||
"'__setstate__' must have 'self' and the state as its "
|
||||
"only arguments, but found ",
|
||||
set_schema.arguments().size(),
|
||||
" arguments");
|
||||
AT_CHECK(
|
||||
set_schema.returns().size() == 1,
|
||||
"'__setstate__' must return None, but found ",
|
||||
set_schema.returns().size(),
|
||||
" return values");
|
||||
AT_CHECK(
|
||||
set_schema.returns().at(0).type()->isSubtypeOf(NoneType::get()),
|
||||
"'__setstate__' must return None, but found value of type",
|
||||
set_schema.returns().at(0).type()->python_str());
|
||||
|
||||
// Check that the return type of __getstate__ matches the input to
|
||||
// __setstate__
|
||||
auto get_type = get_schema.returns().at(0).type();
|
||||
auto set_type = set_schema.arguments().at(1).type();
|
||||
|
||||
AT_CHECK(
|
||||
set_type->isSubtypeOf(get_type),
|
||||
"'__getstate__'s return type (",
|
||||
get_type->python_str(),
|
||||
" does not match '__setstate__'s argument type (",
|
||||
set_type->python_str(),
|
||||
"))");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Run module.__getstate__() and return the result
|
||||
IValue ScriptModuleSerializer::moduleGetState(const script::Module& module) {
|
||||
auto getstate = module.find_method("__getstate__");
|
||||
AT_CHECK(
|
||||
getstate != nullptr,
|
||||
"Cannot call '__getstate__' method because"
|
||||
" it does not exist");
|
||||
|
||||
Stack stack;
|
||||
getstate->run(stack);
|
||||
return stack.at(0);
|
||||
}
|
||||
|
||||
size_t ScriptModuleSerializer::addTensor(const at::Tensor& tensor) {
|
||||
tensor_table_.push_back(tensor);
|
||||
return tensor_table_.size() - 1;
|
||||
@ -720,17 +799,18 @@ void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
|
||||
}
|
||||
}
|
||||
|
||||
void ScriptModuleSerializer::writeAttributeTable() {
|
||||
void ScriptModuleSerializer::writePickleArchive(
|
||||
const std::string& name,
|
||||
const std::vector<IValue>& ivalues) {
|
||||
Pickler pickler(&tensor_table_);
|
||||
pickler.start();
|
||||
pickler.startTuple();
|
||||
for (const IValue& ivalue : attribute_table_) {
|
||||
for (const IValue& ivalue : ivalues) {
|
||||
pickler.addIValue(ivalue);
|
||||
}
|
||||
pickler.endTuple();
|
||||
pickler.finish();
|
||||
writer_.writeRecord(
|
||||
"attributes.pkl", pickler.stack().data(), pickler.stack().size());
|
||||
writer_.writeRecord(name, pickler.stack().data(), pickler.stack().size());
|
||||
}
|
||||
|
||||
void ScriptModuleSerializer::convertModule(
|
||||
@ -740,19 +820,47 @@ void ScriptModuleSerializer::convertModule(
|
||||
torch::ModuleDef* module_def) {
|
||||
module_def->set_name(name);
|
||||
module_def->set_optimize(module.is_optimized());
|
||||
for (const auto& elem : module.get_parameters()) {
|
||||
torch::ParameterDef* param_def = module_def->add_parameters();
|
||||
convertParameter(elem, param_def, /*is_buffer=*/false);
|
||||
|
||||
// If __getstate__ and __setstate__ methods are provided, use those for
|
||||
// serializing instead of serializing the attributes directly
|
||||
bool user_provided_serialization = moduleHasValidGetSetState(module);
|
||||
if (user_provided_serialization) {
|
||||
// Run the '__getstate__' method on the module and store the result
|
||||
pickled_ivalues_.emplace_back(moduleGetState(module));
|
||||
module_def->set_get_state_attribute_id(pickled_ivalues_.size() - 1);
|
||||
}
|
||||
|
||||
// Add all the parameters
|
||||
for (const auto& param : module.get_parameters()) {
|
||||
torch::ParameterDef* param_def = module_def->add_parameters();
|
||||
param_def->set_name(param.name());
|
||||
param_def->set_is_buffer(false);
|
||||
if (user_provided_serialization) {
|
||||
// If a __getstate__ was used, don't write the actual tensor
|
||||
param_def->set_tensor_id(-1);
|
||||
} else {
|
||||
param_def->set_tensor_id(addTensor(param.value().toTensor()));
|
||||
}
|
||||
}
|
||||
|
||||
// Add all the attributes
|
||||
for (const auto& attribute : module.get_attributes()) {
|
||||
// Add attribute to ModuleDef
|
||||
torch::AttributeDef* attribute_def = module_def->add_attributes();
|
||||
attribute_def->set_name(attribute.name());
|
||||
attribute_def->set_type(attribute.type()->python_str());
|
||||
|
||||
attribute_table_.push_back(attribute.value());
|
||||
attribute_def->set_id(attribute_table_.size() - 1);
|
||||
if (!user_provided_serialization) {
|
||||
// Write the attribute's index if it's actually saved, -1 if it needs to
|
||||
// come from __getstate__
|
||||
pickled_ivalues_.push_back(attribute.value());
|
||||
attribute_def->set_id(pickled_ivalues_.size() - 1);
|
||||
} else {
|
||||
// The module had a __setstate__, so write the attribute name/type so
|
||||
// it can be correctly imported, but it has no entry in the
|
||||
// pickled_ivalues_ table
|
||||
attribute_def->set_id(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream module_name;
|
||||
@ -760,7 +868,7 @@ void ScriptModuleSerializer::convertModule(
|
||||
module_name << prefix << "_";
|
||||
module_name << name;
|
||||
|
||||
if (module.get_methods().size() > 0) {
|
||||
if (module.class_compilation_unit().get_functions().size() > 0) {
|
||||
std::ostringstream methods;
|
||||
methods << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
|
||||
PythonPrint(
|
||||
@ -786,15 +894,6 @@ void ScriptModuleSerializer::convertModule(
|
||||
}
|
||||
}
|
||||
|
||||
void ScriptModuleSerializer::convertParameter(
|
||||
const script::Slot& param,
|
||||
torch::ParameterDef* param_def,
|
||||
bool is_parameter) {
|
||||
param_def->set_name(param.name());
|
||||
param_def->set_is_buffer(is_parameter);
|
||||
param_def->set_tensor_id(addTensor(param.value().toTensor()));
|
||||
}
|
||||
|
||||
// Pretty printing for ONNX
|
||||
constexpr char indent_char = ' ';
|
||||
constexpr size_t indent_multiplier = 2;
|
||||
|
@ -58,8 +58,11 @@ class ScriptModuleDeserializer final {
|
||||
void convertModule(const torch::ModuleDef& module_def);
|
||||
|
||||
void loadTensorTable(torch::ModelDef* model_def);
|
||||
void loadAttributeTable();
|
||||
std::vector<IValue> loadPickleArchive(const std::string& name);
|
||||
void importCallback(const std::string& qualifier);
|
||||
void moduleSetState(
|
||||
const std::shared_ptr<script::Module>& module,
|
||||
IValue state);
|
||||
|
||||
caffe2::serialize::PyTorchStreamReader reader_;
|
||||
// this is a hack to make sure the script module created in C++ is the
|
||||
@ -69,7 +72,8 @@ class ScriptModuleDeserializer final {
|
||||
std::vector<std::string> moduleStack_;
|
||||
|
||||
std::vector<at::Tensor> tensor_table_;
|
||||
std::vector<IValue> attribute_table_;
|
||||
std::vector<IValue> pickled_ivalues_;
|
||||
|
||||
std::unordered_set<std::string> imported_libs_;
|
||||
|
||||
std::shared_ptr<script::Module> main_module_;
|
||||
@ -139,7 +143,7 @@ void ScriptModuleDeserializer::deserialize(
|
||||
|
||||
loadTensorTable(&model_def);
|
||||
if (model_def.proto_version() >= 2) {
|
||||
loadAttributeTable();
|
||||
pickled_ivalues_ = loadPickleArchive("attributes.pkl");
|
||||
}
|
||||
|
||||
// TODO: this can be simplified when C++/Python interop lands,
|
||||
@ -154,13 +158,13 @@ void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
|
||||
}
|
||||
}
|
||||
|
||||
void ScriptModuleDeserializer::loadAttributeTable() {
|
||||
std::vector<IValue> ScriptModuleDeserializer::loadPickleArchive(const std::string& name) {
|
||||
at::DataPtr attributes_ptr;
|
||||
size_t attributes_size;
|
||||
std::tie(attributes_ptr, attributes_size) =
|
||||
reader_.getRecord("attributes.pkl");
|
||||
reader_.getRecord(name);
|
||||
Unpickler unpickler(attributes_ptr.get(), attributes_size, &tensor_table_);
|
||||
attribute_table_ = unpickler.parse_ivalue_list();
|
||||
return unpickler.parse_ivalue_list();
|
||||
}
|
||||
|
||||
at::Tensor ScriptModuleDeserializer::loadTensor(
|
||||
@ -255,6 +259,21 @@ void ScriptModuleDeserializer::importCallback(const std::string& qualifier) {
|
||||
import_callback);
|
||||
}
|
||||
|
||||
void ScriptModuleDeserializer::moduleSetState(
|
||||
const std::shared_ptr<script::Module>& module,
|
||||
IValue state) {
|
||||
auto setstate = module->class_compilation_unit().find_function("__setstate__");
|
||||
|
||||
AT_CHECK(
|
||||
setstate != nullptr,
|
||||
"Cannot call '__setstate__' method because"
|
||||
" it does not exist");
|
||||
|
||||
// TODO: once modules are first class in the interpreter and methods are not
|
||||
// lowered, change this to `module->run_method("__setstate__", {state});`
|
||||
setstate->run({module->module_object(), state});
|
||||
}
|
||||
|
||||
void ScriptModuleDeserializer::convertModule(
|
||||
const torch::ModuleDef& module_def) {
|
||||
std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
|
||||
@ -282,10 +301,16 @@ void ScriptModuleDeserializer::convertModule(
|
||||
continue;
|
||||
}
|
||||
|
||||
IValue ivalue;
|
||||
if (attr_def.id() >= 0) {
|
||||
// attribute has no value in the table, set it to None for now. After
|
||||
// __getstate__, check that all the attributes that are not Optional
|
||||
// can't be None
|
||||
ivalue = pickled_ivalues_.at(attr_def.id());
|
||||
}
|
||||
|
||||
module->register_attribute(
|
||||
attr_def.name(),
|
||||
typeParser.parseType(attr_def.type()),
|
||||
attribute_table_.at(attr_def.id()));
|
||||
attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
|
||||
}
|
||||
if (module_def.has_torchscript_arena()) {
|
||||
at::DataPtr data;
|
||||
@ -303,6 +328,26 @@ void ScriptModuleDeserializer::convertModule(
|
||||
tensor_table_,
|
||||
import_callback);
|
||||
}
|
||||
|
||||
if (module_def.has_get_state_attribute_id()) {
|
||||
moduleSetState(
|
||||
module, pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
||||
}
|
||||
|
||||
for (const auto& slot : module->get_attributes()) {
|
||||
// Verify that all the non-optional attributes have been initialized
|
||||
// TODO: Issue #20497
|
||||
if (slot.type()->kind() != TypeKind::OptionalType) {
|
||||
AT_CHECK(
|
||||
!slot.value().isNone(),
|
||||
"The field '",
|
||||
slot.name(),
|
||||
"' was left unitialized after __setstate__, but expected a ",
|
||||
"value of type '",
|
||||
slot.type()->python_str(),
|
||||
"'");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -143,7 +143,7 @@ struct TORCH_API Module {
|
||||
~Module() {
|
||||
// ClassType own the compilation unit of their Functions, but each
|
||||
// Function has a self argument which owns the ClassType, created a
|
||||
// referernce cycle. By dropping all the methods of the module's class
|
||||
// reference cycle. By dropping all the methods of the module's class
|
||||
// here we break the cycle.
|
||||
class_compilation_unit().drop_all_functions();
|
||||
}
|
||||
|
@ -330,6 +330,16 @@ std::vector<std::shared_ptr<SugaredValue>> ModuleValue::asTuple(
|
||||
return result;
|
||||
}
|
||||
|
||||
void ModuleValue::setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
const std::string& field,
|
||||
Value* newValue) {
|
||||
// Forward to SimpleValue::setAttr
|
||||
SimpleValue simple(self_);
|
||||
simple.setAttr(loc, m, field, newValue);
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
@ -160,6 +161,12 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
|
||||
Function& m,
|
||||
const c10::optional<size_t>& size_hint = {}) override;
|
||||
|
||||
void setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
const std::string& field,
|
||||
Value* newValue) override;
|
||||
|
||||
private:
|
||||
Value* self_;
|
||||
std::shared_ptr<Module> module_;
|
||||
|
Reference in New Issue
Block a user