mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Add static method support for TorchBind (#51177)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51177 **Summary** This commit adds support for static methods to TorchBind. Just like pybind, the API for declaring a static method is `def_static(...)`. A static method must be called on the class directly, and can be called both in Python as well as TorchScript. Support for static methods is implemented in a manner similar to that of instance methods. Registered static functions are wrapped in a layer of unboxing logic, their schemas are inferred using templates and metaprogramming, and they are added to the `ClassType` object corresponding to the TorchBind class on which they are registered. ScriptClass has been extended to support a `__getattr__` function so that static methods of TorchBind classes can be invoked in Python. The implementation of `__getattr__` returns `ScriptClassFunctionPtr`, a version of `StrongFunctionPtr` without a compilation unit (since the functions of a TorchBind class live inside the TorchBind registry). Within TorchScript, TorchBind static functions are desugared in `PythonClassValue::attr` by looking them up on the class type of the `PythonClassValue` instance. **Test Plan** This commit adds a unit test that tests a simple static method on a TorchBind class. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26356942 Pulled By: SplitInfinity fbshipit-source-id: 1b6a9bc2e5f3e22071ad78e331a0201fbbf7ab30
This commit is contained in:
committed by
Facebook GitHub Bot
parent
de4c9ecc35
commit
73de98204d
@ -153,6 +153,30 @@ class class_ {
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Method registration API for static methods.
|
||||
template <typename Func>
|
||||
class_& def_static(std::string name, Func func, std::string doc_string = "") {
|
||||
auto qualMethodName = qualClassName + "." + name;
|
||||
auto schema =
|
||||
c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
||||
|
||||
auto wrapped_func =
|
||||
[func = std::move(func)](jit::Stack& stack) mutable -> void {
|
||||
using RetType =
|
||||
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
||||
detail::BoxedProxy<RetType, Func>()(stack, func);
|
||||
};
|
||||
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
||||
qualMethodName,
|
||||
std::move(schema),
|
||||
std::move(wrapped_func),
|
||||
std::move(doc_string));
|
||||
|
||||
classTypePtr->addStaticMethod(method.get());
|
||||
registerCustomClassMethod(std::move(method));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// This is an unsafe method registration API added for adding custom JIT backend support via custom
|
||||
/// C++ classes. It is not for general purpose use.
|
||||
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {
|
||||
|
Reference in New Issue
Block a user