mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] pickle serialization for custom bound classes
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32604 Test Plan: Imported from OSS Differential Revision: D19566633 fbshipit-source-id: 9387d3ff45cbd6ccde49ce190a52859481cc301c
This commit is contained in:
committed by
Facebook Github Bot
parent
34ccfba403
commit
465ebd58ba
@ -129,7 +129,5 @@ struct is_type_condition<C, std::enable_if_t<std::is_same<bool, std::remove_cv_t
|
||||
*/
|
||||
template <class T>
|
||||
struct is_fundamental : std::is_fundamental<T> {};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -753,7 +753,7 @@ ENDIF()
|
||||
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
||||
FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h"
|
||||
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
|
||||
|
||||
|
||||
|
||||
@ -55,19 +55,16 @@ struct Stack : torch::jit::CustomClassHolder {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> __getstate__() const {
|
||||
return stack_;
|
||||
}
|
||||
|
||||
void __setstate__(std::vector<std::string> state) {
|
||||
stack_ = std::move(state);
|
||||
}
|
||||
|
||||
std::tuple<double, int64_t> return_a_tuple() const {
|
||||
return std::make_tuple(1337.0f, 123);
|
||||
}
|
||||
};
|
||||
|
||||
struct PickleTester : torch::jit::CustomClassHolder {
|
||||
PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
|
||||
std::vector<int64_t> vals;
|
||||
};
|
||||
|
||||
static auto test = torch::jit::class_<Foo>("_TorchScriptTesting_Foo")
|
||||
.def(torch::jit::init<int64_t, int64_t>())
|
||||
// .def(torch::jit::init<>())
|
||||
@ -83,8 +80,14 @@ static auto testStack =
|
||||
.def("pop", &Stack<std::string>::pop)
|
||||
.def("clone", &Stack<std::string>::clone)
|
||||
.def("merge", &Stack<std::string>::merge)
|
||||
.def("__getstate__", &Stack<std::string>::__getstate__)
|
||||
.def("__setstate__", &Stack<std::string>::__setstate__)
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<Stack<std::string>>& self) {
|
||||
return self->stack_;
|
||||
},
|
||||
[](std::vector<std::string> state) { // __setstate__
|
||||
return c10::make_intrusive<Stack<std::string>>(
|
||||
std::vector<std::string>{"i", "was", "deserialized"});
|
||||
})
|
||||
.def("return_a_tuple", &Stack<std::string>::return_a_tuple)
|
||||
.def(
|
||||
"top",
|
||||
@ -95,6 +98,28 @@ static auto testStack =
|
||||
// take an intrusive_ptr<Stack> as the first argument.
|
||||
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
|
||||
// clang-format on
|
||||
|
||||
static auto testPickle =
|
||||
torch::jit::class_<PickleTester>("_TorchScriptTesting_PickleTester")
|
||||
.def(torch::jit::init<std::vector<int64_t>>())
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
||||
return std::vector<int64_t>{1, 3, 3, 7};
|
||||
},
|
||||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<PickleTester>(std::move(state));
|
||||
})
|
||||
.def(
|
||||
"top",
|
||||
[](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
return self->vals.back();
|
||||
})
|
||||
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
auto val = self->vals.back();
|
||||
self->vals.pop_back();
|
||||
return val;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -4932,22 +4932,6 @@ def foo(x):
|
||||
self.assertEqual(out.pop(), "hi")
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
|
||||
def test_torchbind_getstate_setstate(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_StackString(["3", "5"])
|
||||
s = val.__getstate__()
|
||||
# TODO: sort out whether unpickler should call __new__ or __init__
|
||||
val2 = torch.classes._TorchScriptTesting_StackString(["0", "0"])
|
||||
val2.__setstate__(s)
|
||||
return val.pop(), val2.pop()
|
||||
ret = f()
|
||||
self.assertEqual(ret[0], ret[1])
|
||||
|
||||
ret = torch.jit.script(f)()
|
||||
self.assertEqual(ret[0], ret[1])
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
|
||||
def test_torchbind_return_tuple(self):
|
||||
@ -5000,6 +4984,47 @@ def foo(x):
|
||||
scripted = torch.jit.script(foo)
|
||||
self.assertEqual(scripted(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
|
||||
def test_torchbind_class_attribute(self):
|
||||
class FooBar1234(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar1234, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar1234()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
assert eic() == "deserialized"
|
||||
for expected in ["deserialized", "was", "i"]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
|
||||
def test_torchbind_getstate(self):
|
||||
class FooBar4321(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar4321, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar4321()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
|
||||
# return {1, 3, 3, 7}. I tried to make this actually depend on the
|
||||
# values at instantiation in the test with some transformation, but
|
||||
# because it seems we serialize/deserialize multiple times, that
|
||||
# transformation isn't as you would it expect it to be.
|
||||
assert eic() == 7
|
||||
for expected in [7, 3, 3, 1]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
def test_jitter_bug(self):
|
||||
@torch.jit.script
|
||||
def fn2(input, kernel_size):
|
||||
|
||||
@ -115,6 +115,17 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
|
||||
push<PickleOpCode>(PickleOpCode::BUILD);
|
||||
} else if (ivalue.isDevice()) {
|
||||
pushDevice(ivalue);
|
||||
} else if (ivalue.isCapsule()) {
|
||||
std::stringstream err;
|
||||
err << "Cannot serialize custom bound C++ class";
|
||||
if (memorized_class_types_ && memorized_class_types_->size()) {
|
||||
if (auto qualname = memorized_class_types_->back()->name()) {
|
||||
err << " " << qualname->qualifiedName();
|
||||
}
|
||||
}
|
||||
err << ". Please define serialization methods via torch::jit::pickle_ for "
|
||||
"this class.";
|
||||
AT_ERROR(err.str());
|
||||
} else {
|
||||
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
|
||||
}
|
||||
|
||||
@ -170,6 +170,11 @@ inline InferredType tryToInferType(py::handle input) {
|
||||
}
|
||||
}
|
||||
|
||||
if (py::isinstance<script::Object>(input)) {
|
||||
auto object = py::cast<script::Object>(input);
|
||||
return InferredType(object.type());
|
||||
}
|
||||
|
||||
// Try container types
|
||||
return tryToInferContainerType(input);
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/custom_class_detail.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@ -22,25 +23,10 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
|
||||
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
|
||||
|
||||
namespace detail {
|
||||
template <class R, class...>
|
||||
struct types {
|
||||
using type = types;
|
||||
};
|
||||
template <class Sig>
|
||||
struct args;
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
|
||||
template <class Sig>
|
||||
using args_t = typename args<Sig>::type;
|
||||
} // namespace detail
|
||||
template <class... Types>
|
||||
detail::types<void, Types...> init() { return detail::types<void, Types...>{}; }
|
||||
detail::types<void, Types...> init() {
|
||||
return detail::types<void, Types...>{};
|
||||
}
|
||||
|
||||
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
||||
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
|
||||
@ -68,10 +54,6 @@ class class_ {
|
||||
|
||||
public:
|
||||
class_(std::string className_) : className(std::move(className_)) {
|
||||
// Currently we register everything as a python class just for convenience.
|
||||
// We'll want to remove this at some point to get rid of the python
|
||||
// dependency. It would require significant changes to class registration,
|
||||
// (I think)?
|
||||
qualClassName = topModule + "." + parentModule + "." + className;
|
||||
|
||||
// We currently represent custom classes as torchscript classes with a
|
||||
@ -127,6 +109,72 @@ class class_ {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Pickle
|
||||
template <typename GetStateFn, typename SetStateFn>
|
||||
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
|
||||
static_assert(
|
||||
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
|
||||
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
|
||||
"torch::jit::pickle_ currently only supports lambdas as "
|
||||
"__getstate__ and __setstate__ arguments.");
|
||||
def("__getstate__", std::forward<GetStateFn>(get_state));
|
||||
|
||||
// __setstate__ needs to be registered with some custom handling:
|
||||
// We need to wrap the invocation of of the user-provided function
|
||||
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
|
||||
// and assign it to the `capsule` attribute.
|
||||
using SetStateTraits =
|
||||
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
|
||||
using SetStateArg = typename c10::guts::typelist::head_t<
|
||||
typename SetStateTraits::parameter_types>;
|
||||
auto setstate_wrapper = [set_state = std::move(set_state)](
|
||||
c10::tagged_capsule<CurClass> self,
|
||||
SetStateArg&& arg) {
|
||||
c10::intrusive_ptr<CurClass> classObj =
|
||||
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
|
||||
auto genericPtr =
|
||||
c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
|
||||
classObj);
|
||||
auto capsule = IValue(genericPtr);
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, capsule);
|
||||
};
|
||||
defineMethod<void>("__setstate__", std::move(setstate_wrapper));
|
||||
|
||||
// type validation
|
||||
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
|
||||
auto format_getstate_schema = [&getstate_schema]() {
|
||||
std::stringstream ss;
|
||||
ss << getstate_schema;
|
||||
return ss.str();
|
||||
};
|
||||
TORCH_CHECK(
|
||||
getstate_schema.arguments().size() == 1,
|
||||
"__getstate__ should take exactly one argument: self. Got: ",
|
||||
format_getstate_schema());
|
||||
auto first_arg_type = getstate_schema.arguments().at(0).type();
|
||||
TORCH_CHECK(
|
||||
*first_arg_type == *classTypePtr,
|
||||
"self argument of __getstate__ must be the custom class type. Got ",
|
||||
first_arg_type->python_str());
|
||||
TORCH_CHECK(
|
||||
getstate_schema.returns().size() == 1,
|
||||
"__getstate__ should return exactly one value for serialization. Got: ",
|
||||
format_getstate_schema());
|
||||
auto ser_type = getstate_schema.returns().at(0).type();
|
||||
auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema();
|
||||
auto arg_type = setstate_schema.arguments().at(1).type();
|
||||
TORCH_CHECK(
|
||||
(*arg_type == *ser_type),
|
||||
"__setstate__'s argument should be the same type as the "
|
||||
"return value of __getstate__. Got ",
|
||||
arg_type->python_str(),
|
||||
" but expected ",
|
||||
ser_type->python_str());
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
template<typename R, typename Func>
|
||||
void defineMethod(std::string name, Func func) {
|
||||
|
||||
37
torch/custom_class_detail.h
Normal file
37
torch/custom_class_detail.h
Normal file
@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Argument type utilities
|
||||
template <class R, class...>
|
||||
struct types {
|
||||
using type = types;
|
||||
};
|
||||
|
||||
template <class Sig>
|
||||
struct args;
|
||||
|
||||
// Method
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
|
||||
|
||||
// Const method
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
|
||||
|
||||
template <class Sig>
|
||||
using args_t = typename args<Sig>::type;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
|
||||
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
Reference in New Issue
Block a user