mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Summary: Adds support for `__getstate__` and `__setstate__` on modules that are called as part of export (`torch.save()`) and import (`torch.jit.load`). * `__getstate__` and `__setstate__` must be TorchScript functions with the signatures `() -> T` and `(T) -> None` respectively * The results of `__getstate__` are stored using the pickler in `states.pkl` with one for each module in definition order (`__getstate__` returns `None` by default if an imlpementation is not provided) * This prevents sharing between `__getstate__` and attributes, but this should be fine since their use is mostly unrelated (attributes are for storing values to be used in script methods, `__getstate__` for running arbitrary computations during import) Follow up * Somehow replacing `__getstate__`/`__setstate__` with a `ScriptMethodStub` makes `MyScriptModule().__getstate__()` call `ScriptModule.__getstate__()` when used in Python. This should be fixed so semantics in Python are preserved, but it doesn't affect the typical usage. ](https://our.intern.facebook.com/intern/diff/15287161/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/20242 Pulled By: driazati Differential Revision: D15287161 fbshipit-source-id: b3f5f33ab74a21a89e6d15460af63aff75cab2d8
198 lines
5.6 KiB
C++
198 lines
5.6 KiB
C++
|
|
#pragma once
|
|
|
|
#include <torch/csrc/jit/pybind_utils.h>
|
|
#include <torch/csrc/jit/script/module.h>
|
|
#include <torch/csrc/jit/script/sugared_value.h>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace script {
|
|
|
|
std::string typeString(py::handle h);
|
|
|
|
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
|
|
return std::make_shared<SimpleValue>(v);
|
|
}
|
|
|
|
// NB: This should be the single entry-point for instantiating a SugaredValue
|
|
// from a Python object. If you are adding support for converting a new Python
|
|
// type, *add it in this function's implementation*.
|
|
std::shared_ptr<SugaredValue> toSugaredValue(
|
|
py::object obj,
|
|
Function& m,
|
|
SourceRange loc,
|
|
bool is_constant = false);
|
|
|
|
std::shared_ptr<Function> as_function(const py::object& obj);
|
|
|
|
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
|
PythonValue(py::object self) : self(std::move(self)) {}
|
|
|
|
FunctionSchema getSchema(const size_t n_args, const size_t n_binders);
|
|
|
|
// call it like a function, e.g. `outputs = this(inputs)`
|
|
std::shared_ptr<SugaredValue> call(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
at::ArrayRef<NamedValue> inputs_,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override;
|
|
|
|
std::string kind() const override;
|
|
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
|
|
std::shared_ptr<SugaredValue> attr(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const std::string& field) override;
|
|
|
|
protected:
|
|
py::object getattr(const SourceRange& loc, const std::string& name);
|
|
|
|
void checkForAddToConstantsError(std::stringstream& ss);
|
|
|
|
py::object self;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
|
|
explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
|
|
|
|
std::shared_ptr<SugaredValue> attr(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const std::string& field) override;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
|
|
explicit ConstantPythonTupleValue(py::object tup)
|
|
: PythonValue(std::move(tup)) {}
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
|
|
Value* asValue(const SourceRange& loc, Function& m) override;
|
|
};
|
|
|
|
// Represents all the parameters of a module as a List[Tensor]
|
|
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
|
|
ConstantParameterList(Value* the_list) : the_list_(the_list) {}
|
|
std::string kind() const override {
|
|
return "constant parameter list";
|
|
}
|
|
std::shared_ptr<SugaredValue> call(
|
|
const SourceRange& loc,
|
|
Function& caller,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override {
|
|
return toSimple(the_list_);
|
|
}
|
|
|
|
private:
|
|
Value* the_list_;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue {
|
|
OverloadedMethodValue(Value* module, std::vector<std::string> method_names)
|
|
: module_(module), method_names_(std::move(method_names)) {}
|
|
|
|
std::string kind() const override {
|
|
return "overloaded function";
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> call(
|
|
const SourceRange& loc,
|
|
Function& caller,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override;
|
|
|
|
private:
|
|
Value* module_;
|
|
std::vector<std::string> method_names_;
|
|
};
|
|
|
|
// defines how modules/methods behave inside the script subset.
|
|
// for now this does not have any interaction with python.
|
|
// in the future, we will add the ability to resolve `self.foo` to python
|
|
// {functions, modules, contants} so this SugaredValue is defined here
|
|
// anticipating we will eventually need to replace Module with a py::object
|
|
// holding the actual nn.Module class.
|
|
|
|
struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
|
|
ModuleValue(Value* self, std::shared_ptr<Module> module, py::object py_module)
|
|
: self_(self),
|
|
module_(std::move(module)),
|
|
py_module_(std::move(py_module)) {}
|
|
|
|
std::string kind() const override {
|
|
return "module";
|
|
}
|
|
|
|
// select an attribute on it, e.g. `this.field`
|
|
std::shared_ptr<SugaredValue> attr(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const std::string& field) override;
|
|
|
|
// call module.forward
|
|
std::shared_ptr<SugaredValue> call(
|
|
const SourceRange& loc,
|
|
Function& caller,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override {
|
|
return attr(loc, caller, "forward")
|
|
->call(loc, caller, inputs, attributes, n_binders);
|
|
}
|
|
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const c10::optional<size_t>& size_hint = {}) override;
|
|
|
|
void setAttr(
|
|
const SourceRange& loc,
|
|
Function& m,
|
|
const std::string& field,
|
|
Value* newValue) override;
|
|
|
|
private:
|
|
Value* self_;
|
|
std::shared_ptr<Module> module_;
|
|
py::object py_module_;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
|
|
BooleanDispatchValue(py::dict dispatched_fn)
|
|
: dispatched_fn_(std::move(dispatched_fn)) {}
|
|
|
|
std::string kind() const override {
|
|
return "boolean dispatch";
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> call(
|
|
const SourceRange& loc,
|
|
Function& caller,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override;
|
|
|
|
private:
|
|
py::dict dispatched_fn_;
|
|
};
|
|
|
|
} // namespace script
|
|
} // namespace jit
|
|
} // namespace torch
|