Files
pytorch/torch/csrc/jit/script/python_sugared_value.h
davidriazati cd28ff5395 Add support for __getstate__/__setstate__ on module (#20242)
Summary:
Adds support for `__getstate__` and `__setstate__` on modules that are called as part of export (`torch.save()`) and import (`torch.jit.load`).
* `__getstate__` and `__setstate__` must be TorchScript functions with the signatures `() -> T` and `(T) -> None` respectively
* The results of `__getstate__` are stored using the pickler in `states.pkl` with one for each module in definition order (`__getstate__` returns `None` by default if an imlpementation is not provided)
    * This prevents sharing between `__getstate__` and attributes, but this should be fine since their use is mostly unrelated (attributes are for storing values to be used in script methods, `__getstate__` for running arbitrary computations during import)

Follow up
* Somehow replacing `__getstate__`/`__setstate__` with a `ScriptMethodStub` makes `MyScriptModule().__getstate__()` call `ScriptModule.__getstate__()` when used in Python. This should be fixed so semantics in Python are preserved, but it doesn't affect the typical usage.
](https://our.intern.facebook.com/intern/diff/15287161/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20242

Pulled By: driazati

Differential Revision: D15287161

fbshipit-source-id: b3f5f33ab74a21a89e6d15460af63aff75cab2d8
2019-05-17 14:43:14 -07:00

198 lines
5.6 KiB
C++

#pragma once
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/sugared_value.h>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace script {
std::string typeString(py::handle h);
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
return std::make_shared<SimpleValue>(v);
}
// NB: This should be the single entry-point for instantiating a SugaredValue
// from a Python object. If you are adding support for converting a new Python
// type, *add it in this function's implementation*.
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Function& m,
SourceRange loc,
bool is_constant = false);
std::shared_ptr<Function> as_function(const py::object& obj);
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
PythonValue(py::object self) : self(std::move(self)) {}
FunctionSchema getSchema(const size_t n_args, const size_t n_binders);
// call it like a function, e.g. `outputs = this(inputs)`
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
std::string kind() const override;
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override;
protected:
py::object getattr(const SourceRange& loc, const std::string& name);
void checkForAddToConstantsError(std::stringstream& ss);
py::object self;
};
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override;
};
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
explicit ConstantPythonTupleValue(py::object tup)
: PythonValue(std::move(tup)) {}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) override;
Value* asValue(const SourceRange& loc, Function& m) override;
};
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
ConstantParameterList(Value* the_list) : the_list_(the_list) {}
std::string kind() const override {
return "constant parameter list";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
return toSimple(the_list_);
}
private:
Value* the_list_;
};
struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue {
OverloadedMethodValue(Value* module, std::vector<std::string> method_names)
: module_(module), method_names_(std::move(method_names)) {}
std::string kind() const override {
return "overloaded function";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
private:
Value* module_;
std::vector<std::string> method_names_;
};
// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
// {functions, modules, contants} so this SugaredValue is defined here
// anticipating we will eventually need to replace Module with a py::object
// holding the actual nn.Module class.
struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
ModuleValue(Value* self, std::shared_ptr<Module> module, py::object py_module)
: self_(self),
module_(std::move(module)),
py_module_(std::move(py_module)) {}
std::string kind() const override {
return "module";
}
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override;
// call module.forward
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
return attr(loc, caller, "forward")
->call(loc, caller, inputs, attributes, n_binders);
}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
const c10::optional<size_t>& size_hint = {}) override;
void setAttr(
const SourceRange& loc,
Function& m,
const std::string& field,
Value* newValue) override;
private:
Value* self_;
std::shared_ptr<Module> module_;
py::object py_module_;
};
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValue(py::dict dispatched_fn)
: dispatched_fn_(std::move(dispatched_fn)) {}
std::string kind() const override {
return "boolean dispatch";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;
private:
py::dict dispatched_fn_;
};
} // namespace script
} // namespace jit
} // namespace torch