mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torchbind] Add more comprehensive docscrings (#34906)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34906 Test Plan: Imported from OSS Differential Revision: D20496221 Pulled By: jamesr66a fbshipit-source-id: 3863ec77324564f6f0f1c54b0cbd6c29d12f3c74
This commit is contained in:
committed by
Facebook GitHub Bot
parent
09a7788a2f
commit
130e720784
@ -17,45 +17,55 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
|
||||
// the ClassType pointer to the Type that describes that custom class,
|
||||
// or nullptr if no class by that name was found.
|
||||
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
|
||||
|
||||
// Given an IValue, return true if the object contained in that IValue
|
||||
// is a custom C++ class, otherwise return false.
|
||||
TORCH_API bool isCustomClass(const c10::IValue& v);
|
||||
|
||||
// This function is used in conjunction with `class_::def()` to register
|
||||
// a constructor for a given C++ class type. For example,
|
||||
// torch::init<int, std::string>() would register a two-argument constructor
|
||||
// taking an int and a std::string as argument.
|
||||
template <class... Types>
|
||||
detail::types<void, Types...> init() {
|
||||
return detail::types<void, Types...>{};
|
||||
}
|
||||
|
||||
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
||||
// Currently exposes one class `torch::class_<T>` and 2 methods.
|
||||
// - Constructing `torch::class_<Foo>` registers `Foo` in Python and
|
||||
// Torchscript, and puts it under `torch.classes.Foo` in Python.
|
||||
// - torch::class_<Foo>.def("method1", &Foo::method1) does some template
|
||||
// metaprogramming to introspect the function types and register the operator
|
||||
// for use in Torchscript.
|
||||
// - torch::class_<Foo>.def(torch::init<int64_t, int64_t>()) registers
|
||||
// the Foo(int, int) constructor.
|
||||
// see test/custom_operator/classes.cpp and
|
||||
// test/custom_operator/test_custom_classes.py for example usages
|
||||
|
||||
// Entry point for custom C++ class registration. To register a C++ class
|
||||
// in PyTorch, instantiate `torch::class_` with the desired class as the
|
||||
// template parameter. Typically, this instantiation should be done in
|
||||
// the initialization of a global variable, so that the class will be
|
||||
// made available on dynamic library loading without any additional API
|
||||
// calls needed. For example, to register a class named Foo, you might
|
||||
// create a global variable like so:
|
||||
//
|
||||
// static auto register_foo = torch::class_<Foo>("Foo")
|
||||
// .def("myMethod", &Foo::myMethod)
|
||||
// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
|
||||
// // Do something with `self`
|
||||
// });
|
||||
//
|
||||
// In addition to registering the class, this registration also chains
|
||||
// `def()` calls to register methods. `myMethod()` is registered with
|
||||
// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
|
||||
// is registered with a C++ lambda expression.
|
||||
template <class CurClass>
|
||||
class class_ {
|
||||
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
|
||||
"torch::class_<T> requires T to inherit from CustomClassHolder");
|
||||
|
||||
std::string className;
|
||||
std::string qualClassName;
|
||||
at::ClassTypePtr classTypePtr;
|
||||
|
||||
const std::string parentModule = "classes";
|
||||
const std::string topModule = "__torch__.torch";
|
||||
|
||||
public:
|
||||
// Constructor. String argument className_ is the name you would like to
|
||||
// see this class exposed as in Python and TorchScript. For example, if
|
||||
// you pass in "MyStack" here, the class will appear as
|
||||
// `torch.classes.MyStack` in both Python and TorchScript.
|
||||
class_(std::string className_) : className(std::move(className_)) {
|
||||
qualClassName = topModule + "." + parentModule + "." + className;
|
||||
|
||||
// We currently represent custom classes as torchscript classes with a
|
||||
// capsule attribute
|
||||
classTypePtr = at::ClassType::create(
|
||||
c10::QualifiedName(qualClassName),
|
||||
std::weak_ptr<jit::CompilationUnit>());
|
||||
@ -69,6 +79,10 @@ class class_ {
|
||||
registerCustomClass(classTypePtr);
|
||||
}
|
||||
|
||||
// def() can be used in conjunction with `torch::init()` to register
|
||||
// a constructor for a given C++ class type. For example, passing
|
||||
// torch::init<int, std::string>() would register a two-argument constructor
|
||||
// taking an int and a std::string as argument.
|
||||
template <typename... Types>
|
||||
class_& def(detail::types<void, Types...>) { // Used in combination with
|
||||
// torch::init<...>()
|
||||
@ -81,6 +95,24 @@ class class_ {
|
||||
defineMethod("__init__", std::move(func));
|
||||
return *this;
|
||||
}
|
||||
|
||||
// This is the normal method registration API. `name` is the name that
|
||||
// the method will be made accessible by in Python and TorchScript.
|
||||
// `f` is a callable object that defines the method. Typically `f`
|
||||
// will either be a pointer to a method on `CurClass`, or a lambda
|
||||
// expression that takes a c10::intrusive_ptr<CurClass> as the first
|
||||
// argument (emulating a `this` argument in a C++ method.)
|
||||
//
|
||||
// Examples:
|
||||
// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
|
||||
// // Python and TorchScript
|
||||
// .def("call_foo", &Foo::foo)
|
||||
//
|
||||
// // Exposes the given lambda expression as method `call_lambda()`
|
||||
// // in Python and TorchScript.
|
||||
// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
|
||||
// // do something
|
||||
// })
|
||||
template <typename Func>
|
||||
class_& def(std::string name, Func f) {
|
||||
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
||||
@ -88,7 +120,22 @@ class class_ {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Pickle
|
||||
// def_pickle() is used to define exactly what state gets serialized
|
||||
// or deserialized for a given instance of a custom C++ class in
|
||||
// Python or TorchScript. This protocol is equivalent to the Pickle
|
||||
// concept of __getstate__ and __setstate__ from Python
|
||||
// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
|
||||
//
|
||||
// Currently, both the `get_state` and `set_state` callables must be
|
||||
// C++ lambda expressions. They should have the following signatures,
|
||||
// where CurClass is the class you're registering and T is some object
|
||||
// that encapsulates the state of the object.
|
||||
//
|
||||
// __getstate__(intrusive_ptr<CurClass>) -> T
|
||||
// __setstate__(T) -> intrusive_ptr<CurClass>
|
||||
//
|
||||
// T must be an object that is convertable to IValue by the same rules
|
||||
// for custom op/method registration.
|
||||
template <typename GetStateFn, typename SetStateFn>
|
||||
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
|
||||
static_assert(
|
||||
@ -177,8 +224,24 @@ class class_ {
|
||||
registerCustomClassMethod(method);
|
||||
classTypePtr->addMethod(method.get());
|
||||
}
|
||||
|
||||
std::string className;
|
||||
std::string qualClassName;
|
||||
at::ClassTypePtr classTypePtr;
|
||||
|
||||
const std::string parentModule = "classes";
|
||||
const std::string topModule = "__torch__.torch";
|
||||
};
|
||||
|
||||
// make_custom_class() is a convenient way to create an instance of a registered
|
||||
// custom class and wrap it in an IValue, for example when you want to pass the
|
||||
// object to TorchScript. Its syntax is equivalent to APIs like std::make_shared<>
|
||||
// or c10::make_intrusive<>.
|
||||
//
|
||||
// For example, if you have a custom C++ class that can be constructed from an int
|
||||
// and std::string, you might use this API like so:
|
||||
//
|
||||
// IValue custom_class_iv = torch::make_custom_class<MyClass>(3, "foobarbaz");
|
||||
template <typename CurClass, typename... CtorArgs>
|
||||
c10::IValue make_custom_class(CtorArgs&&... args) {
|
||||
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<CurClass>>()) {
|
||||
|
Reference in New Issue
Block a user