mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Introduce BuiltinOpFunction and integrate into torchbind (#34098)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34098 * #33900 [JIT] Move stuff out of class_type.cpp Test Plan: Imported from OSS Differential Revision: D20229166 Pulled By: jamesr66a fbshipit-source-id: d658a63a5d6e372e675f35b8456adc8de82b49f3
This commit is contained in:
committed by
Facebook Github Bot
parent
60e8615a6d
commit
45a504dd2d
@ -297,8 +297,8 @@ namespace detail {
|
||||
};
|
||||
|
||||
template<class FuncType>
|
||||
std::unique_ptr<FunctionSchema> inferFunctionSchema_() {
|
||||
return std::make_unique<FunctionSchema>(inferFunctionSchema<FuncType>("", ""));
|
||||
std::unique_ptr<FunctionSchema> inferFunctionSchemaFlattenedReturns_() {
|
||||
return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<FuncType>("", ""));
|
||||
}
|
||||
|
||||
template<class KernelFunctor>
|
||||
@ -306,7 +306,7 @@ namespace detail {
|
||||
public:
|
||||
using func_type = typename c10::guts::infer_function_traits_t<KernelFunctor>::func_type;
|
||||
std::unique_ptr<FunctionSchema> operator()() const {
|
||||
return inferFunctionSchema_<func_type>();
|
||||
return inferFunctionSchemaFlattenedReturns_<func_type>();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
105
aten/src/ATen/core/builtin_function.h
Normal file
105
aten/src/ATen/core/builtin_function.h
Normal file
@ -0,0 +1,105 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/function.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct BuiltinOpFunction : public Function {
|
||||
BuiltinOpFunction(
|
||||
c10::QualifiedName qualname,
|
||||
c10::FunctionSchema schema,
|
||||
std::function<void(Stack&)> callable)
|
||||
: name_(std::move(qualname)),
|
||||
callable_(std::move(callable)),
|
||||
schema_(std::move(schema)) {
|
||||
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
|
||||
}
|
||||
|
||||
bool isGraphFunction() const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
void run(Stack& stack) override {
|
||||
callable_(stack);
|
||||
}
|
||||
|
||||
void run(Stack&& stack) override {
|
||||
callable_(stack);
|
||||
}
|
||||
|
||||
at::IValue operator()(std::vector<at::IValue> stack, const Kwargs& kwargs)
|
||||
override {
|
||||
getSchema().checkAndNormalizeInputs(stack, kwargs);
|
||||
callable_(stack);
|
||||
return stack.front();
|
||||
}
|
||||
|
||||
const c10::QualifiedName& qualname() const override {
|
||||
return name_;
|
||||
}
|
||||
|
||||
const std::string& name() const override {
|
||||
return name_.name();
|
||||
}
|
||||
|
||||
// if this isn't yet defined, run its method_creator function
|
||||
void ensure_defined() override {
|
||||
// nop
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> graph() const override {
|
||||
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
|
||||
"from it. This probably indicates that the JIT calling context needs a "
|
||||
"special case on Function::isGraphFunction()");
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> optimized_graph() const override {
|
||||
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
|
||||
"from it. This probably indicates that the JIT calling context needs a "
|
||||
"special case on Function::isGraphFunction()");
|
||||
}
|
||||
|
||||
GraphExecutor& get_executor() override {
|
||||
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a GraphExecutor requested "
|
||||
"from it. This probably indicates that the JIT calling context needs a "
|
||||
"special case on Function::isGraphFunction()");
|
||||
}
|
||||
|
||||
const c10::FunctionSchema& getSchema() const override {
|
||||
return schema_;
|
||||
}
|
||||
|
||||
size_t num_inputs() const override {
|
||||
return schema_.arguments().size();
|
||||
}
|
||||
|
||||
void check_single_output() override {
|
||||
TORCH_CHECK(schema_.returns().size() == 1);
|
||||
}
|
||||
|
||||
std::string pretty_print_schema() const override {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
std::stringstream ss;
|
||||
ss << getSchema();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
Function& setSchema(c10::FunctionSchema schema) override {
|
||||
schema_ = std::move(schema);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~BuiltinOpFunction() {}
|
||||
|
||||
private:
|
||||
c10::QualifiedName name_;
|
||||
|
||||
std::function<void(Stack&)> callable_;
|
||||
|
||||
c10::FunctionSchema schema_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,27 +1,41 @@
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/api/custom_class.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
at::TypePtr noOpGetter(const std::string& /*unused*/) {
|
||||
return nullptr;
|
||||
std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
|
||||
static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
|
||||
return customClasses;
|
||||
}
|
||||
|
||||
std::atomic<GetCustomClassFnType> custom_class_fn{noOpGetter};
|
||||
|
||||
} // namespace
|
||||
|
||||
void setGetCustomClassFn(GetCustomClassFnType fn) {
|
||||
custom_class_fn.store(fn);
|
||||
void registerCustomClass(at::ClassTypePtr class_type) {
|
||||
TORCH_INTERNAL_ASSERT(class_type->name());
|
||||
auto name = class_type->name()->qualifiedName();
|
||||
TORCH_CHECK(!customClasses().count(name))
|
||||
customClasses()[name] = std::move(class_type);
|
||||
}
|
||||
|
||||
at::TypePtr getCustomClass(const std::string& name) {
|
||||
return custom_class_fn.load()(name);
|
||||
at::ClassTypePtr getCustomClass(const std::string& name) {
|
||||
return customClasses().count(name) ? customClasses()[name] : nullptr;
|
||||
}
|
||||
|
||||
bool isCustomClass(const c10::IValue& v) {
|
||||
return v.isObject() && v.toObject()->type()->name() &&
|
||||
getCustomClass(v.toObject()->type()->name()->qualifiedName());
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Function>>& customClassMethods() {
|
||||
static std::vector<std::shared_ptr<Function>> customClassMethods;
|
||||
return customClassMethods;
|
||||
}
|
||||
|
||||
void registerCustomClassMethod(std::shared_ptr<Function> fn) {
|
||||
customClassMethods().emplace_back(std::move(fn));
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -25,6 +25,8 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
// execution of the function. script::Method is a wrapper around a
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
virtual bool isGraphFunction() const = 0;
|
||||
|
||||
virtual void run(Stack& stack) = 0;
|
||||
|
||||
virtual void run(Stack&& stack) = 0;
|
||||
|
@ -367,14 +367,10 @@ StrongTypePtr::StrongTypePtr(
|
||||
cu_ = std::move(cu);
|
||||
type_ = type;
|
||||
TORCH_INTERNAL_ASSERT(type_);
|
||||
if (type_->cast<ClassType>()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
cu_, "class type's owning compilation unit is nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap() {
|
||||
static std::unordered_map<std::string, c10::StrongTypePtr> tmap;
|
||||
std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap() {
|
||||
static std::unordered_map<std::string, c10::ClassTypePtr> tmap;
|
||||
return tmap;
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,10 @@ struct ClassType;
|
||||
struct Type;
|
||||
class RRefInterface;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
|
||||
struct ClassType;
|
||||
using ClassTypePtr = std::shared_ptr<ClassType>;
|
||||
|
||||
namespace ivalue {
|
||||
struct Tuple;
|
||||
struct Future;
|
||||
@ -695,12 +699,12 @@ struct TORCH_API StrongTypePtr {
|
||||
std::shared_ptr<Type> type_;
|
||||
};
|
||||
|
||||
TORCH_API std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap();
|
||||
TORCH_API std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap();
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
|
||||
template<typename T>
|
||||
c10::StrongTypePtr getCustomClassType() {
|
||||
c10::ClassTypePtr getCustomClassType() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
auto res = tmap.find(typeid(T).name());
|
||||
if (res == tmap.end()) {
|
||||
@ -718,7 +722,7 @@ inline bool isCustomClassRegistered() {
|
||||
#else // C10_MOBILE
|
||||
|
||||
template<typename T>
|
||||
c10::StrongTypePtr getCustomClassType() {
|
||||
c10::ClassTypePtr getCustomClassType() {
|
||||
throw c10::Error("Custom class is not supported on mobile.", "");
|
||||
}
|
||||
|
||||
|
@ -866,7 +866,11 @@ IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
|
||||
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
|
||||
}
|
||||
auto res = getCustomClassType<inputType>();
|
||||
auto retObject = ivalue::Object::create(std::move(res), 1);
|
||||
auto retObject = ivalue::Object::create(
|
||||
StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit>(),
|
||||
std::move(res)),
|
||||
1);
|
||||
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(std::move(x));
|
||||
|
||||
retObject->setSlot(0, IValue(std::move(objPtr)));
|
||||
|
@ -1317,7 +1317,7 @@ struct getTypePtr_ final {
|
||||
throw c10::Error("Type could not be converted to any of the known types.", "");
|
||||
}
|
||||
auto res = getCustomClassType<T>();
|
||||
return std::dynamic_pointer_cast<Type>(std::move(res.type_));
|
||||
return std::dynamic_pointer_cast<Type>(std::move(res));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1436,6 +1436,12 @@ struct getTypePtr_<std::tuple<Contained...>> final {
|
||||
return TupleType::create(std::move(contained_types));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct getTypePtr_<void> final {
|
||||
static TypePtr call() {
|
||||
return NoneType::get();
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
template <class T>
|
||||
inline TypePtr getTypePtr() {
|
||||
|
@ -97,6 +97,13 @@ struct createReturns<void, void> final {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ReturnType>
|
||||
struct createSingleReturn {
|
||||
static constexpr std::array<ArgumentDef, 1> call() {
|
||||
return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
|
||||
}
|
||||
};
|
||||
|
||||
template<size_t NumArgs>
|
||||
std::vector<Argument> createArgumentVector(const std::array<ArgumentDef, NumArgs>& args) {
|
||||
std::vector<Argument> result;
|
||||
@ -120,9 +127,9 @@ inline FunctionSchema make_function_schema(std::string&& name, std::string&& ove
|
||||
}
|
||||
|
||||
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
||||
/// function.
|
||||
/// function. Flattens std::tuple returns into multiple return types
|
||||
template <typename FunctionTraits>
|
||||
FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&& overload_name) {
|
||||
FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns(std::string&& name, std::string&& overload_name) {
|
||||
using ReturnType = typename FunctionTraits::return_type;
|
||||
using ParameterTypes = typename FunctionTraits::parameter_types;
|
||||
|
||||
@ -131,12 +138,31 @@ FunctionSchema createFunctionSchemaFromTraits(std::string&& name, std::string&&
|
||||
|
||||
return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
|
||||
}
|
||||
|
||||
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
||||
/// function. Preserves std::tuple returns as a Tuple return type
|
||||
template <typename FunctionTraits>
|
||||
FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
|
||||
using ReturnType = typename FunctionTraits::return_type;
|
||||
using ParameterTypes = typename FunctionTraits::parameter_types;
|
||||
|
||||
constexpr auto arguments = createArguments<ParameterTypes>::call();
|
||||
constexpr auto returns = createSingleReturn<ReturnType>::call();
|
||||
|
||||
return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template<class FuncType>
|
||||
FunctionSchema inferFunctionSchema(std::string&& name, std::string&& overload_name) {
|
||||
return detail::infer_schema::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
||||
FunctionSchema inferFunctionSchemaFlattenedReturns(std::string&& name, std::string&& overload_name) {
|
||||
return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
||||
}
|
||||
|
||||
template<class FuncType>
|
||||
FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
|
||||
return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
||||
}
|
||||
|
||||
CAFFE2_API c10::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
|
||||
|
@ -979,13 +979,11 @@ void ClassType::unsafeRemoveConstant(const std::string& name) {
|
||||
|
||||
std::shared_ptr<CompilationUnit> ClassType::compilation_unit() {
|
||||
auto cu = compilation_unit_.lock();
|
||||
TORCH_INTERNAL_ASSERT(cu);
|
||||
return cu;
|
||||
}
|
||||
|
||||
std::shared_ptr<const CompilationUnit> ClassType::compilation_unit() const {
|
||||
auto cu = compilation_unit_.lock();
|
||||
TORCH_INTERNAL_ASSERT(cu);
|
||||
return cu;
|
||||
}
|
||||
|
||||
|
@ -366,7 +366,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/runtime/autodiff.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/ir/attributes.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/runtime/argument_spec.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/api/custom_class.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/pass_manager.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/pickler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/serialization/unpickler.cpp
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
@ -121,11 +122,10 @@ static auto testPickle =
|
||||
});
|
||||
|
||||
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
|
||||
return at::zeros({instance->vals.back(), 4});
|
||||
return torch::zeros({instance->vals.back(), 4});
|
||||
}
|
||||
|
||||
torch::RegisterOperators& register_take_instance() {
|
||||
static int ensure_custom_class_handler_registered = register_custom_class_handler();
|
||||
static auto instance_registry = torch::RegisterOperators().op(
|
||||
torch::RegisterOperators::options()
|
||||
.schema(
|
||||
|
@ -82,7 +82,6 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/runtime/argument_spec.cpp",
|
||||
"torch/csrc/jit/ir/constants.cpp",
|
||||
"torch/csrc/jit/api/function_impl.cpp",
|
||||
"torch/csrc/jit/api/custom_class.cpp",
|
||||
"torch/csrc/jit/ir/node_hashing.cpp",
|
||||
"torch/csrc/jit/ir/type_hashing.cpp",
|
||||
"torch/csrc/jit/serialization/export.cpp",
|
||||
|
@ -1,40 +0,0 @@
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
std::vector<c10::RegisterOperators>& registeredOps() {
|
||||
static std::vector<c10::RegisterOperators> ops;
|
||||
return ops;
|
||||
}
|
||||
|
||||
std::shared_ptr<script::CompilationUnit>& classCU() {
|
||||
static std::shared_ptr<script::CompilationUnit> cu =
|
||||
std::make_shared<script::CompilationUnit>();
|
||||
return cu;
|
||||
}
|
||||
|
||||
bool isCustomClass(const c10::IValue& v) {
|
||||
return v.isObject() && v.toObject()->type()->name() &&
|
||||
getCustomClass(v.toObject()->type()->name()->qualifiedName());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
TypePtr realCustomClassHandler(const std::string& name) {
|
||||
return classCU()->get_type(name);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int register_custom_class_handler() {
|
||||
setGetCustomClassFn(realCustomClassHandler);
|
||||
return 0;
|
||||
};
|
||||
|
||||
static int ensure_custom_class_handler_registered = register_custom_class_handler();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,24 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
TORCH_API at::TypePtr getCustomClass(const std::string& name);
|
||||
|
||||
TORCH_API bool isCustomClass(const c10::IValue& v);
|
||||
|
||||
using GetCustomClassFnType = at::TypePtr (*)(const std::string&);
|
||||
// Use this to set the function for retrieving custom classes
|
||||
//
|
||||
// This is necessary because the custom classes implementation
|
||||
// is not in ATen core, but the schema type parser is, which
|
||||
// can resolve custom classes as type expressions.
|
||||
TORCH_API void setGetCustomClassFn(GetCustomClassFnType fn);
|
||||
|
||||
TORCH_API int register_custom_class_handler();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -17,6 +17,10 @@ struct TORCH_API GraphFunction : public Function {
|
||||
graph_(std::move(graph)),
|
||||
function_creator_(std::move(function_creator)) {}
|
||||
|
||||
bool isGraphFunction() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
void run(Stack& stack) override;
|
||||
|
||||
void run(Stack&& stack) override;
|
||||
|
@ -2,10 +2,10 @@
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/util/string_utils.h>
|
||||
#include <torch/csrc/jit/api/custom_class.h>
|
||||
#include <torch/csrc/jit/frontend/lexer.h>
|
||||
#include <torch/csrc/jit/frontend/parse_string_literal.h>
|
||||
#include <torch/csrc/jit/frontend/schema_type_parser.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <string>
|
||||
|
||||
using c10::AliasInfo;
|
||||
|
@ -1,8 +1,9 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
#include <ATen/core/builtin_function.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
@ -1822,6 +1823,7 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
|
||||
|
||||
std::vector<Value*> inlineCallTo(Node* to_replace, Function* callee) {
|
||||
WithInsertPoint guard(to_replace);
|
||||
TORCH_INTERNAL_ASSERT(callee->isGraphFunction());
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
auto new_outputs = insertGraph(
|
||||
*to_replace->owningGraph(),
|
||||
|
@ -31,6 +31,9 @@ void inlineCalls(Block* block) {
|
||||
const std::string& name = cur->s(attr::name);
|
||||
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
|
||||
auto function = class_type->getMethod(name);
|
||||
if (!function->isGraphFunction()) {
|
||||
continue;
|
||||
}
|
||||
GRAPH_UPDATE("Inlining method '", function->name(), "' to ", *cur);
|
||||
GRAPH_UPDATE("Function body: ", *function->optimized_graph());
|
||||
inlineCallTo(cur, function);
|
||||
|
@ -29,17 +29,11 @@ void initPythonCustomClassBindings(PyObject* module) {
|
||||
// directly, we need a wrapper that at least returns the instance
|
||||
// rather than the None return value from __init__
|
||||
m.def("_get_custom_class_python_wrapper", [](const std::string& qualname) {
|
||||
auto cu = classCU();
|
||||
std::string full_qualname = "__torch__.torch.classes." + qualname;
|
||||
c10::NamedTypePtr named_type = cu->get_type(full_qualname);
|
||||
if (!named_type || !named_type->cast<ClassType>()) {
|
||||
std::stringstream err;
|
||||
err << "Class " << qualname << " not registered!";
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
auto named_type = getCustomClass(full_qualname);
|
||||
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
|
||||
return ScriptClass(
|
||||
c10::StrongTypePtr(std::move(cu), std::move(class_type)));
|
||||
return ScriptClass(c10::StrongTypePtr(
|
||||
std::shared_ptr<script::CompilationUnit>(), std::move(class_type)));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -779,7 +779,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::object state;
|
||||
std::string qualname;
|
||||
std::tie(state, qualname) = state_tup;
|
||||
auto class_type = classCU()->get_class(qualname);
|
||||
auto class_type = getCustomClass(qualname);
|
||||
TORCH_CHECK(
|
||||
class_type,
|
||||
"Tried to deserialize class ",
|
||||
@ -789,7 +789,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||
"sure the appropriate code is linked.");
|
||||
|
||||
auto self = script::Object(c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(classCU(), class_type), 1));
|
||||
c10::StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit>(),
|
||||
class_type),
|
||||
1));
|
||||
if (auto setstate_method = self.find_method("__setstate__")) {
|
||||
auto setstate_schema = setstate_method->function().getSchema();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
@ -965,6 +965,32 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
}
|
||||
|
||||
void runBuiltinFunction(Stack &stack, Function *fn, ActiveFrame *af) {
|
||||
// BuiltinOpFunction directly invokes a void(Stack&) to implement
|
||||
// custom C++ classes. Call run() here with the stack, and we will
|
||||
// get the results from that C++ method back in the stack. Advance
|
||||
// the PC by 1 without adding any new frame.
|
||||
fn->run(stack);
|
||||
++af->pc;
|
||||
}
|
||||
|
||||
void runGraphFunction(Stack &stack, Function *fn, ActiveFrame *af) {
|
||||
const Code& code =
|
||||
// consider passing
|
||||
// `frames.back().function->remaining_bailout_depth_` into
|
||||
// `get_executor().getPlanFor()` to propagate caller's depth
|
||||
// restrictions onto children while this strategy has a
|
||||
// potential to reduce the number of compilations for too
|
||||
// dynamic callers we might miss opportunities where a caller is
|
||||
// dynamic but a callee gets stable arguments
|
||||
fn->get_executor()
|
||||
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
|
||||
.code;
|
||||
frames.back().pc = af->pc + 1;
|
||||
enterFrame(code, stack.size() - code.num_inputs());
|
||||
*af = ActiveFrame(frames.back());
|
||||
}
|
||||
|
||||
bool runImpl(Stack& stack) {
|
||||
// if we have never run before, then we might have to return the
|
||||
// stack when we suspend, record where it starts so we return the right
|
||||
@ -1062,21 +1088,12 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
} break;
|
||||
case CALL: {
|
||||
const Code& code =
|
||||
// consider passing
|
||||
// `frames.back().function->remaining_bailout_depth_` into
|
||||
// `get_executor().getPlanFor()` to propagate caller's depth
|
||||
// restrictions onto children while this strategy has a
|
||||
// potential to reduce the number of compilations for too
|
||||
// dynamic callers we might miss opportunities where a caller is
|
||||
// dynamic but a callee gets stable arguments
|
||||
af.functions[inst.X]
|
||||
->get_executor()
|
||||
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
|
||||
.code;
|
||||
frames.back().pc = af.pc + 1;
|
||||
enterFrame(code, stack.size() - code.num_inputs());
|
||||
af = ActiveFrame(frames.back());
|
||||
Function* fn = af.functions[inst.X];
|
||||
if (!fn->isGraphFunction()) {
|
||||
runBuiltinFunction(stack, fn, &af);
|
||||
} else {
|
||||
runGraphFunction(stack, fn, &af);
|
||||
}
|
||||
} break;
|
||||
case INTERFACE_CALL: {
|
||||
// note the hash table lookup to find the function
|
||||
@ -1095,13 +1112,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
.toObject()
|
||||
->type()
|
||||
->getMethod(af.constants[inst.X].toStringRef());
|
||||
const Code& code =
|
||||
function->get_executor()
|
||||
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
|
||||
.code;
|
||||
frames.back().pc = af.pc + 1;
|
||||
enterFrame(code, stack.size() - inst.N);
|
||||
af = ActiveFrame(frames.back());
|
||||
if (!function->isGraphFunction()) {
|
||||
runBuiltinFunction(stack, function, &af);
|
||||
} else {
|
||||
runGraphFunction(stack, function, &af);
|
||||
}
|
||||
} break;
|
||||
case RET:
|
||||
if (frames.size() > 1) {
|
||||
|
@ -1,11 +1,11 @@
|
||||
#include "import_source.h"
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <torch/csrc/jit/api/custom_class.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
@ -1147,6 +1147,7 @@ struct PythonPrintImpl {
|
||||
void printFunction(
|
||||
const Function& func,
|
||||
bool print_first_argument_type = true) {
|
||||
TORCH_INTERNAL_ASSERT(func.isGraphFunction());
|
||||
const FunctionSchema& schema = func.getSchema();
|
||||
Graph& graph = *func.graph();
|
||||
used_names_.clear(); // each graph can reuse local names
|
||||
@ -1192,6 +1193,16 @@ struct PythonPrintImpl {
|
||||
enforce_importable_(enforce_importable) {}
|
||||
|
||||
void printClass(const ClassTypePtr& classType) {
|
||||
// If any of the methods are not Graph funtions, this indicates that
|
||||
// this class is a custom-bound C++ class. Skip serialization
|
||||
// of this class, we will depend on the ClassType being defined
|
||||
// in the target process.
|
||||
for (auto& method : classType->methods()) {
|
||||
if (!method->isGraphFunction()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_module = classType->is_module();
|
||||
body_ << "class " << classType->name()->name();
|
||||
if (is_module) {
|
||||
|
@ -1,28 +1,31 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/builtin_function.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <torch/csrc/jit/api/custom_class.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <torch/custom_class_detail.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
|
||||
TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
|
||||
|
||||
TORCH_API bool isCustomClass(const c10::IValue& v);
|
||||
|
||||
template <class... Types>
|
||||
detail::types<void, Types...> init() {
|
||||
return detail::types<void, Types...>{};
|
||||
@ -47,7 +50,7 @@ class class_ {
|
||||
|
||||
std::string className;
|
||||
std::string qualClassName;
|
||||
ClassTypePtr classTypePtr;
|
||||
at::ClassTypePtr classTypePtr;
|
||||
|
||||
const std::string parentModule = "classes";
|
||||
const std::string topModule = "__torch__.torch";
|
||||
@ -58,16 +61,17 @@ class class_ {
|
||||
|
||||
// We currently represent custom classes as torchscript classes with a
|
||||
// capsule attribute
|
||||
classTypePtr =
|
||||
ClassType::create(c10::QualifiedName(qualClassName), classCU());
|
||||
classTypePtr->addAttribute("capsule", CapsuleType::get());
|
||||
classTypePtr = at::ClassType::create(
|
||||
c10::QualifiedName(qualClassName),
|
||||
std::weak_ptr<script::CompilationUnit>());
|
||||
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
||||
|
||||
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
|
||||
c10::StrongTypePtr(classCU(), classTypePtr)});
|
||||
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
|
||||
c10::StrongTypePtr(classCU(), classTypePtr)});
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
{typeid(c10::intrusive_ptr<CurClass>).name(), classTypePtr});
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
{typeid(c10::tagged_capsule<CurClass>).name(), classTypePtr});
|
||||
|
||||
classCU()->register_type(classTypePtr);
|
||||
registerCustomClass(classTypePtr);
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
@ -163,40 +167,26 @@ class class_ {
|
||||
private:
|
||||
template <typename Func>
|
||||
void defineMethod(std::string name, Func func) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto qualFuncName = className + "::" + name;
|
||||
ensure_c10_registerer_defined();
|
||||
registeredOps().push_back(
|
||||
torch::RegisterOperators().op(qualFuncName, std::move(func)));
|
||||
auto func_symbol = c10::Symbol::fromQualString(qualFuncName);
|
||||
auto ops = torch::jit::getAllOperatorsFor(func_symbol);
|
||||
TORCH_CHECK(ops.size() == 1);
|
||||
auto &schema = ops[0]->schema();
|
||||
auto qualMethodName = qualClassName + "." + name;
|
||||
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
||||
|
||||
for (const auto& arg : schema.arguments()) {
|
||||
graph->addInput()->setType(arg.type());
|
||||
}
|
||||
auto wrapped_func = [func = std::move(func)](Stack& stack) mutable -> void {
|
||||
// TODO: we need to figure out how to profile calls to custom functions
|
||||
// like this! Currently can't do it because the profiler stuff is in
|
||||
// libtorch and not ATen
|
||||
using RetType =
|
||||
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
||||
detail::BoxedProxy<RetType, Func>()(stack, func);
|
||||
};
|
||||
auto method = std::make_shared<BuiltinOpFunction>(
|
||||
qualMethodName, std::move(schema), std::move(wrapped_func));
|
||||
|
||||
auto opCall = graph->insertNode(graph->create(
|
||||
func_symbol, graph->inputs(), schema.returns().size()));
|
||||
Value* res;
|
||||
if (schema.returns().size() > 1) {
|
||||
const auto& returns = schema.returns();
|
||||
size_t op_invocation_idx = 0;
|
||||
for (const auto& ret : returns) {
|
||||
opCall->output(op_invocation_idx++)->setType(ret.type());
|
||||
}
|
||||
res = graph->insertNode(graph->createTuple(opCall->outputs()))->output();
|
||||
} else if (schema.returns().size() == 1) {
|
||||
const auto& returns = schema.returns();
|
||||
res = opCall->output()->setType(returns[0].type());
|
||||
} else {
|
||||
res = graph->insertConstant(IValue())->setType(NoneType::get());
|
||||
}
|
||||
graph->registerOutput(res);
|
||||
|
||||
auto method = classCU()->create_function(qualClassName + "." + name, graph);
|
||||
classTypePtr->addMethod(method);
|
||||
// Register the method here to keep the Method alive.
|
||||
// ClassTypes do not hold ownership of their methods (normally it
|
||||
// those are held by the CompilationUnit), so we need a proxy for
|
||||
// that behavior here.
|
||||
registerCustomClassMethod(method);
|
||||
classTypePtr->addMethod(method.get());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/boxing/kernel_functor.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
|
||||
@ -60,10 +62,68 @@ Func wrap_func(Func f) {
|
||||
return f;
|
||||
}
|
||||
|
||||
template <
|
||||
class Functor,
|
||||
bool AllowDeprecatedTypes,
|
||||
size_t... ivalue_arg_indices>
|
||||
typename c10::guts::infer_function_traits_t<Functor>::return_type
|
||||
call_torchbind_method_from_stack(
|
||||
Functor& functor,
|
||||
Stack& stack,
|
||||
std::index_sequence<ivalue_arg_indices...>) {
|
||||
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
||||
// be unused and we have to silence the compiler warning.
|
||||
|
||||
constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
|
||||
|
||||
using IValueArgTypes =
|
||||
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
|
||||
return (functor)(c10::detail::ivalue_to_arg<
|
||||
std::remove_cv_t<std::remove_reference_t<
|
||||
c10::guts::typelist::
|
||||
element_t<ivalue_arg_indices, IValueArgTypes>>>,
|
||||
AllowDeprecatedTypes>(std::move(
|
||||
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
|
||||
}
|
||||
|
||||
template <class Functor, bool AllowDeprecatedTypes>
|
||||
typename c10::guts::infer_function_traits_t<Functor>::return_type
|
||||
call_torchbind_method_from_stack(Functor& functor, Stack& stack) {
|
||||
constexpr size_t num_ivalue_args =
|
||||
c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
|
||||
return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
|
||||
functor, stack, std::make_index_sequence<num_ivalue_args>());
|
||||
}
|
||||
|
||||
template <class RetType, class Func>
|
||||
struct BoxedProxy;
|
||||
|
||||
template <class RetType, class Func>
|
||||
struct BoxedProxy {
|
||||
void operator()(Stack& stack, Func& func) {
|
||||
auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
|
||||
constexpr size_t num_ivalue_args =
|
||||
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
|
||||
torch::jit::drop(stack, num_ivalue_args);
|
||||
stack.emplace_back(c10::ivalue::from(std::move(retval)));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Func>
|
||||
struct BoxedProxy<void, Func> {
|
||||
void operator()(Stack& stack, Func& func) {
|
||||
call_torchbind_method_from_stack<Func, false>(func, stack);
|
||||
constexpr size_t num_ivalue_args =
|
||||
c10::guts::infer_function_traits_t<Func>::number_of_parameters;
|
||||
torch::jit::drop(stack, num_ivalue_args);
|
||||
stack.emplace_back(IValue());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
TORCH_API std::vector<c10::RegisterOperators>& registeredOps();
|
||||
TORCH_API std::shared_ptr<script::CompilationUnit>& classCU();
|
||||
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
|
||||
TORCH_API void registerCustomClassMethod(std::shared_ptr<Function> method);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user