[torchbind] Add more comprehensive docscrings (#34906)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34906

Test Plan: Imported from OSS

Differential Revision: D20496221

Pulled By: jamesr66a

fbshipit-source-id: 3863ec77324564f6f0f1c54b0cbd6c29d12f3c74
This commit is contained in:
James Reed
2020-03-17 20:35:58 -07:00
committed by Facebook GitHub Bot
parent 09a7788a2f
commit 130e720784

View File

@ -17,45 +17,55 @@
namespace torch {
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,
// or nullptr if no class by that name was found.
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
// Given an IValue, return true if the object contained in that IValue
// is a custom C++ class, otherwise return false.
TORCH_API bool isCustomClass(const c10::IValue& v);
// This function is used in conjunction with `class_::def()` to register
// a constructor for a given C++ class type. For example,
// torch::init<int, std::string>() would register a two-argument constructor
// taking an int and a std::string as argument.
template <class... Types>
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
// Currently exposes one class `torch::class_<T>` and 2 methods.
// - Constructing `torch::class_<Foo>` registers `Foo` in Python and
// Torchscript, and puts it under `torch.classes.Foo` in Python.
// - torch::class_<Foo>.def("method1", &Foo::method1) does some template
// metaprogramming to introspect the function types and register the operator
// for use in Torchscript.
// - torch::class_<Foo>.def(torch::init<int64_t, int64_t>()) registers
// the Foo(int, int) constructor.
// see test/custom_operator/classes.cpp and
// test/custom_operator/test_custom_classes.py for example usages
// Entry point for custom C++ class registration. To register a C++ class
// in PyTorch, instantiate `torch::class_` with the desired class as the
// template parameter. Typically, this instantiation should be done in
// the initialization of a global variable, so that the class will be
// made available on dynamic library loading without any additional API
// calls needed. For example, to register a class named Foo, you might
// create a global variable like so:
//
// static auto register_foo = torch::class_<Foo>("Foo")
// .def("myMethod", &Foo::myMethod)
// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
// // Do something with `self`
// });
//
// In addition to registering the class, this registration also chains
// `def()` calls to register methods. `myMethod()` is registered with
// a pointer to the Foo class's `myMethod()` method. `lambdaMethod()`
// is registered with a C++ lambda expression.
template <class CurClass>
class class_ {
static_assert(std::is_base_of<CustomClassHolder, CurClass>::value,
"torch::class_<T> requires T to inherit from CustomClassHolder");
std::string className;
std::string qualClassName;
at::ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
public:
// Constructor. String argument className_ is the name you would like to
// see this class exposed as in Python and TorchScript. For example, if
// you pass in "MyStack" here, the class will appear as
// `torch.classes.MyStack` in both Python and TorchScript.
class_(std::string className_) : className(std::move(className_)) {
qualClassName = topModule + "." + parentModule + "." + className;
// We currently represent custom classes as torchscript classes with a
// capsule attribute
classTypePtr = at::ClassType::create(
c10::QualifiedName(qualClassName),
std::weak_ptr<jit::CompilationUnit>());
@ -69,6 +79,10 @@ class class_ {
registerCustomClass(classTypePtr);
}
// def() can be used in conjunction with `torch::init()` to register
// a constructor for a given C++ class type. For example, passing
// torch::init<int, std::string>() would register a two-argument constructor
// taking an int and a std::string as argument.
template <typename... Types>
class_& def(detail::types<void, Types...>) { // Used in combination with
// torch::init<...>()
@ -81,6 +95,24 @@ class class_ {
defineMethod("__init__", std::move(func));
return *this;
}
// This is the normal method registration API. `name` is the name that
// the method will be made accessible by in Python and TorchScript.
// `f` is a callable object that defines the method. Typically `f`
// will either be a pointer to a method on `CurClass`, or a lambda
// expression that takes a c10::intrusive_ptr<CurClass> as the first
// argument (emulating a `this` argument in a C++ method.)
//
// Examples:
// // Exposes method `foo` on C++ class `Foo` as `call_foo()` in
// // Python and TorchScript
// .def("call_foo", &Foo::foo)
//
// // Exposes the given lambda expression as method `call_lambda()`
// // in Python and TorchScript.
// .def("call_lambda", [](const c10::intrusive_ptr<Foo>& self) {
// // do something
// })
template <typename Func>
class_& def(std::string name, Func f) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
@ -88,7 +120,22 @@ class class_ {
return *this;
}
// Pickle
// def_pickle() is used to define exactly what state gets serialized
// or deserialized for a given instance of a custom C++ class in
// Python or TorchScript. This protocol is equivalent to the Pickle
// concept of __getstate__ and __setstate__ from Python
// (https://docs.python.org/2/library/pickle.html#object.__getstate__)
//
// Currently, both the `get_state` and `set_state` callables must be
// C++ lambda expressions. They should have the following signatures,
// where CurClass is the class you're registering and T is some object
// that encapsulates the state of the object.
//
// __getstate__(intrusive_ptr<CurClass>) -> T
// __setstate__(T) -> intrusive_ptr<CurClass>
//
// T must be an object that is convertable to IValue by the same rules
// for custom op/method registration.
template <typename GetStateFn, typename SetStateFn>
class_& def_pickle(GetStateFn&& get_state, SetStateFn&& set_state) {
static_assert(
@ -177,8 +224,24 @@ class class_ {
registerCustomClassMethod(method);
classTypePtr->addMethod(method.get());
}
std::string className;
std::string qualClassName;
at::ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
};
// make_custom_class() is a convenient way to create an instance of a registered
// custom class and wrap it in an IValue, for example when you want to pass the
// object to TorchScript. Its syntax is equivalent to APIs like std::make_shared<>
// or c10::make_intrusive<>.
//
// For example, if you have a custom C++ class that can be constructed from an int
// and std::string, you might use this API like so:
//
// IValue custom_class_iv = torch::make_custom_class<MyClass>(3, "foobarbaz");
template <typename CurClass, typename... CtorArgs>
c10::IValue make_custom_class(CtorArgs&&... args) {
if (!c10::isCustomClassRegistered<c10::intrusive_ptr<CurClass>>()) {