#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { /// This function is used in conjunction with `class_::def()` to register /// a constructor for a given C++ class type. For example, /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template detail::types init() { return detail::types{}; } /// Entry point for custom C++ class registration. To register a C++ class /// in PyTorch, instantiate `torch::class_` with the desired class as the /// template parameter. Typically, this instantiation should be done in /// the initialization of a global variable, so that the class will be /// made available on dynamic library loading without any additional API /// calls needed. For example, to register a class named Foo, you might /// create a global variable like so: /// /// static auto register_foo = torch::class_("myclasses", "Foo") /// .def("myMethod", &Foo::myMethod) /// .def("lambdaMethod", [](const c10::intrusive_ptr& self) { /// // Do something with `self` /// }); /// /// In addition to registering the class, this registration also chains /// `def()` calls to register methods. `myMethod()` is registered with /// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()` /// is registered with a C++ lambda expression. template class class_ { static_assert(std::is_base_of::value, "torch::class_ requires T to inherit from CustomClassHolder"); public: /// This constructor actually registers the class type. /// String argument `namespaceName` is an identifier for the /// namespace you would like this class to appear in. /// String argument `className` is the name you would like to /// see this class exposed as in Python and TorchScript. For example, if /// you pass `foo` as the namespace name and `Bar` as the className, the /// class will appear as `torch.classes.foo.Bar` in Python and TorchScript explicit class_(const std::string& namespaceName, const std::string& className) { detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(className, "Class name"); qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className; classTypePtr = at::ClassType::create( c10::QualifiedName(qualClassName), std::weak_ptr()); classTypePtr->addAttribute("capsule", at::CapsuleType::get()); c10::getCustomClassTypeMap().insert( {typeid(c10::intrusive_ptr).name(), classTypePtr}); c10::getCustomClassTypeMap().insert( {typeid(c10::tagged_capsule).name(), classTypePtr}); registerCustomClass(classTypePtr); } /// def() can be used in conjunction with `torch::init()` to register /// a constructor for a given C++ class type. For example, passing /// `torch::init()` would register a two-argument constructor /// taking an `int` and a `std::string` as argument. template class_& def(detail::types) { // Used in combination with // torch::init<...>() auto func = [](c10::tagged_capsule self, Types... args) { auto classObj = c10::make_intrusive(args...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(std::move(classObj))); }; defineMethod("__init__", std::move(func)); return *this; } /// This is the normal method registration API. `name` is the name that /// the method will be made accessible by in Python and TorchScript. /// `f` is a callable object that defines the method. Typically `f` /// will either be a pointer to a method on `CurClass`, or a lambda /// expression that takes a `c10::intrusive_ptr` as the first /// argument (emulating a `this` argument in a C++ method.) /// /// Examples: /// /// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in /// // Python and TorchScript /// .def("call_foo", &Foo::foo) /// /// // Exposes the given lambda expression as method `call_lambda()` /// // in Python and TorchScript. /// .def("call_lambda", [](const c10::intrusive_ptr& self) { /// // do something /// }) template class_& def(std::string name, Func f) { auto wrapped_f = detail::wrap_func(std::move(f)); defineMethod(std::move(name), std::move(wrapped_f)); return *this; } /// This is an unsafe method registration API added for adding custom JIT backend support via custom /// C++ classes. It is not for general purpose use. class_& _def_unboxed(std::string name, std::function func, c10::FunctionSchema schema) { auto qualMethodName = qualClassName + "." + name; auto method = std::make_unique( qualMethodName, std::move(schema), std::move(func)); classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); return *this; } /// def_pickle() is used to define exactly what state gets serialized /// or deserialized for a given instance of a custom C++ class in /// Python or TorchScript. This protocol is equivalent to the Pickle /// concept of `__getstate__` and `__setstate__` from Python /// (https://docs.python.org/2/library/pickle.html#object.__getstate__) /// /// Currently, both the `get_state` and `set_state` callables must be /// C++ lambda expressions. They should have the following signatures, /// where `CurClass` is the class you're registering and `T` is some object /// that encapsulates the state of the object. /// /// __getstate__(intrusive_ptr) -> T /// __setstate__(T) -> intrusive_ptr /// /// `T` must be an object that is convertable to IValue by the same rules /// for custom op/method registration. /// /// Example: /// /// .def_pickle( /// // __getstate__ /// [](const c10::intrusive_ptr>& self) { /// return self->stack_; /// }, /// [](std::vector state) { // __setstate__ /// return c10::make_intrusive>( /// std::vector{"i", "was", "deserialized"}); /// }) template class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) { static_assert( c10::guts::is_stateless_lambda>::value && c10::guts::is_stateless_lambda>::value, "def_pickle() currently only supports lambdas as " "__getstate__ and __setstate__ arguments."); def("__getstate__", std::forward(get_state)); // __setstate__ needs to be registered with some custom handling: // We need to wrap the invocation of of the user-provided function // such that we take the return value (i.e. c10::intrusive_ptr) // and assign it to the `capsule` attribute. using SetStateTraits = c10::guts::infer_function_traits_t>; using SetStateArg = typename c10::guts::typelist::head_t< typename SetStateTraits::parameter_types>; auto setstate_wrapper = [set_state = std::move(set_state)]( c10::tagged_capsule self, SetStateArg&& arg) { c10::intrusive_ptr classObj = at::guts::invoke(set_state, std::forward(arg)); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; defineMethod( "__setstate__", detail::wrap_func( std::move(setstate_wrapper))); // type validation auto getstate_schema = classTypePtr->getMethod("__getstate__").getSchema(); auto format_getstate_schema = [&getstate_schema]() { std::stringstream ss; ss << getstate_schema; return ss.str(); }; TORCH_CHECK( getstate_schema.arguments().size() == 1, "__getstate__ should take exactly one argument: self. Got: ", format_getstate_schema()); auto first_arg_type = getstate_schema.arguments().at(0).type(); TORCH_CHECK( *first_arg_type == *classTypePtr, "self argument of __getstate__ must be the custom class type. Got ", first_arg_type->repr_str()); TORCH_CHECK( getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", format_getstate_schema()); auto ser_type = getstate_schema.returns().at(0).type(); auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( (*arg_type == *ser_type), "__setstate__'s argument should be the same type as the " "return value of __getstate__. Got ", arg_type->repr_str(), " but expected ", ser_type->repr_str()); return *this; } private: template void defineMethod(std::string name, Func func) { auto qualMethodName = qualClassName + "." + name; auto schema = c10::inferFunctionSchemaSingleReturn(std::move(name), ""); auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void { // TODO: we need to figure out how to profile calls to custom functions // like this! Currently can't do it because the profiler stuff is in // libtorch and not ATen using RetType = typename c10::guts::infer_function_traits_t::return_type; detail::BoxedProxy()(stack, func); }; auto method = std::make_unique( qualMethodName, std::move(schema), std::move(wrapped_func)); // Register the method here to keep the Method alive. // ClassTypes do not hold ownership of their methods (normally it // those are held by the CompilationUnit), so we need a proxy for // that behavior here. classTypePtr->addMethod(method.get()); registerCustomClassMethod(std::move(method)); } std::string qualClassName; at::ClassTypePtr classTypePtr; }; /// make_custom_class() is a convenient way to create an instance of a registered /// custom class and wrap it in an IValue, for example when you want to pass the /// object to TorchScript. Its syntax is equivalent to APIs like `std::make_shared<>` /// or `c10::make_intrusive<>`. /// /// For example, if you have a custom C++ class that can be constructed from an `int` /// and `std::string`, you might use this API like so: /// /// IValue custom_class_iv = torch::make_custom_class(3, "foobarbaz"); template c10::IValue make_custom_class(CtorArgs&&... args) { if (!c10::isCustomClassRegistered>()) { throw c10::Error( "Trying to instantiate a class that isn't a registered custom class.", ""); } auto userClassInstance = c10::make_intrusive(std::forward(args)...); return c10::IValue(std::move(userClassInstance)); } // jit namespace for backward-compatibility // We previously defined everything in torch::jit but moved it out to // better reflect that these features are not limited only to TorchScript namespace jit { using ::torch::getCustomClass; using ::torch::isCustomClass; using ::torch::init; using ::torch::class_; } // namespace jit template inline class_ Library::class_(const std::string& className) { TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, "class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. " "All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. " "(Error occurred at ", file_, ":", line_, ")"); TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_); return torch::class_(*ns_, className); } }