[JIT] pickle serialization for custom bound classes

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32604

Test Plan: Imported from OSS

Differential Revision: D19566633

fbshipit-source-id: 9387d3ff45cbd6ccde49ce190a52859481cc301c
This commit is contained in:
James Reed
2020-01-28 10:58:28 -08:00
committed by Facebook Github Bot
parent 34ccfba403
commit 465ebd58ba
8 changed files with 200 additions and 51 deletions

View File

@ -129,7 +129,5 @@ struct is_type_condition<C, std::enable_if_t<std::is_same<bool, std::remove_cv_t
*/
template <class T>
struct is_fundamental : std::is_fundamental<T> {};
}
}

View File

@ -753,7 +753,7 @@ ENDIF()
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
FILES_MATCHING PATTERN "*.h")
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h"
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)

View File

@ -55,19 +55,16 @@ struct Stack : torch::jit::CustomClassHolder {
}
}
std::vector<std::string> __getstate__() const {
return stack_;
}
void __setstate__(std::vector<std::string> state) {
stack_ = std::move(state);
}
std::tuple<double, int64_t> return_a_tuple() const {
return std::make_tuple(1337.0f, 123);
}
};
struct PickleTester : torch::jit::CustomClassHolder {
PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
std::vector<int64_t> vals;
};
static auto test = torch::jit::class_<Foo>("_TorchScriptTesting_Foo")
.def(torch::jit::init<int64_t, int64_t>())
// .def(torch::jit::init<>())
@ -83,8 +80,14 @@ static auto testStack =
.def("pop", &Stack<std::string>::pop)
.def("clone", &Stack<std::string>::clone)
.def("merge", &Stack<std::string>::merge)
.def("__getstate__", &Stack<std::string>::__getstate__)
.def("__setstate__", &Stack<std::string>::__setstate__)
.def_pickle(
[](const c10::intrusive_ptr<Stack<std::string>>& self) {
return self->stack_;
},
[](std::vector<std::string> state) { // __setstate__
return c10::make_intrusive<Stack<std::string>>(
std::vector<std::string>{"i", "was", "deserialized"});
})
.def("return_a_tuple", &Stack<std::string>::return_a_tuple)
.def(
"top",
@ -95,6 +98,28 @@ static auto testStack =
// take an intrusive_ptr<Stack> as the first argument.
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
// clang-format on
static auto testPickle =
torch::jit::class_<PickleTester>("_TorchScriptTesting_PickleTester")
.def(torch::jit::init<std::vector<int64_t>>())
.def_pickle(
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
return std::vector<int64_t>{1, 3, 3, 7};
},
[](std::vector<int64_t> state) { // __setstate__
return c10::make_intrusive<PickleTester>(std::move(state));
})
.def(
"top",
[](const c10::intrusive_ptr<PickleTester>& self) {
return self->vals.back();
})
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
auto val = self->vals.back();
self->vals.pop_back();
return val;
});
} // namespace
} // namespace jit

View File

@ -4932,22 +4932,6 @@ def foo(x):
self.assertEqual(out.pop(), "hi")
self.assertEqual(out.pop(), "mom")
@skipIfRocm
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
def test_torchbind_getstate_setstate(self):
def f():
val = torch.classes._TorchScriptTesting_StackString(["3", "5"])
s = val.__getstate__()
# TODO: sort out whether unpickler should call __new__ or __init__
val2 = torch.classes._TorchScriptTesting_StackString(["0", "0"])
val2.__setstate__(s)
return val.pop(), val2.pop()
ret = f()
self.assertEqual(ret[0], ret[1])
ret = torch.jit.script(f)()
self.assertEqual(ret[0], ret[1])
@skipIfRocm
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
def test_torchbind_return_tuple(self):
@ -5000,6 +4984,47 @@ def foo(x):
scripted = torch.jit.script(foo)
self.assertEqual(scripted(), "mom")
@skipIfRocm
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
def test_torchbind_class_attribute(self):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"])
def forward(self):
return self.f.top()
inst = FooBar1234()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
assert eic() == "deserialized"
for expected in ["deserialized", "was", "i"]:
assert eic.f.pop() == expected
@skipIfRocm
@unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
def test_torchbind_getstate(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
# return {1, 3, 3, 7}. I tried to make this actually depend on the
# values at instantiation in the test with some transformation, but
# because it seems we serialize/deserialize multiple times, that
# transformation isn't as you would it expect it to be.
assert eic() == 7
for expected in [7, 3, 3, 1]:
assert eic.f.pop() == expected
def test_jitter_bug(self):
@torch.jit.script
def fn2(input, kernel_size):

View File

@ -115,6 +115,17 @@ void Pickler::pushIValueImpl(const IValue& ivalue) {
push<PickleOpCode>(PickleOpCode::BUILD);
} else if (ivalue.isDevice()) {
pushDevice(ivalue);
} else if (ivalue.isCapsule()) {
std::stringstream err;
err << "Cannot serialize custom bound C++ class";
if (memorized_class_types_ && memorized_class_types_->size()) {
if (auto qualname = memorized_class_types_->back()->name()) {
err << " " << qualname->qualifiedName();
}
}
err << ". Please define serialization methods via torch::jit::pickle_ for "
"this class.";
AT_ERROR(err.str());
} else {
AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
}

View File

@ -170,6 +170,11 @@ inline InferredType tryToInferType(py::handle input) {
}
}
if (py::isinstance<script::Object>(input)) {
auto object = py::cast<script::Object>(input);
return InferredType(object.type());
}
// Try container types
return tryToInferContainerType(input);
}

View File

@ -15,6 +15,7 @@
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/utils/variadic.h>
#include <torch/custom_class_detail.h>
#include <iostream>
#include <sstream>
@ -22,25 +23,10 @@
namespace torch {
namespace jit {
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
namespace detail {
template <class R, class...>
struct types {
using type = types;
};
template <class Sig>
struct args;
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
template <class Sig>
using args_t = typename args<Sig>::type;
} // namespace detail
template <class... Types>
detail::types<void, Types...> init() { return detail::types<void, Types...>{}; }
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
@ -68,10 +54,6 @@ class class_ {
public:
class_(std::string className_) : className(std::move(className_)) {
// Currently we register everything as a python class just for convenience.
// We'll want to remove this at some point to get rid of the python
// dependency. It would require significant changes to class registration,
// (I think)?
qualClassName = topModule + "." + parentModule + "." + className;
// We currently represent custom classes as torchscript classes with a
@ -127,6 +109,72 @@ class class_ {
return *this;
}
// Pickle
template <typename GetStateFn, typename SetStateFn>
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
static_assert(
c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value &&
c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value,
"torch::jit::pickle_ currently only supports lambdas as "
"__getstate__ and __setstate__ arguments.");
def("__getstate__", std::forward<GetStateFn>(get_state));
// __setstate__ needs to be registered with some custom handling:
// We need to wrap the invocation of of the user-provided function
// such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>)
// and assign it to the `capsule` attribute.
using SetStateTraits =
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
using SetStateArg = typename c10::guts::typelist::head_t<
typename SetStateTraits::parameter_types>;
auto setstate_wrapper = [set_state = std::move(set_state)](
c10::tagged_capsule<CurClass> self,
SetStateArg&& arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(set_state, std::forward<SetStateArg>(arg));
auto genericPtr =
c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(
classObj);
auto capsule = IValue(genericPtr);
auto object = self.ivalue.toObject();
object->setSlot(0, capsule);
};
defineMethod<void>("__setstate__", std::move(setstate_wrapper));
// type validation
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
auto format_getstate_schema = [&getstate_schema]() {
std::stringstream ss;
ss << getstate_schema;
return ss.str();
};
TORCH_CHECK(
getstate_schema.arguments().size() == 1,
"__getstate__ should take exactly one argument: self. Got: ",
format_getstate_schema());
auto first_arg_type = getstate_schema.arguments().at(0).type();
TORCH_CHECK(
*first_arg_type == *classTypePtr,
"self argument of __getstate__ must be the custom class type. Got ",
first_arg_type->python_str());
TORCH_CHECK(
getstate_schema.returns().size() == 1,
"__getstate__ should return exactly one value for serialization. Got: ",
format_getstate_schema());
auto ser_type = getstate_schema.returns().at(0).type();
auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema();
auto arg_type = setstate_schema.arguments().at(1).type();
TORCH_CHECK(
(*arg_type == *ser_type),
"__setstate__'s argument should be the same type as the "
"return value of __getstate__. Got ",
arg_type->python_str(),
" but expected ",
ser_type->python_str());
return *this;
}
private:
template<typename R, typename Func>
void defineMethod(std::string name, Func func) {

View File

@ -0,0 +1,37 @@
#pragma once
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeTraits.h>
namespace torch {
namespace jit {
namespace detail {
// Argument type utilities
template <class R, class...>
struct types {
using type = types;
};
template <class Sig>
struct args;
// Method
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
// Const method
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
template <class Sig>
using args_t = typename args<Sig>::type;
} // namespace detail
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
} // namespace jit
} // namespace torch