[JIT] Add static method support for TorchBind (#51177)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51177

**Summary**
This commit adds support for static methods to TorchBind. Just like
pybind, the API for declaring a static method is `def_static(...)`. A
static method must be called on the class directly, and can be called
both in Python as well as TorchScript.

Support for static methods is implemented in a manner similar to that of
instance methods. Registered static functions are wrapped in a layer of
unboxing logic, their schemas are inferred using templates and
metaprogramming, and they are added to the `ClassType` object
corresponding to the TorchBind class on which they are registered.
ScriptClass has been extended to support a `__getattr__` function so
that static methods of TorchBind classes can be invoked in Python. The
implementation of `__getattr__` returns `ScriptClassFunctionPtr`, a
version of `StrongFunctionPtr` without a compilation unit (since the
functions of a TorchBind class live inside the TorchBind registry).
Within TorchScript, TorchBind static functions are desugared in
`PythonClassValue::attr` by looking them up on the class type of the
`PythonClassValue` instance.

**Test Plan**
This commit adds a unit test that tests a simple static method on a
TorchBind class.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26356942

Pulled By: SplitInfinity

fbshipit-source-id: 1b6a9bc2e5f3e22071ad78e331a0201fbbf7ab30
This commit is contained in:
Meghan Lele
2021-02-13 19:35:38 -08:00
committed by Facebook GitHub Bot
parent de4c9ecc35
commit 73de98204d
8 changed files with 106 additions and 0 deletions

View File

@ -153,6 +153,30 @@ class class_ {
return *this;
}
/// Method registration API for static methods.
template <typename Func>
class_& def_static(std::string name, Func func, std::string doc_string = "") {
auto qualMethodName = qualClassName + "." + name;
auto schema =
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
auto wrapped_func =
[func = std::move(func)](jit::Stack& stack) mutable -> void {
using RetType =
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualMethodName,
std::move(schema),
std::move(wrapped_func),
std::move(doc_string));
classTypePtr->addStaticMethod(method.get());
registerCustomClassMethod(std::move(method));
return *this;
}
/// This is an unsafe method registration API added for adding custom JIT backend support via custom
/// C++ classes. It is not for general purpose use.
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {