[JIT] Introduce BuiltinOpFunction and integrate into torchbind (#34098)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34098

* #33900 [JIT] Move stuff out of class_type.cpp

Test Plan: Imported from OSS

Differential Revision: D20229166

Pulled By: jamesr66a

fbshipit-source-id: d658a63a5d6e372e675f35b8456adc8de82b49f3
This commit is contained in:
James Reed
2020-03-07 09:59:11 -08:00
committed by Facebook Github Bot
parent 60e8615a6d
commit 45a504dd2d
26 changed files with 359 additions and 188 deletions

View File

@ -297,8 +297,8 @@ namespace detail {
};
template<class FuncType>
std::unique_ptr<FunctionSchema> inferFunctionSchema_() {
return std::make_unique<FunctionSchema>(inferFunctionSchema<FuncType>("", ""));
std::unique_ptr<FunctionSchema> inferFunctionSchemaFlattenedReturns_() {
return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<FuncType>("", ""));
}
template<class KernelFunctor>
@ -306,7 +306,7 @@ namespace detail {
public:
using func_type = typename c10::guts::infer_function_traits_t<KernelFunctor>::func_type;
std::unique_ptr<FunctionSchema> operator()() const {
return inferFunctionSchema_<func_type>();
return inferFunctionSchemaFlattenedReturns_<func_type>();
}
};
}

View File

@ -0,0 +1,105 @@
#pragma once
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/function.h>
namespace torch {
namespace jit {
struct BuiltinOpFunction : public Function {
BuiltinOpFunction(
c10::QualifiedName qualname,
c10::FunctionSchema schema,
std::function<void(Stack&)> callable)
: name_(std::move(qualname)),
callable_(std::move(callable)),
schema_(std::move(schema)) {
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
}
bool isGraphFunction() const override {
return false;
}
void run(Stack& stack) override {
callable_(stack);
}
void run(Stack&& stack) override {
callable_(stack);
}
at::IValue operator()(std::vector<at::IValue> stack, const Kwargs& kwargs)
override {
getSchema().checkAndNormalizeInputs(stack, kwargs);
callable_(stack);
return stack.front();
}
const c10::QualifiedName& qualname() const override {
return name_;
}
const std::string& name() const override {
return name_.name();
}
// if this isn't yet defined, run its method_creator function
void ensure_defined() override {
// nop
}
std::shared_ptr<Graph> graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
std::shared_ptr<Graph> optimized_graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
GraphExecutor& get_executor() override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a GraphExecutor requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
const c10::FunctionSchema& getSchema() const override {
return schema_;
}
size_t num_inputs() const override {
return schema_.arguments().size();
}
void check_single_output() override {
TORCH_CHECK(schema_.returns().size() == 1);
}
std::string pretty_print_schema() const override {
TORCH_INTERNAL_ASSERT(false);
std::stringstream ss;
ss << getSchema();
return ss.str();
}
Function& setSchema(c10::FunctionSchema schema) override {
schema_ = std::move(schema);
return *this;
}
~BuiltinOpFunction() {}
private:
c10::QualifiedName name_;
std::function<void(Stack&)> callable_;
c10::FunctionSchema schema_;
};
} // namespace jit
} // namespace torch

View File

@ -1,27 +1,41 @@
#include <torch/custom_class.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/api/custom_class.h>
#include <atomic>
#include <unordered_map>
namespace torch {
namespace jit {
namespace {
at::TypePtr noOpGetter(const std::string& /*unused*/) {
return nullptr;
std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
return customClasses;
}
std::atomic<GetCustomClassFnType> custom_class_fn{noOpGetter};
} // namespace
void setGetCustomClassFn(GetCustomClassFnType fn) {
custom_class_fn.store(fn);
void registerCustomClass(at::ClassTypePtr class_type) {
TORCH_INTERNAL_ASSERT(class_type->name());
auto name = class_type->name()->qualifiedName();
TORCH_CHECK(!customClasses().count(name))
customClasses()[name] = std::move(class_type);
}
at::TypePtr getCustomClass(const std::string& name) {
return custom_class_fn.load()(name);
at::ClassTypePtr getCustomClass(const std::string& name) {
return customClasses().count(name) ? customClasses()[name] : nullptr;
}
bool isCustomClass(const c10::IValue& v) {
return v.isObject() && v.toObject()->type()->name() &&
getCustomClass(v.toObject()->type()->name()->qualifiedName());
}
std::vector<std::shared_ptr<Function>>& customClassMethods() {
static std::vector<std::shared_ptr<Function>> customClassMethods;
return customClassMethods;
}
void registerCustomClassMethod(std::shared_ptr<Function> fn) {
customClassMethods().emplace_back(std::move(fn));
}
} // namespace jit

View File

@ -25,6 +25,8 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// execution of the function. script::Method is a wrapper around a
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
virtual bool isGraphFunction() const = 0;
virtual void run(Stack& stack) = 0;
virtual void run(Stack&& stack) = 0;

View File

@ -367,14 +367,10 @@ StrongTypePtr::StrongTypePtr(
cu_ = std::move(cu);
type_ = type;
TORCH_INTERNAL_ASSERT(type_);
if (type_->cast<ClassType>()) {
TORCH_INTERNAL_ASSERT(
cu_, "class type's owning compilation unit is nullptr");
}
}
std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap() {
static std::unordered_map<std::string, c10::StrongTypePtr> tmap;
std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap() {
static std::unordered_map<std::string, c10::ClassTypePtr> tmap;
return tmap;
}

View File

@ -25,6 +25,10 @@ struct ClassType;
struct Type;
class RRefInterface;
using TypePtr = std::shared_ptr<Type>;
struct ClassType;
using ClassTypePtr = std::shared_ptr<ClassType>;
namespace ivalue {
struct Tuple;
struct Future;
@ -695,12 +699,12 @@ struct TORCH_API StrongTypePtr {
std::shared_ptr<Type> type_;
};
TORCH_API std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap();
TORCH_API std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap();
#ifndef C10_MOBILE
template<typename T>
c10::StrongTypePtr getCustomClassType() {
c10::ClassTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(typeid(T).name());
if (res == tmap.end()) {
@ -718,7 +722,7 @@ inline bool isCustomClassRegistered() {
#else // C10_MOBILE
template<typename T>
c10::StrongTypePtr getCustomClassType() {
c10::ClassTypePtr getCustomClassType() {
throw c10::Error("Custom class is not supported on mobile.", "");
}

View File

@ -866,7 +866,11 @@ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
}
auto res = getCustomClassType<inputType>();
auto retObject = ivalue::Object::create(std::move(res), 1);
auto retObject = ivalue::Object::create(
StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit>(),
std::move(res)),
1);
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));
retObject->setSlot(0, IValue(std::move(objPtr)));

View File

@ -1317,7 +1317,7 @@ struct getTypePtr_ final {
throw c10::Error("Type could not be converted to any of the known types.", "");
}
auto res = getCustomClassType<T>();
return std::dynamic_pointer_cast<Type>(std::move(res.type_));
return std::dynamic_pointer_cast<Type>(std::move(res));
}
};
@ -1436,6 +1436,12 @@ struct getTypePtr_<std::tuple<Contained...>> final {
return TupleType::create(std::move(contained_types));
}
};
template <>
struct getTypePtr_<void> final {
static TypePtr call() {
return NoneType::get();
}
};
} // namespace detail
template <class T>
inline TypePtr getTypePtr() {

View File

@ -97,6 +97,13 @@ struct createReturns<void, void> final {
}
};
template <typename ReturnType>
struct createSingleReturn {
static constexpr std::array<ArgumentDef, 1> call() {
return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
}
};
template<size_t NumArgs>
std::vector<Argument> createArgumentVector(const std::array<ArgumentDef, NumArgs>& args) {
std::vector<Argument> result;
@ -120,9 +127,9 @@ inline FunctionSchema make_function_schema(std::string&& name, std::string&& ove
}
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
/// function. Flattens std::tuple returns into multiple return types
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&& overload_name) {
FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns(std::string&& name, std::string&& overload_name) {
using ReturnType = typename FunctionTraits::return_type;
using ParameterTypes = typename FunctionTraits::parameter_types;
@ -131,12 +138,31 @@ FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&&
return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
}
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function. Preserves std::tuple returns as a Tuple return type
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
using ReturnType = typename FunctionTraits::return_type;
using ParameterTypes = typename FunctionTraits::parameter_types;
constexpr auto arguments = createArguments<ParameterTypes>::call();
constexpr auto returns = createSingleReturn<ReturnType>::call();
return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
}
}
}
template<class FuncType>
FunctionSchema inferFunctionSchema(std::string&& name, std::string&& overload_name) {
return detail::infer_schema::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
FunctionSchema inferFunctionSchemaFlattenedReturns(std::string&& name, std::string&& overload_name) {
return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
template<class FuncType>
FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
CAFFE2_API c10::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);

View File

@ -979,13 +979,11 @@ void ClassType::unsafeRemoveConstant(const std::string& name) {
std::shared_ptr<CompilationUnit> ClassType::compilation_unit() {
auto cu = compilation_unit_.lock();
TORCH_INTERNAL_ASSERT(cu);
return cu;
}
std::shared_ptr<const CompilationUnit> ClassType::compilation_unit() const {
auto cu = compilation_unit_.lock();
TORCH_INTERNAL_ASSERT(cu);
return cu;
}

View File

@ -366,7 +366,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/runtime/autodiff.cpp
${TORCH_SRC_DIR}/csrc/jit/ir/attributes.cpp
${TORCH_SRC_DIR}/csrc/jit/runtime/argument_spec.cpp
${TORCH_SRC_DIR}/csrc/jit/api/custom_class.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/pass_manager.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/pickler.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/unpickler.cpp

View File

@ -1,4 +1,5 @@
#include <torch/custom_class.h>
#include <torch/script.h>
#include <iostream>
#include <string>
@ -121,11 +122,10 @@ static auto testPickle =
});
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
return at::zeros({instance->vals.back(), 4});
return torch::zeros({instance->vals.back(), 4});
}
torch::RegisterOperators& register_take_instance() {
static int ensure_custom_class_handler_registered = register_custom_class_handler();
static auto instance_registry = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(

View File

@ -82,7 +82,6 @@ libtorch_sources = [
"torch/csrc/jit/runtime/argument_spec.cpp",
"torch/csrc/jit/ir/constants.cpp",
"torch/csrc/jit/api/function_impl.cpp",
"torch/csrc/jit/api/custom_class.cpp",
"torch/csrc/jit/ir/node_hashing.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
"torch/csrc/jit/serialization/export.cpp",

View File

@ -1,40 +0,0 @@
#include <torch/custom_class.h>
#include <atomic>
namespace torch {
namespace jit {
std::vector<c10::RegisterOperators>& registeredOps() {
static std::vector<c10::RegisterOperators> ops;
return ops;
}
std::shared_ptr<script::CompilationUnit>& classCU() {
static std::shared_ptr<script::CompilationUnit> cu =
std::make_shared<script::CompilationUnit>();
return cu;
}
bool isCustomClass(const c10::IValue& v) {
return v.isObject() && v.toObject()->type()->name() &&
getCustomClass(v.toObject()->type()->name()->qualifiedName());
}
namespace {
TypePtr realCustomClassHandler(const std::string& name) {
return classCU()->get_type(name);
}
} // namespace
int register_custom_class_handler() {
setGetCustomClassFn(realCustomClassHandler);
return 0;
};
static int ensure_custom_class_handler_registered = register_custom_class_handler();
} // namespace jit
} // namespace torch

View File

@ -1,24 +0,0 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
namespace torch {
namespace jit {
TORCH_API at::TypePtr getCustomClass(const std::string& name);
TORCH_API bool isCustomClass(const c10::IValue& v);
using GetCustomClassFnType = at::TypePtr (*)(const std::string&);
// Use this to set the function for retrieving custom classes
//
// This is necessary because the custom classes implementation
// is not in ATen core, but the schema type parser is, which
// can resolve custom classes as type expressions.
TORCH_API void setGetCustomClassFn(GetCustomClassFnType fn);
TORCH_API int register_custom_class_handler();
} // namespace jit
} // namespace torch

View File

@ -17,6 +17,10 @@ struct TORCH_API GraphFunction : public Function {
graph_(std::move(graph)),
function_creator_(std::move(function_creator)) {}
bool isGraphFunction() const override {
return true;
}
void run(Stack& stack) override;
void run(Stack&& stack) override;

View File

@ -2,10 +2,10 @@
#include <ATen/core/interned_strings.h>
#include <ATen/core/jit_type.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/api/custom_class.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
#include <torch/csrc/jit/frontend/schema_type_parser.h>
#include <torch/custom_class.h>
#include <string>
using c10::AliasInfo;

View File

@ -1,8 +1,9 @@
#include <torch/csrc/jit/ir/ir.h>
#include <ATen/core/builtin_function.h>
#include <ATen/core/function.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
@ -1822,6 +1823,7 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
std::vector<Value*> inlineCallTo(Node* to_replace, Function* callee) {
WithInsertPoint guard(to_replace);
TORCH_INTERNAL_ASSERT(callee->isGraphFunction());
std::unordered_map<Value*, Value*> value_map;
auto new_outputs = insertGraph(
*to_replace->owningGraph(),

View File

@ -31,6 +31,9 @@ void inlineCalls(Block* block) {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
auto function = class_type->getMethod(name);
if (!function->isGraphFunction()) {
continue;
}
GRAPH_UPDATE("Inlining method '", function->name(), "' to ", *cur);
GRAPH_UPDATE("Function body: ", *function->optimized_graph());
inlineCallTo(cur, function);

View File

@ -29,17 +29,11 @@ void initPythonCustomClassBindings(PyObject* module) {
// directly, we need a wrapper that at least returns the instance
// rather than the None return value from __init__
m.def("_get_custom_class_python_wrapper", [](const std::string& qualname) {
auto cu = classCU();
std::string full_qualname = "__torch__.torch.classes." + qualname;
c10::NamedTypePtr named_type = cu->get_type(full_qualname);
if (!named_type || !named_type->cast<ClassType>()) {
std::stringstream err;
err << "Class " << qualname << " not registered!";
throw std::runtime_error(err.str());
}
auto named_type = getCustomClass(full_qualname);
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
return ScriptClass(
c10::StrongTypePtr(std::move(cu), std::move(class_type)));
return ScriptClass(c10::StrongTypePtr(
std::shared_ptr<script::CompilationUnit>(), std::move(class_type)));
});
}

View File

@ -779,7 +779,7 @@ void initJitScriptBindings(PyObject* module) {
py::object state;
std::string qualname;
std::tie(state, qualname) = state_tup;
auto class_type = classCU()->get_class(qualname);
auto class_type = getCustomClass(qualname);
TORCH_CHECK(
class_type,
"Tried to deserialize class ",
@ -789,7 +789,10 @@ void initJitScriptBindings(PyObject* module) {
"sure the appropriate code is linked.");
auto self = script::Object(c10::ivalue::Object::create(
c10::StrongTypePtr(classCU(), class_type), 1));
c10::StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit>(),
class_type),
1));
if (auto setstate_method = self.find_method("__setstate__")) {
auto setstate_schema = setstate_method->function().getSchema();
TORCH_INTERNAL_ASSERT(

View File

@ -965,6 +965,32 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
}
}
void runBuiltinFunction(Stack &stack, Function *fn, ActiveFrame *af) {
// BuiltinOpFunction directly invokes a void(Stack&) to implement
// custom C++ classes. Call run() here with the stack, and we will
// get the results from that C++ method back in the stack. Advance
// the PC by 1 without adding any new frame.
fn->run(stack);
++af->pc;
}
void runGraphFunction(Stack &stack, Function *fn, ActiveFrame *af) {
const Code& code =
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a
// potential to reduce the number of compilations for too
// dynamic callers we might miss opportunities where a caller is
// dynamic but a callee gets stable arguments
fn->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
frames.back().pc = af->pc + 1;
enterFrame(code, stack.size() - code.num_inputs());
*af = ActiveFrame(frames.back());
}
bool runImpl(Stack& stack) {
// if we have never run before, then we might have to return the
// stack when we suspend, record where it starts so we return the right
@ -1062,21 +1088,12 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
}
} break;
case CALL: {
const Code& code =
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a
// potential to reduce the number of compilations for too
// dynamic callers we might miss opportunities where a caller is
// dynamic but a callee gets stable arguments
af.functions[inst.X]
->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
frames.back().pc = af.pc + 1;
enterFrame(code, stack.size() - code.num_inputs());
af = ActiveFrame(frames.back());
Function* fn = af.functions[inst.X];
if (!fn->isGraphFunction()) {
runBuiltinFunction(stack, fn, &af);
} else {
runGraphFunction(stack, fn, &af);
}
} break;
case INTERFACE_CALL: {
// note the hash table lookup to find the function
@ -1095,13 +1112,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
.toObject()
->type()
->getMethod(af.constants[inst.X].toStringRef());
const Code& code =
function->get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code;
frames.back().pc = af.pc + 1;
enterFrame(code, stack.size() - inst.N);
af = ActiveFrame(frames.back());
if (!function->isGraphFunction()) {
runBuiltinFunction(stack, function, &af);
} else {
runGraphFunction(stack, function, &af);
}
} break;
case RET:
if (frames.size() > 1) {

View File

@ -1,11 +1,11 @@
#include "import_source.h"
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/api/custom_class.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/custom_class.h>
namespace torch {
namespace jit {

View File

@ -1147,6 +1147,7 @@ struct PythonPrintImpl {
void printFunction(
const Function& func,
bool print_first_argument_type = true) {
TORCH_INTERNAL_ASSERT(func.isGraphFunction());
const FunctionSchema& schema = func.getSchema();
Graph& graph = *func.graph();
used_names_.clear(); // each graph can reuse local names
@ -1192,6 +1193,16 @@ struct PythonPrintImpl {
enforce_importable_(enforce_importable) {}
void printClass(const ClassTypePtr& classType) {
// If any of the methods are not Graph funtions, this indicates that
// this class is a custom-bound C++ class. Skip serialization
// of this class, we will depend on the ClassType being defined
// in the target process.
for (auto& method : classType->methods()) {
if (!method->isGraphFunction()) {
return;
}
}
bool is_module = classType->is_module();
body_ << "class " << classType->name()->name();
if (is_module) {

View File

@ -1,28 +1,31 @@
#pragma once
#include <ATen/core/builtin_function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/stack.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/TypeTraits.h>
#include <torch/csrc/jit/api/custom_class.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/utils/variadic.h>
#include <torch/custom_class_detail.h>
#include <iostream>
#include <sstream>
namespace torch {
namespace jit {
namespace script {
struct CompilationUnit;
}
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
TORCH_API bool isCustomClass(const c10::IValue& v);
template <class... Types>
detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
@ -47,7 +50,7 @@ class class_ {
std::string className;
std::string qualClassName;
ClassTypePtr classTypePtr;
at::ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
@ -58,16 +61,17 @@ class class_ {
// We currently represent custom classes as torchscript classes with a
// capsule attribute
classTypePtr =
ClassType::create(c10::QualifiedName(qualClassName), classCU());
classTypePtr->addAttribute("capsule", CapsuleType::get());
classTypePtr = at::ClassType::create(
c10::QualifiedName(qualClassName),
std::weak_ptr<script::CompilationUnit>());
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
c10::StrongTypePtr(classCU(), classTypePtr)});
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
c10::StrongTypePtr(classCU(), classTypePtr)});
c10::getCustomClassTypeMap().insert(
{typeid(c10::intrusive_ptr<CurClass>).name(), classTypePtr});
c10::getCustomClassTypeMap().insert(
{typeid(c10::tagged_capsule<CurClass>).name(), classTypePtr});
classCU()->register_type(classTypePtr);
registerCustomClass(classTypePtr);
}
template <typename... Types>
@ -163,40 +167,26 @@ class class_ {
private:
template <typename Func>
void defineMethod(std::string name, Func func) {
auto graph = std::make_shared<Graph>();
auto qualFuncName = className + "::" + name;
ensure_c10_registerer_defined();
registeredOps().push_back(
torch::RegisterOperators().op(qualFuncName, std::move(func)));
auto func_symbol = c10::Symbol::fromQualString(qualFuncName);
auto ops = torch::jit::getAllOperatorsFor(func_symbol);
TORCH_CHECK(ops.size() == 1);
auto &schema = ops[0]->schema();
auto qualMethodName = qualClassName + "." + name;
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
for (const auto& arg : schema.arguments()) {
graph->addInput()->setType(arg.type());
}
auto wrapped_func = [func = std::move(func)](Stack& stack) mutable -> void {
// TODO: we need to figure out how to profile calls to custom functions
// like this! Currently can't do it because the profiler stuff is in
// libtorch and not ATen
using RetType =
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_shared<BuiltinOpFunction>(
qualMethodName, std::move(schema), std::move(wrapped_func));
auto opCall = graph->insertNode(graph->create(
func_symbol, graph->inputs(), schema.returns().size()));
Value* res;
if (schema.returns().size() > 1) {
const auto& returns = schema.returns();
size_t op_invocation_idx = 0;
for (const auto& ret : returns) {
opCall->output(op_invocation_idx++)->setType(ret.type());
}
res = graph->insertNode(graph->createTuple(opCall->outputs()))->output();
} else if (schema.returns().size() == 1) {
const auto& returns = schema.returns();
res = opCall->output()->setType(returns[0].type());
} else {
res = graph->insertConstant(IValue())->setType(NoneType::get());
}
graph->registerOutput(res);
auto method = classCU()->create_function(qualClassName + "." + name, graph);
classTypePtr->addMethod(method);
// Register the method here to keep the Method alive.
// ClassTypes do not hold ownership of their methods (normally it
// those are held by the CompilationUnit), so we need a proxy for
// that behavior here.
registerCustomClassMethod(method);
classTypePtr->addMethod(method.get());
}
};

View File

@ -1,5 +1,7 @@
#pragma once
#include <ATen/core/boxing/kernel_functor.h>
#include <ATen/core/function.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeTraits.h>
@ -60,10 +62,68 @@ Func wrap_func(Func f) {
return f;
}
template <
class Functor,
bool AllowDeprecatedTypes,
size_t... ivalue_arg_indices>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(
Functor& functor,
Stack& stack,
std::index_sequence<ivalue_arg_indices...>) {
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
// be unused and we have to silence the compiler warning.
constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
using IValueArgTypes =
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
return (functor)(c10::detail::ivalue_to_arg<
std::remove_cv_t<std::remove_reference_t<
c10::guts::typelist::
element_t<ivalue_arg_indices, IValueArgTypes>>>,
AllowDeprecatedTypes>(std::move(
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
}
template <class Functor, bool AllowDeprecatedTypes>
typename c10::guts::infer_function_traits_t<Functor>::return_type
call_torchbind_method_from_stack(Functor& functor, Stack& stack) {
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
functor, stack, std::make_index_sequence<num_ivalue_args>());
}
template <class RetType, class Func>
struct BoxedProxy;
template <class RetType, class Func>
struct BoxedProxy {
void operator()(Stack& stack, Func& func) {
auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(c10::ivalue::from(std::move(retval)));
}
};
template <class Func>
struct BoxedProxy<void, Func> {
void operator()(Stack& stack, Func& func) {
call_torchbind_method_from_stack<Func, false>(func, stack);
constexpr size_t num_ivalue_args =
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
torch::jit::drop(stack, num_ivalue_args);
stack.emplace_back(IValue());
}
};
} // namespace detail
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::shared_ptr<Function> method);
} // namespace jit
} // namespace torch