mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[JIT] pickle serialization for custom bound classes
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32604 Test Plan: Imported from OSS Differential Revision: D19566633 fbshipit-source-id: 9387d3ff45cbd6ccde49ce190a52859481cc301c
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							34ccfba403
						
					
				
				
					commit
					465ebd58ba
				
			| @ -129,7 +129,5 @@ struct is_type_condition<C, std::enable_if_t<std::is_same<bool, std::remove_cv_t | ||||
|  */ | ||||
| template <class T> | ||||
| struct is_fundamental : std::is_fundamental<T> {}; | ||||
|  | ||||
|  | ||||
| } | ||||
| } | ||||
|  | ||||
| @ -753,7 +753,7 @@ ENDIF() | ||||
|   install(DIRECTORY "${TORCH_SRC_DIR}/csrc" | ||||
|     DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch | ||||
|     FILES_MATCHING PATTERN "*.h") | ||||
|   install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" | ||||
|   install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h" | ||||
|     DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -55,19 +55,16 @@ struct Stack : torch::jit::CustomClassHolder { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   std::vector<std::string> __getstate__() const { | ||||
|     return stack_; | ||||
|   } | ||||
|  | ||||
|   void __setstate__(std::vector<std::string> state) { | ||||
|     stack_ = std::move(state); | ||||
|   } | ||||
|  | ||||
|   std::tuple<double, int64_t> return_a_tuple() const { | ||||
|     return std::make_tuple(1337.0f, 123); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| struct PickleTester : torch::jit::CustomClassHolder { | ||||
|   PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {} | ||||
|   std::vector<int64_t> vals; | ||||
| }; | ||||
|  | ||||
| static auto test = torch::jit::class_<Foo>("_TorchScriptTesting_Foo") | ||||
|                        .def(torch::jit::init<int64_t, int64_t>()) | ||||
|                        // .def(torch::jit::init<>()) | ||||
| @ -83,8 +80,14 @@ static auto testStack = | ||||
|         .def("pop", &Stack<std::string>::pop) | ||||
|         .def("clone", &Stack<std::string>::clone) | ||||
|         .def("merge", &Stack<std::string>::merge) | ||||
|         .def("__getstate__", &Stack<std::string>::__getstate__) | ||||
|         .def("__setstate__", &Stack<std::string>::__setstate__) | ||||
|         .def_pickle( | ||||
|             [](const c10::intrusive_ptr<Stack<std::string>>& self) { | ||||
|               return self->stack_; | ||||
|             }, | ||||
|             [](std::vector<std::string> state) { // __setstate__ | ||||
|               return c10::make_intrusive<Stack<std::string>>( | ||||
|                   std::vector<std::string>{"i", "was", "deserialized"}); | ||||
|             }) | ||||
|         .def("return_a_tuple", &Stack<std::string>::return_a_tuple) | ||||
|         .def( | ||||
|             "top", | ||||
| @ -95,6 +98,28 @@ static auto testStack = | ||||
|         // take an intrusive_ptr<Stack> as the first argument. | ||||
|         // .def("foo", [](int64_t a) -> int64_t{ return 3;}); | ||||
| // clang-format on | ||||
|  | ||||
| static auto testPickle = | ||||
|     torch::jit::class_<PickleTester>("_TorchScriptTesting_PickleTester") | ||||
|         .def(torch::jit::init<std::vector<int64_t>>()) | ||||
|         .def_pickle( | ||||
|             [](c10::intrusive_ptr<PickleTester> self) { // __getstate__ | ||||
|               return std::vector<int64_t>{1, 3, 3, 7}; | ||||
|             }, | ||||
|             [](std::vector<int64_t> state) { // __setstate__ | ||||
|               return c10::make_intrusive<PickleTester>(std::move(state)); | ||||
|             }) | ||||
|         .def( | ||||
|             "top", | ||||
|             [](const c10::intrusive_ptr<PickleTester>& self) { | ||||
|               return self->vals.back(); | ||||
|             }) | ||||
|         .def("pop", [](const c10::intrusive_ptr<PickleTester>& self) { | ||||
|           auto val = self->vals.back(); | ||||
|           self->vals.pop_back(); | ||||
|           return val; | ||||
|         }); | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| } // namespace jit | ||||
|  | ||||
| @ -4932,22 +4932,6 @@ def foo(x): | ||||
|         self.assertEqual(out.pop(), "hi") | ||||
|         self.assertEqual(out.pop(), "mom") | ||||
|  | ||||
|     @skipIfRocm | ||||
|     @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") | ||||
|     def test_torchbind_getstate_setstate(self): | ||||
|         def f(): | ||||
|             val = torch.classes._TorchScriptTesting_StackString(["3", "5"]) | ||||
|             s = val.__getstate__() | ||||
|             # TODO: sort out whether unpickler should call __new__ or __init__ | ||||
|             val2 = torch.classes._TorchScriptTesting_StackString(["0", "0"]) | ||||
|             val2.__setstate__(s) | ||||
|             return val.pop(), val2.pop() | ||||
|         ret = f() | ||||
|         self.assertEqual(ret[0], ret[1]) | ||||
|  | ||||
|         ret = torch.jit.script(f)() | ||||
|         self.assertEqual(ret[0], ret[1]) | ||||
|  | ||||
|     @skipIfRocm | ||||
|     @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") | ||||
|     def test_torchbind_return_tuple(self): | ||||
| @ -5000,6 +4984,47 @@ def foo(x): | ||||
|         scripted = torch.jit.script(foo) | ||||
|         self.assertEqual(scripted(), "mom") | ||||
|  | ||||
|     @skipIfRocm | ||||
|     @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") | ||||
|     def test_torchbind_class_attribute(self): | ||||
|         class FooBar1234(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super(FooBar1234, self).__init__() | ||||
|                 self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"]) | ||||
|  | ||||
|             def forward(self): | ||||
|                 return self.f.top() | ||||
|  | ||||
|         inst = FooBar1234() | ||||
|         scripted = torch.jit.script(inst) | ||||
|         eic = self.getExportImportCopy(scripted) | ||||
|         assert eic() == "deserialized" | ||||
|         for expected in ["deserialized", "was", "i"]: | ||||
|             assert eic.f.pop() == expected | ||||
|  | ||||
|     @skipIfRocm | ||||
|     @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case") | ||||
|     def test_torchbind_getstate(self): | ||||
|         class FooBar4321(torch.nn.Module): | ||||
|             def __init__(self): | ||||
|                 super(FooBar4321, self).__init__() | ||||
|                 self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4]) | ||||
|  | ||||
|             def forward(self): | ||||
|                 return self.f.top() | ||||
|  | ||||
|         inst = FooBar4321() | ||||
|         scripted = torch.jit.script(inst) | ||||
|         eic = self.getExportImportCopy(scripted) | ||||
|         # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to | ||||
|         # return {1, 3, 3, 7}. I tried to make this actually depend on the | ||||
|         # values at instantiation in the test with some transformation, but | ||||
|         # because it seems we serialize/deserialize multiple times, that | ||||
|         # transformation isn't as you would it expect it to be. | ||||
|         assert eic() == 7 | ||||
|         for expected in [7, 3, 3, 1]: | ||||
|             assert eic.f.pop() == expected | ||||
|  | ||||
|     def test_jitter_bug(self): | ||||
|         @torch.jit.script | ||||
|         def fn2(input, kernel_size): | ||||
|  | ||||
| @ -115,6 +115,17 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { | ||||
|     push<PickleOpCode>(PickleOpCode::BUILD); | ||||
|   } else if (ivalue.isDevice()) { | ||||
|     pushDevice(ivalue); | ||||
|   } else if (ivalue.isCapsule()) { | ||||
|     std::stringstream err; | ||||
|     err << "Cannot serialize custom bound C++ class"; | ||||
|     if (memorized_class_types_ && memorized_class_types_->size()) { | ||||
|       if (auto qualname = memorized_class_types_->back()->name()) { | ||||
|         err << " " << qualname->qualifiedName(); | ||||
|       } | ||||
|     } | ||||
|     err << ". Please define serialization methods via torch::jit::pickle_ for " | ||||
|            "this class."; | ||||
|     AT_ERROR(err.str()); | ||||
|   } else { | ||||
|     AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); | ||||
|   } | ||||
|  | ||||
| @ -170,6 +170,11 @@ inline InferredType tryToInferType(py::handle input) { | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (py::isinstance<script::Object>(input)) { | ||||
|     auto object = py::cast<script::Object>(input); | ||||
|     return InferredType(object.type()); | ||||
|   } | ||||
|  | ||||
|   // Try container types | ||||
|   return tryToInferContainerType(input); | ||||
| } | ||||
|  | ||||
| @ -15,6 +15,7 @@ | ||||
| #include <torch/csrc/jit/script/compilation_unit.h> | ||||
| #include <torch/csrc/jit/tracer.h> | ||||
| #include <torch/csrc/utils/variadic.h> | ||||
| #include <torch/custom_class_detail.h> | ||||
| #include <iostream> | ||||
| #include <sstream> | ||||
|  | ||||
| @ -22,25 +23,10 @@ | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| TORCH_API std::vector<c10::RegisterOperators>& registeredOps(); | ||||
| TORCH_API std::shared_ptr<script::CompilationUnit>& classCU(); | ||||
|  | ||||
| namespace detail { | ||||
| template <class R, class...> | ||||
| struct types { | ||||
|   using type = types; | ||||
| }; | ||||
| template <class Sig> | ||||
| struct args; | ||||
| template <class R, class CurClass, class... Args> | ||||
| struct args<R (CurClass::*)(Args...)> : types<R, Args...> {}; | ||||
| template <class R, class CurClass, class... Args> | ||||
| struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {}; | ||||
| template <class Sig> | ||||
| using args_t = typename args<Sig>::type; | ||||
| } // namespace detail | ||||
| template <class... Types> | ||||
| detail::types<void, Types...> init() { return detail::types<void, Types...>{}; } | ||||
| detail::types<void, Types...> init() { | ||||
|   return detail::types<void, Types...>{}; | ||||
| } | ||||
|  | ||||
| // To bind custom classes into Torchscript, use an API very similar to Pybind's. | ||||
| // Currently exposes one class `torch::jit::class_<T>` and 2 methods. | ||||
| @ -68,10 +54,6 @@ class class_ { | ||||
|  | ||||
|  public: | ||||
|   class_(std::string className_) : className(std::move(className_)) { | ||||
|     // Currently we register everything as a python class just for convenience. | ||||
|     // We'll want to remove this at some point to get rid of the python | ||||
|     // dependency. It would require significant changes to class registration, | ||||
|     // (I think)? | ||||
|     qualClassName = topModule + "." + parentModule + "." + className; | ||||
|  | ||||
|     // We currently represent custom classes as torchscript classes with a | ||||
| @ -127,6 +109,72 @@ class class_ { | ||||
|     return *this; | ||||
|   } | ||||
|  | ||||
|   // Pickle | ||||
|   template <typename GetStateFn, typename SetStateFn> | ||||
|   class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { | ||||
|     static_assert( | ||||
|         c10::guts::is_stateless_lambda<std::decay_t<GetStateFn>>::value && | ||||
|             c10::guts::is_stateless_lambda<std::decay_t<SetStateFn>>::value, | ||||
|         "torch::jit::pickle_ currently only supports lambdas as " | ||||
|         "__getstate__ and __setstate__ arguments."); | ||||
|     def("__getstate__", std::forward<GetStateFn>(get_state)); | ||||
|  | ||||
|     // __setstate__ needs to be registered with some custom handling: | ||||
|     // We need to wrap the invocation of of the user-provided function | ||||
|     // such that we take the return value (i.e. c10::intrusive_ptr<CurrClass>) | ||||
|     // and assign it to the `capsule` attribute. | ||||
|     using SetStateTraits = | ||||
|         c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>; | ||||
|     using SetStateArg = typename c10::guts::typelist::head_t< | ||||
|         typename SetStateTraits::parameter_types>; | ||||
|     auto setstate_wrapper = [set_state = std::move(set_state)]( | ||||
|                                 c10::tagged_capsule<CurClass> self, | ||||
|                                 SetStateArg&& arg) { | ||||
|       c10::intrusive_ptr<CurClass> classObj = | ||||
|           at::guts::invoke(set_state, std::forward<SetStateArg>(arg)); | ||||
|       auto genericPtr = | ||||
|           c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>( | ||||
|               classObj); | ||||
|       auto capsule = IValue(genericPtr); | ||||
|       auto object = self.ivalue.toObject(); | ||||
|       object->setSlot(0, capsule); | ||||
|     }; | ||||
|     defineMethod<void>("__setstate__", std::move(setstate_wrapper)); | ||||
|  | ||||
|     // type validation | ||||
|     auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema(); | ||||
|     auto format_getstate_schema = [&getstate_schema]() { | ||||
|       std::stringstream ss; | ||||
|       ss << getstate_schema; | ||||
|       return ss.str(); | ||||
|     }; | ||||
|     TORCH_CHECK( | ||||
|         getstate_schema.arguments().size() == 1, | ||||
|         "__getstate__ should take exactly one argument: self. Got: ", | ||||
|         format_getstate_schema()); | ||||
|     auto first_arg_type = getstate_schema.arguments().at(0).type(); | ||||
|     TORCH_CHECK( | ||||
|         *first_arg_type == *classTypePtr, | ||||
|         "self argument of __getstate__ must be the custom class type. Got ", | ||||
|         first_arg_type->python_str()); | ||||
|     TORCH_CHECK( | ||||
|         getstate_schema.returns().size() == 1, | ||||
|         "__getstate__ should return exactly one value for serialization. Got: ", | ||||
|         format_getstate_schema()); | ||||
|     auto ser_type = getstate_schema.returns().at(0).type(); | ||||
|     auto setstate_schema = classTypePtr->getMethod("__setstate__")->getSchema(); | ||||
|     auto arg_type = setstate_schema.arguments().at(1).type(); | ||||
|     TORCH_CHECK( | ||||
|         (*arg_type == *ser_type), | ||||
|         "__setstate__'s argument should be the same type as the " | ||||
|         "return value of __getstate__. Got ", | ||||
|         arg_type->python_str(), | ||||
|         " but expected ", | ||||
|         ser_type->python_str()); | ||||
|  | ||||
|     return *this; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   template<typename R, typename Func> | ||||
|   void defineMethod(std::string name, Func func) { | ||||
|  | ||||
							
								
								
									
										37
									
								
								torch/custom_class_detail.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								torch/custom_class_detail.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/util/Metaprogramming.h> | ||||
| #include <c10/util/TypeTraits.h> | ||||
|  | ||||
| namespace torch { | ||||
| namespace jit { | ||||
|  | ||||
| namespace detail { | ||||
|  | ||||
| // Argument type utilities | ||||
| template <class R, class...> | ||||
| struct types { | ||||
|   using type = types; | ||||
| }; | ||||
|  | ||||
| template <class Sig> | ||||
| struct args; | ||||
|  | ||||
| // Method | ||||
| template <class R, class CurClass, class... Args> | ||||
| struct args<R (CurClass::*)(Args...)> : types<R, Args...> {}; | ||||
|  | ||||
| // Const method | ||||
| template <class R, class CurClass, class... Args> | ||||
| struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {}; | ||||
|  | ||||
| template <class Sig> | ||||
| using args_t = typename args<Sig>::type; | ||||
|  | ||||
| } // namespace detail | ||||
|  | ||||
| TORCH_API std::vector<c10::RegisterOperators>& registeredOps(); | ||||
| TORCH_API std::shared_ptr<script::CompilationUnit>& classCU(); | ||||
|  | ||||
| } // namespace jit | ||||
| } // namespace torch | ||||
		Reference in New Issue
	
	Block a user