mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support doc_string for TorchBind custom classes (#46576)
Summary: With this PR, users can optionally provide a "doc_string" to describe a class or its method. doc_string for TorchBind classes and methods are stored as `doc_string` properties in `Function` and `ScriptClass`. These `dos_string` properties are then exposed in Python layer via PyBind for doc generation. Fixes https://github.com/pytorch/pytorch/issues/46047 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46576 Reviewed By: wanchaol Differential Revision: D24440636 Pulled By: gmagogsfm fbshipit-source-id: bfa9b270a6c2d8bc769a88fad6be939cc6310412
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d4c1a5ab0
commit
f9b9430152
@ -10,13 +10,19 @@ struct BuiltinOpFunction : public Function {
|
||||
BuiltinOpFunction(
|
||||
c10::QualifiedName qualname,
|
||||
c10::FunctionSchema schema,
|
||||
std::function<void(Stack&)> callable)
|
||||
std::function<void(Stack&)> callable,
|
||||
std::string doc_string = "")
|
||||
: name_(std::move(qualname)),
|
||||
callable_(std::move(callable)),
|
||||
schema_(std::move(schema)) {
|
||||
schema_(std::move(schema)),
|
||||
doc_string_(std::move(doc_string)) {
|
||||
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
|
||||
}
|
||||
|
||||
const std::string& doc_string() const override {
|
||||
return doc_string_;
|
||||
}
|
||||
|
||||
bool isGraphFunction() const override {
|
||||
return false;
|
||||
}
|
||||
@ -110,6 +116,8 @@ struct BuiltinOpFunction : public Function {
|
||||
std::function<void(Stack&)> callable_;
|
||||
|
||||
c10::FunctionSchema schema_;
|
||||
|
||||
std::string doc_string_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -25,6 +25,11 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
// execution of the function. Method is a wrapper around an
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
virtual const std::string& doc_string() const {
|
||||
static const std::string no_doc_string = "";
|
||||
return no_doc_string;
|
||||
}
|
||||
|
||||
virtual bool isGraphFunction() const = 0;
|
||||
|
||||
virtual void run(Stack& stack) = 0;
|
||||
|
@ -1989,7 +1989,8 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
static ClassTypePtr create(
|
||||
c10::optional<QualifiedName> qualifiedName,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module = false);
|
||||
bool is_module = false,
|
||||
std::string doc_string = "");
|
||||
|
||||
bool operator==(const Type& rhs) const override {
|
||||
if (auto user_rhs = rhs.cast<ClassType>()) {
|
||||
@ -2175,6 +2176,9 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
return constantNames_[slot];
|
||||
}
|
||||
|
||||
const std::string& doc_string() const {
|
||||
return doc_string_;
|
||||
}
|
||||
|
||||
IValue getConstant(const std::string& name) const;
|
||||
|
||||
@ -2271,7 +2275,8 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
ClassType(
|
||||
c10::optional<QualifiedName> name,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module);
|
||||
bool is_module,
|
||||
std::string doc_string);
|
||||
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
|
||||
const auto& n = name().value();
|
||||
@ -2306,6 +2311,9 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
std::vector<Property> properties_;
|
||||
|
||||
bool isModule_ = false;
|
||||
|
||||
// Doc string of class.
|
||||
std::string doc_string_ = "";
|
||||
};
|
||||
|
||||
struct InterfaceType;
|
||||
|
@ -1211,19 +1211,21 @@ InterfaceType::~InterfaceType() = default;
|
||||
ClassTypePtr ClassType::create(
|
||||
c10::optional<QualifiedName> qualifiedName,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module) {
|
||||
bool is_module,
|
||||
std::string doc_string) {
|
||||
return ClassTypePtr(
|
||||
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
|
||||
new ClassType(std::move(qualifiedName), std::move(cu), is_module, std::move(doc_string)));
|
||||
}
|
||||
|
||||
ClassType::ClassType(
|
||||
c10::optional<QualifiedName> name,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module = false)
|
||||
bool is_module = false,
|
||||
std::string doc_string = "")
|
||||
: NamedType(TypeKind::ClassType, std::move(name)),
|
||||
compilation_unit_(std::move(cu)),
|
||||
isModule_(is_module) {
|
||||
}
|
||||
isModule_(is_module),
|
||||
doc_string_(std::move(doc_string)) {}
|
||||
|
||||
const std::vector<torch::jit::Function*>& ClassType::methods() const {
|
||||
return methods_;
|
||||
|
@ -44,5 +44,47 @@ TEST(CustomClassTest, TorchbindIValueAPI) {
|
||||
test_with_obj(new_stack_ivalue, "boo");
|
||||
}
|
||||
|
||||
class TorchBindTestClass : public torch::jit::CustomClassHolder {
|
||||
public:
|
||||
std::string get() {
|
||||
return "Hello, I am your test custom class";
|
||||
}
|
||||
};
|
||||
|
||||
constexpr char class_doc_string[] = R"(
|
||||
I am docstring for TorchBindTestClass
|
||||
Args:
|
||||
What is an argument? Oh never mind, I don't take any.
|
||||
|
||||
Return:
|
||||
How would I know? I am just a holder of some meaningless test methods.
|
||||
)";
|
||||
constexpr char method_doc_string[] =
|
||||
"I am docstring for TorchBindTestClass get_with_docstring method";
|
||||
|
||||
namespace {
|
||||
static auto reg =
|
||||
torch::class_<TorchBindTestClass>(
|
||||
"_TorchBindTest",
|
||||
"_TorchBindTestClass",
|
||||
class_doc_string)
|
||||
.def("get", &TorchBindTestClass::get)
|
||||
.def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
|
||||
|
||||
} // namespace
|
||||
|
||||
// Tests DocString is properly propagated when defining CustomClasses.
|
||||
TEST(CustomClassTest, TestDocString) {
|
||||
auto class_type = getCustomClass(
|
||||
"__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
|
||||
AT_ASSERT(class_type);
|
||||
AT_ASSERT(class_type->doc_string() == class_doc_string);
|
||||
|
||||
AT_ASSERT(class_type->getMethod("get").doc_string().empty());
|
||||
AT_ASSERT(
|
||||
class_type->getMethod("get_with_docstring").doc_string() ==
|
||||
method_doc_string);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -28,7 +28,10 @@ void initPythonCustomClassBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
py::class_<ScriptClass>(m, "ScriptClass")
|
||||
.def("__call__", &ScriptClass::__call__);
|
||||
.def("__call__", &ScriptClass::__call__)
|
||||
.def_property_readonly("__doc__", [](const ScriptClass& self) {
|
||||
return self.class_type_.type_->expect<ClassType>()->doc_string();
|
||||
});
|
||||
|
||||
// This function returns a ScriptClass that wraps the constructor
|
||||
// of the given class, specified by the qualified name passed in.
|
||||
|
@ -1187,9 +1187,13 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"name",
|
||||
[](const StrongFunctionPtr& self) { return self.function_->name(); })
|
||||
.def_property_readonly(
|
||||
"qualified_name", [](const StrongFunctionPtr& self) {
|
||||
"qualified_name",
|
||||
[](const StrongFunctionPtr& self) {
|
||||
return self.function_->qualname().qualifiedName();
|
||||
});
|
||||
})
|
||||
.def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
|
||||
return self.function_->doc_string();
|
||||
});
|
||||
|
||||
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
|
||||
.def(
|
||||
|
@ -58,14 +58,16 @@ class class_ {
|
||||
/// see this class exposed as in Python and TorchScript. For example, if
|
||||
/// you pass `foo` as the namespace name and `Bar` as the className, the
|
||||
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
|
||||
explicit class_(const std::string& namespaceName, const std::string& className) {
|
||||
explicit class_(const std::string& namespaceName, const std::string& className, std::string doc_string = "") {
|
||||
detail::checkValidIdent(namespaceName, "Namespace name");
|
||||
detail::checkValidIdent(className, "Class name");
|
||||
qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className;
|
||||
|
||||
classTypePtr = at::ClassType::create(
|
||||
c10::QualifiedName(qualClassName),
|
||||
std::weak_ptr<jit::CompilationUnit>());
|
||||
std::weak_ptr<jit::CompilationUnit>(),
|
||||
/*is_module=*/false,
|
||||
std::move(doc_string));
|
||||
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
||||
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
@ -81,7 +83,7 @@ class class_ {
|
||||
/// `torch::init<int, std::string>()` would register a two-argument constructor
|
||||
/// taking an `int` and a `std::string` as argument.
|
||||
template <typename... Types>
|
||||
class_& def(detail::types<void, Types...>) { // Used in combination with
|
||||
class_& def(detail::types<void, Types...>, std::string doc_string = "") { // Used in combination with
|
||||
// torch::init<...>()
|
||||
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
||||
auto classObj = c10::make_intrusive<CurClass>(args...);
|
||||
@ -89,7 +91,7 @@ class class_ {
|
||||
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
|
||||
};
|
||||
|
||||
defineMethod("__init__", std::move(func));
|
||||
defineMethod("__init__", std::move(func), std::move(doc_string));
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -112,18 +114,18 @@ class class_ {
|
||||
/// // do something
|
||||
/// })
|
||||
template <typename Func>
|
||||
class_& def(std::string name, Func f) {
|
||||
class_& def(std::string name, Func f, std::string doc_string = "") {
|
||||
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
||||
defineMethod(std::move(name), std::move(wrapped_f));
|
||||
defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// This is an unsafe method registration API added for adding custom JIT backend support via custom
|
||||
/// C++ classes. It is not for general purpose use.
|
||||
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema) {
|
||||
class_& _def_unboxed(std::string name, std::function<void(jit::Stack&)> func, c10::FunctionSchema schema, std::string doc_string = "") {
|
||||
auto qualMethodName = qualClassName + "." + name;
|
||||
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
||||
qualMethodName, std::move(schema), std::move(func));
|
||||
qualMethodName, std::move(schema), std::move(func), std::move(doc_string));
|
||||
classTypePtr->addMethod(method.get());
|
||||
registerCustomClassMethod(std::move(method));
|
||||
return *this;
|
||||
@ -228,7 +230,7 @@ class class_ {
|
||||
|
||||
private:
|
||||
template <typename Func>
|
||||
void defineMethod(std::string name, Func func) {
|
||||
void defineMethod(std::string name, Func func, std::string doc_string = "") {
|
||||
auto qualMethodName = qualClassName + "." + name;
|
||||
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
||||
|
||||
@ -241,7 +243,7 @@ class class_ {
|
||||
detail::BoxedProxy<RetType, Func>()(stack, func);
|
||||
};
|
||||
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
||||
qualMethodName, std::move(schema), std::move(wrapped_func));
|
||||
qualMethodName, std::move(schema), std::move(wrapped_func), std::move(doc_string));
|
||||
|
||||
// Register the method here to keep the Method alive.
|
||||
// ClassTypes do not hold ownership of their methods (normally it
|
||||
|
Reference in New Issue
Block a user