diff --git a/c10/util/TypeTraits.h b/c10/util/TypeTraits.h index f0faa7f0c169..a5756dbc8cc5 100644 --- a/c10/util/TypeTraits.h +++ b/c10/util/TypeTraits.h @@ -129,7 +129,5 @@ struct is_type_condition struct is_fundamental : std::is_fundamental {}; - - } } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index bcfdc7145d11..99756918bc5e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -753,7 +753,7 @@ ENDIF() install(DIRECTORY "${TORCH_SRC_DIR}/csrc" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch FILES_MATCHING PATTERN "*.h") - install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" + install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h" DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index f1b46b54e5ef..cffc987c300a 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -55,19 +55,16 @@ struct Stack : torch::jit::CustomClassHolder { } } - std::vector __getstate__() const { - return stack_; - } - - void __setstate__(std::vector state) { - stack_ = std::move(state); - } - std::tuple return_a_tuple() const { return std::make_tuple(1337.0f, 123); } }; +struct PickleTester : torch::jit::CustomClassHolder { + PickleTester(std::vector vals) : vals(std::move(vals)) {} + std::vector vals; +}; + static auto test = torch::jit::class_("_TorchScriptTesting_Foo") .def(torch::jit::init()) // .def(torch::jit::init<>()) @@ -83,8 +80,14 @@ static auto testStack = .def("pop", &Stack::pop) .def("clone", &Stack::clone) .def("merge", &Stack::merge) - .def("__getstate__", &Stack::__getstate__) - .def("__setstate__", &Stack::__setstate__) + .def_pickle( + [](const c10::intrusive_ptr>& self) { + return self->stack_; + }, + [](std::vector state) { // __setstate__ + return c10::make_intrusive>( + std::vector{"i", "was", "deserialized"}); + }) .def("return_a_tuple", &Stack::return_a_tuple) .def( "top", @@ -95,6 +98,28 @@ static auto testStack = // take an intrusive_ptr as the first argument. // .def("foo", [](int64_t a) -> int64_t{ return 3;}); // clang-format on + +static auto testPickle = + torch::jit::class_("_TorchScriptTesting_PickleTester") + .def(torch::jit::init>()) + .def_pickle( + [](c10::intrusive_ptr self) { // __getstate__ + return std::vector{1, 3, 3, 7}; + }, + [](std::vector state) { // __setstate__ + return c10::make_intrusive(std::move(state)); + }) + .def( + "top", + [](const c10::intrusive_ptr& self) { + return self->vals.back(); + }) + .def("pop", [](const c10::intrusive_ptr& self) { + auto val = self->vals.back(); + self->vals.pop_back(); + return val; + }); + } // namespace } // namespace jit diff --git a/test/test_jit.py b/test/test_jit.py index 26156d198bf9..b49bf2dbdcd0 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4932,22 +4932,6 @@ def foo(x): self.assertEqual(out.pop(), "hi") self.assertEqual(out.pop(), "mom") - @skipIfRocm - @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") - def test_torchbind_getstate_setstate(self): - def f(): - val = torch.classes._TorchScriptTesting_StackString(["3", "5"]) - s = val.__getstate__() - # TODO: sort out whether unpickler should call __new__ or __init__ - val2 = torch.classes._TorchScriptTesting_StackString(["0", "0"]) - val2.__setstate__(s) - return val.pop(), val2.pop() - ret = f() - self.assertEqual(ret[0], ret[1]) - - ret = torch.jit.script(f)() - self.assertEqual(ret[0], ret[1]) - @skipIfRocm @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") def test_torchbind_return_tuple(self): @@ -5000,6 +4984,47 @@ def foo(x): scripted = torch.jit.script(foo) self.assertEqual(scripted(), "mom") + @skipIfRocm + @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") + def test_torchbind_class_attribute(self): + class FooBar1234(torch.nn.Module): + def __init__(self): + super(FooBar1234, self).__init__() + self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"]) + + def forward(self): + return self.f.top() + + inst = FooBar1234() + scripted = torch.jit.script(inst) + eic = self.getExportImportCopy(scripted) + assert eic() == "deserialized" + for expected in ["deserialized", "was", "i"]: + assert eic.f.pop() == expected + + @skipIfRocm + @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") + def test_torchbind_getstate(self): + class FooBar4321(torch.nn.Module): + def __init__(self): + super(FooBar4321, self).__init__() + self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4]) + + def forward(self): + return self.f.top() + + inst = FooBar4321() + scripted = torch.jit.script(inst) + eic = self.getExportImportCopy(scripted) + # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to + # return {1, 3, 3, 7}. I tried to make this actually depend on the + # values at instantiation in the test with some transformation, but + # because it seems we serialize/deserialize multiple times, that + # transformation isn't as you would it expect it to be. + assert eic() == 7 + for expected in [7, 3, 3, 1]: + assert eic.f.pop() == expected + def test_jitter_bug(self): @torch.jit.script def fn2(input, kernel_size): diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp index e98799e476a7..9d9e564f1a74 100644 --- a/torch/csrc/jit/pickler.cpp +++ b/torch/csrc/jit/pickler.cpp @@ -115,6 +115,17 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { push(PickleOpCode::BUILD); } else if (ivalue.isDevice()) { pushDevice(ivalue); + } else if (ivalue.isCapsule()) { + std::stringstream err; + err << "Cannot serialize custom bound C++ class"; + if (memorized_class_types_ && memorized_class_types_->size()) { + if (auto qualname = memorized_class_types_->back()->name()) { + err << " " << qualname->qualifiedName(); + } + } + err << ". Please define serialization methods via torch::jit::pickle_ for " + "this class."; + AT_ERROR(err.str()); } else { AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); } diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 15ee12c1dce8..7929eb25d616 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -170,6 +170,11 @@ inline InferredType tryToInferType(py::handle input) { } } + if (py::isinstance(input)) { + auto object = py::cast(input); + return InferredType(object.type()); + } + // Try container types return tryToInferContainerType(input); } diff --git a/torch/custom_class.h b/torch/custom_class.h index c052063b0850..20083679b771 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -22,25 +23,10 @@ namespace torch { namespace jit { -TORCH_API std::vector& registeredOps(); -TORCH_API std::shared_ptr& classCU(); - -namespace detail { -template -struct types { - using type = types; -}; -template -struct args; -template -struct args : types {}; -template -struct args : types {}; -template -using args_t = typename args::type; -} // namespace detail template -detail::types init() { return detail::types{}; } +detail::types init() { + return detail::types{}; +} // To bind custom classes into Torchscript, use an API very similar to Pybind's. // Currently exposes one class `torch::jit::class_` and 2 methods. @@ -68,10 +54,6 @@ class class_ { public: class_(std::string className_) : className(std::move(className_)) { - // Currently we register everything as a python class just for convenience. - // We'll want to remove this at some point to get rid of the python - // dependency. It would require significant changes to class registration, - // (I think)? qualClassName = topModule + "." + parentModule + "." + className; // We currently represent custom classes as torchscript classes with a @@ -127,6 +109,72 @@ class class_ { return *this; } + // Pickle + template + class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { + static_assert( + c10::guts::is_stateless_lambda>::value && + c10::guts::is_stateless_lambda>::value, + "torch::jit::pickle_ currently only supports lambdas as " + "__getstate__ and __setstate__ arguments."); + def("__getstate__", std::forward(get_state)); + + // __setstate__ needs to be registered with some custom handling: + // We need to wrap the invocation of of the user-provided function + // such that we take the return value (i.e. c10::intrusive_ptr) + // and assign it to the `capsule` attribute. + using SetStateTraits = + c10::guts::infer_function_traits_t>; + using SetStateArg = typename c10::guts::typelist::head_t< + typename SetStateTraits::parameter_types>; + auto setstate_wrapper = [set_state = std::move(set_state)]( + c10::tagged_capsule self, + SetStateArg&& arg) { + c10::intrusive_ptr classObj = + at::guts::invoke(set_state, std::forward(arg)); + auto genericPtr = + c10::static_intrusive_pointer_cast( + classObj); + auto capsule = IValue(genericPtr); + auto object = self.ivalue.toObject(); + object->setSlot(0, capsule); + }; + defineMethod("__setstate__", std::move(setstate_wrapper)); + + // type validation + auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema(); + auto format_getstate_schema = [&getstate_schema]() { + std::stringstream ss; + ss << getstate_schema; + return ss.str(); + }; + TORCH_CHECK( + getstate_schema.arguments().size() == 1, + "__getstate__ should take exactly one argument: self. Got: ", + format_getstate_schema()); + auto first_arg_type = getstate_schema.arguments().at(0).type(); + TORCH_CHECK( + *first_arg_type == *classTypePtr, + "self argument of __getstate__ must be the custom class type. Got ", + first_arg_type->python_str()); + TORCH_CHECK( + getstate_schema.returns().size() == 1, + "__getstate__ should return exactly one value for serialization. Got: ", + format_getstate_schema()); + auto ser_type = getstate_schema.returns().at(0).type(); + auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema(); + auto arg_type = setstate_schema.arguments().at(1).type(); + TORCH_CHECK( + (*arg_type == *ser_type), + "__setstate__'s argument should be the same type as the " + "return value of __getstate__. Got ", + arg_type->python_str(), + " but expected ", + ser_type->python_str()); + + return *this; + } + private: template void defineMethod(std::string name, Func func) { diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h new file mode 100644 index 000000000000..c8e5e0b66d1a --- /dev/null +++ b/torch/custom_class_detail.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { + +namespace detail { + +// Argument type utilities +template +struct types { + using type = types; +}; + +template +struct args; + +// Method +template +struct args : types {}; + +// Const method +template +struct args : types {}; + +template +using args_t = typename args::type; + +} // namespace detail + +TORCH_API std::vector& registeredOps(); +TORCH_API std::shared_ptr& classCU(); + +} // namespace jit +} // namespace torch