mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Virtualize Function (#33921)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33921 **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.intern.facebook.com/intern/diff/D20153092/)! Test Plan: Imported from OSS Differential Revision: D20177227 Pulled By: jamesr66a fbshipit-source-id: 87f3e484c4f873d60f76f50f6789c1b4a73bdfde
This commit is contained in:
committed by
Facebook Github Bot
parent
bb1114258c
commit
60e8615a6d
62
aten/src/ATen/core/function.h
Normal file
62
aten/src/ATen/core/function.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace c10 {
|
||||
struct FunctionSchema;
|
||||
};
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct Graph;
|
||||
struct GraphExecutor;
|
||||
|
||||
using Stack = std::vector<at::IValue>;
|
||||
using Kwargs = std::unordered_map<std::string, at::IValue>;
|
||||
struct RecursiveMethodCallError : public std::exception {};
|
||||
|
||||
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// A Function is a pure Graph with no implicit `self` object bound.
|
||||
// It contains schema information, and the executor that manages the
|
||||
// execution of the function. script::Method is a wrapper around a
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
virtual void run(Stack& stack) = 0;
|
||||
|
||||
virtual void run(Stack&& stack) = 0;
|
||||
|
||||
virtual at::IValue operator()(
|
||||
std::vector<at::IValue> stack,
|
||||
const Kwargs& kwargs = Kwargs()) = 0;
|
||||
|
||||
virtual const c10::QualifiedName& qualname() const = 0;
|
||||
|
||||
virtual const std::string& name() const = 0;
|
||||
|
||||
// if this isn't yet defined, run its method_creator function
|
||||
virtual void ensure_defined() = 0;
|
||||
|
||||
virtual std::shared_ptr<Graph> graph() const = 0;
|
||||
|
||||
virtual std::shared_ptr<Graph> optimized_graph() const = 0;
|
||||
|
||||
virtual GraphExecutor& get_executor() = 0;
|
||||
|
||||
virtual const c10::FunctionSchema& getSchema() const = 0;
|
||||
|
||||
virtual size_t num_inputs() const = 0;
|
||||
|
||||
virtual void check_single_output() = 0;
|
||||
|
||||
virtual std::string pretty_print_schema() const = 0;
|
||||
|
||||
virtual Function& setSchema(c10::FunctionSchema schema) = 0;
|
||||
|
||||
virtual ~Function() {}
|
||||
};
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -17,7 +17,6 @@
|
||||
struct ClassType;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
@ -828,7 +827,6 @@ struct CAFFE2_API RRefType
|
||||
};
|
||||
|
||||
|
||||
using ::torch::jit::Function;
|
||||
struct NamedType;
|
||||
using NamedTypePtr = std::shared_ptr<NamedType>;
|
||||
|
||||
@ -1060,9 +1058,8 @@ struct CAFFE2_API StringType : public Type {
|
||||
|
||||
struct FunctionType;
|
||||
using FunctionTypePtr = std::shared_ptr<FunctionType>;
|
||||
using ::torch::jit::Function;
|
||||
struct CAFFE2_API FunctionType : public NamedType {
|
||||
static FunctionTypePtr create(Function* function) {
|
||||
static FunctionTypePtr create(torch::jit::Function* function) {
|
||||
return FunctionTypePtr(
|
||||
new FunctionType(function)); // NOLINT(modernize-make-shared)
|
||||
}
|
||||
@ -1079,14 +1076,14 @@ struct CAFFE2_API FunctionType : public NamedType {
|
||||
std::string python_str() const override {
|
||||
return "Function";
|
||||
}
|
||||
Function* function() const {
|
||||
torch::jit::Function* function() const {
|
||||
return function_;
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::FunctionType;
|
||||
|
||||
private:
|
||||
FunctionType(Function* function);
|
||||
Function* function_;
|
||||
FunctionType(torch::jit::Function* function);
|
||||
torch::jit::Function* function_;
|
||||
};
|
||||
|
||||
struct NoneType;
|
||||
@ -1516,7 +1513,7 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
return n.qualifiedName();
|
||||
}
|
||||
|
||||
const std::vector<Function*>& methods() const;
|
||||
const std::vector<torch::jit::Function*>& methods() const;
|
||||
|
||||
TypePtr findAttribute(const std::string& name) const {
|
||||
TORCH_INTERNAL_ASSERT(attributeNames_.size() == attributeTypes_.size());
|
||||
@ -1738,8 +1735,8 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
return parameterSlots_->at(slot);
|
||||
}
|
||||
|
||||
void addMethod(Function* method);
|
||||
Function* getMethod(const std::string& name) const;
|
||||
void addMethod(torch::jit::Function* method);
|
||||
torch::jit::Function* getMethod(const std::string& name) const;
|
||||
|
||||
// [Internal Only] Remove method from the ClassType
|
||||
// caller is responsible to make sure the modification is safe:
|
||||
@ -1791,14 +1788,13 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
std::shared_ptr<std::vector<bool>> parameterSlots_;
|
||||
|
||||
// List of methods associated with this class.
|
||||
std::vector<Function*> methods_;
|
||||
std::vector<torch::jit::Function*> methods_;
|
||||
|
||||
};
|
||||
|
||||
struct InterfaceType;
|
||||
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
|
||||
using ::torch::jit::script::CompilationUnit;
|
||||
using ::torch::jit::Function;
|
||||
|
||||
// Interfaces are a list of abstract methods that a class might meet.
|
||||
// If a class provides those methods, it implicitly meets the interface.
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
@ -685,6 +686,101 @@ InterfaceTypePtr InterfaceType::create(QualifiedName qualifiedName, bool is_modu
|
||||
new InterfaceType(std::move(qualifiedName), is_module));
|
||||
}
|
||||
|
||||
void ClassType::addMethod(torch::jit::Function* method) {
|
||||
TORCH_CHECK(
|
||||
getMethod(method->name()) == nullptr,
|
||||
"Can't redefine method: ",
|
||||
method->name(),
|
||||
" on class: ",
|
||||
python_str());
|
||||
methods_.push_back(method);
|
||||
}
|
||||
|
||||
torch::jit::Function* ClassType::getMethod(const std::string& name) const {
|
||||
for (auto method : methods_) {
|
||||
if (name == method->name()) {
|
||||
return method;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ClassType::unsafeRemoveMethod(const std::string& name) {
|
||||
size_t slot = 0;
|
||||
for (auto method : methods_) {
|
||||
if (method->name() == name) {
|
||||
methods_.erase(methods_.begin() + slot);
|
||||
return;
|
||||
}
|
||||
slot++;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Can't delete undefined method ",
|
||||
name,
|
||||
" on class: ",
|
||||
python_str());
|
||||
}
|
||||
|
||||
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
|
||||
auto ptr = ClassType::create(name(), compilation_unit_);
|
||||
AT_ASSERT(numAttributes() == refined_slots.size());
|
||||
for (size_t i = 0; i < attributeNames_.size(); ++i) {
|
||||
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
|
||||
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
|
||||
}
|
||||
// Copy methods over
|
||||
for (const auto& method : methods()) {
|
||||
ptr->addMethod(method);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
// to improve performance, this check can be cached
|
||||
if (auto iface = rhs->cast<InterfaceType>()) {
|
||||
// ClassType is not a subtype of InterfaceType if the InterfaceType is a
|
||||
// Module Interface Type but the Class Type is not a Module Class Type
|
||||
if (!is_module() && iface->is_module()) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs->python_str()
|
||||
<< "' , only ScriptModule class can be subtype of module"
|
||||
<< " interface.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
for (const FunctionSchema& schema : iface->methods()) {
|
||||
auto self_method = getMethod(schema.name());
|
||||
if (!self_method) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' does not have method '"
|
||||
<< schema.name() << "' but '" << rhs->python_str()
|
||||
<< "' does.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!self_method->getSchema().isSubtypeOf(
|
||||
schema, /*is_method=*/true, why_not)) {
|
||||
if (why_not) {
|
||||
*why_not << "Method on class '" << python_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs->python_str() << "' (2)\n"
|
||||
<< " (1) " << self_method->getSchema() << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
|
||||
FunctionType::FunctionType(torch::jit::Function* function)
|
||||
: NamedType(TypeKind::FunctionType, function->qualname()),
|
||||
function_(function) {}
|
||||
|
||||
bool InterfaceType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
// to improve performance this check can be cached
|
||||
if (auto iface = rhs->cast<InterfaceType>()) {
|
||||
@ -740,7 +836,26 @@ InterfaceType::InterfaceType(QualifiedName name, bool is_module)
|
||||
|
||||
InterfaceType::~InterfaceType() = default;
|
||||
|
||||
const std::vector<Function*>& ClassType::methods() const {
|
||||
ClassTypePtr ClassType::create(
|
||||
c10::optional<QualifiedName> qualifiedName,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module) {
|
||||
return ClassTypePtr(
|
||||
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
|
||||
}
|
||||
|
||||
ClassType::ClassType(
|
||||
c10::optional<QualifiedName> name,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module)
|
||||
: NamedType(TypeKind::ClassType, std::move(name)),
|
||||
compilation_unit_(std::move(cu)) {
|
||||
if (is_module) {
|
||||
parameterSlots_ = std::make_shared<std::vector<bool>>();
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<torch::jit::Function*>& ClassType::methods() const {
|
||||
return methods_;
|
||||
}
|
||||
|
||||
|
@ -449,7 +449,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/schema_matching.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/script_type_parser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/sugared_value.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/ir/class_type.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/parser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/builtin_functions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/frontend/canonicalize_modified_loop.cpp
|
||||
@ -469,7 +468,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/executor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/codegen.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/fallback.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/api/function.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/api/function_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/runtime/vararg_functions.cpp
|
||||
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
|
||||
|
@ -81,6 +81,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/ir/attributes.cpp",
|
||||
"torch/csrc/jit/runtime/argument_spec.cpp",
|
||||
"torch/csrc/jit/ir/constants.cpp",
|
||||
"torch/csrc/jit/api/function_impl.cpp",
|
||||
"torch/csrc/jit/api/custom_class.cpp",
|
||||
"torch/csrc/jit/ir/node_hashing.cpp",
|
||||
"torch/csrc/jit/ir/type_hashing.cpp",
|
||||
@ -163,7 +164,6 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/frontend/script_type_parser.cpp",
|
||||
"torch/csrc/jit/frontend/sugared_value.cpp",
|
||||
"torch/csrc/jit/frontend/schema_matching.cpp",
|
||||
"torch/csrc/jit/ir/class_type.cpp",
|
||||
"torch/csrc/jit/frontend/parser.cpp",
|
||||
"torch/csrc/jit/runtime/jit_exception.cpp",
|
||||
"torch/csrc/jit/serialization/source_range_serialization.cpp",
|
||||
@ -183,7 +183,6 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/codegen/fuser/fallback.cpp",
|
||||
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
|
||||
"torch/csrc/jit/codegen/fuser/interface.cpp",
|
||||
"torch/csrc/jit/api/function.cpp",
|
||||
"torch/csrc/jit/runtime/vararg_functions.cpp",
|
||||
"torch/csrc/jit/python/update_graph_executor_opt.cpp",
|
||||
"torch/csrc/jit/mobile/function.cpp",
|
||||
|
@ -30,7 +30,7 @@ namespace {
|
||||
struct PythonTypeResolver : public jit::script::Resolver {
|
||||
std::shared_ptr<jit::script::SugaredValue> resolveValue(
|
||||
const std::string& /* unused */,
|
||||
Function& /* unused */,
|
||||
torch::jit::Function& /* unused */,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "RPC Type resolver does not need to resolve value");
|
||||
|
@ -1,9 +1,10 @@
|
||||
#pragma once
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
@ -117,7 +118,7 @@ struct TORCH_API CompilationUnit {
|
||||
if (shouldMangle) {
|
||||
name = mangle(name);
|
||||
}
|
||||
auto fn = torch::make_unique<Function>(
|
||||
auto fn = torch::make_unique<GraphFunction>(
|
||||
std::move(name), std::move(graph), nullptr);
|
||||
auto ret = fn.get();
|
||||
register_function(std::move(fn));
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
@ -6,9 +6,9 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
FunctionSchema defaultSchemaFor(const Function& function) {
|
||||
std::vector<Argument> args;
|
||||
std::vector<Argument> returns;
|
||||
c10::FunctionSchema defaultSchemaFor(const Function& function) {
|
||||
std::vector<c10::Argument> args;
|
||||
std::vector<c10::Argument> returns;
|
||||
Graph& g = *function.graph();
|
||||
size_t num_inputs = function.num_inputs();
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
@ -24,19 +24,19 @@ FunctionSchema defaultSchemaFor(const Function& function) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void placeholderCreator(Function&) {
|
||||
void placeholderCreator(GraphFunction&) {
|
||||
throw RecursiveMethodCallError();
|
||||
}
|
||||
|
||||
void Function::run(Stack& stack) {
|
||||
void GraphFunction::run(Stack& stack) {
|
||||
get_executor().run(stack);
|
||||
}
|
||||
|
||||
void Function::run(Stack&& stack) {
|
||||
void GraphFunction::run(Stack&& stack) {
|
||||
run(stack);
|
||||
}
|
||||
|
||||
IValue Function::operator()(
|
||||
IValue GraphFunction::operator()(
|
||||
std::vector<IValue> stack,
|
||||
const Kwargs& kwargs) {
|
||||
getSchema().checkAndNormalizeInputs(stack, kwargs);
|
||||
@ -44,7 +44,7 @@ IValue Function::operator()(
|
||||
return stack.front();
|
||||
}
|
||||
|
||||
void Function::ensure_defined() {
|
||||
void GraphFunction::ensure_defined() {
|
||||
if (function_creator_) {
|
||||
auto creator = function_creator_;
|
||||
function_creator_ = placeholderCreator;
|
||||
@ -54,9 +54,9 @@ void Function::ensure_defined() {
|
||||
check_single_output();
|
||||
}
|
||||
|
||||
const FunctionSchema& Function::getSchema() const {
|
||||
const c10::FunctionSchema& GraphFunction::getSchema() const {
|
||||
if (schema_ == nullptr) {
|
||||
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
|
||||
schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
|
||||
}
|
||||
return *schema_;
|
||||
}
|
@ -1,43 +1,34 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
using Kwargs = std::unordered_map<std::string, IValue>;
|
||||
struct RecursiveMethodCallError : public std::exception {};
|
||||
|
||||
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// A Function is a pure Graph with no implicit `self` object bound.
|
||||
// It contains schema information, and the executor that manages the
|
||||
// execution of the function. script::Method is a wrapper around a
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
Function(
|
||||
struct TORCH_API GraphFunction : public Function {
|
||||
GraphFunction(
|
||||
c10::QualifiedName name,
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::function<void(Function&)> function_creator)
|
||||
std::function<void(GraphFunction&)> function_creator)
|
||||
: name_(std::move(name)),
|
||||
graph_(std::move(graph)),
|
||||
function_creator_(std::move(function_creator)) {}
|
||||
|
||||
void run(Stack& stack);
|
||||
void run(Stack& stack) override;
|
||||
|
||||
void run(Stack&& stack);
|
||||
void run(Stack&& stack) override;
|
||||
|
||||
IValue operator()(
|
||||
std::vector<IValue> stack,
|
||||
const Kwargs& kwargs = Kwargs());
|
||||
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
|
||||
override;
|
||||
|
||||
std::shared_ptr<Graph> graph() const {
|
||||
std::shared_ptr<Graph> graph() const override {
|
||||
return graph_;
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> optimized_graph() const {
|
||||
std::shared_ptr<Graph> optimized_graph() const override {
|
||||
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
|
||||
if (optimized_graph_) {
|
||||
return *optimized_graph_;
|
||||
@ -47,29 +38,29 @@ struct TORCH_API Function {
|
||||
return *optimized_graph_;
|
||||
}
|
||||
|
||||
const c10::QualifiedName& qualname() const {
|
||||
const c10::QualifiedName& qualname() const override {
|
||||
return name_;
|
||||
}
|
||||
|
||||
const std::string& name() const {
|
||||
const std::string& name() const override {
|
||||
return name_.name();
|
||||
}
|
||||
|
||||
// if this isn't yet defined, run its method_creator function
|
||||
void ensure_defined();
|
||||
void ensure_defined() override;
|
||||
|
||||
size_t num_inputs() const {
|
||||
size_t num_inputs() const override {
|
||||
return graph()->inputs().size();
|
||||
}
|
||||
|
||||
Function& setSchema(FunctionSchema schema) {
|
||||
Function& setSchema(FunctionSchema schema) override {
|
||||
schema_ = make_unique<FunctionSchema>(std::move(schema));
|
||||
return *this;
|
||||
}
|
||||
|
||||
const FunctionSchema& getSchema() const;
|
||||
const FunctionSchema& getSchema() const override;
|
||||
|
||||
std::string pretty_print_schema() const {
|
||||
std::string pretty_print_schema() const override {
|
||||
AT_ASSERT(schema_);
|
||||
std::stringstream ss;
|
||||
ss << *schema_;
|
||||
@ -82,18 +73,18 @@ struct TORCH_API Function {
|
||||
|
||||
bool is_optimized() const {
|
||||
AT_WARN(
|
||||
"Function::is_optimized() is deprecated and always returns true. "
|
||||
"GraphFunction::is_optimized() is deprecated and always returns true. "
|
||||
"Please use getGraphExecutorOptimize()");
|
||||
return true;
|
||||
}
|
||||
|
||||
void check_single_output() {
|
||||
void check_single_output() override {
|
||||
TORCH_CHECK(
|
||||
graph()->outputs().size() == 1,
|
||||
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
|
||||
}
|
||||
|
||||
GraphExecutor& get_executor() {
|
||||
GraphExecutor& get_executor() override {
|
||||
ensure_defined();
|
||||
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
|
||||
if (executor_) {
|
||||
@ -114,12 +105,11 @@ struct TORCH_API Function {
|
||||
// here.
|
||||
mutable c10::optional<std::shared_ptr<Graph>> optimized_graph_;
|
||||
|
||||
// Functions are invokable from multiple threads, so this lock needs to be
|
||||
// held when we're initializing graph executor for the first time or computing
|
||||
// the optimized graph.
|
||||
// We're using reentrant mutex so that we don't need to worry about causing a
|
||||
// deadlock by calling one method from another (e.g. optimized_graph() from
|
||||
// get_executor()).
|
||||
// GraphFunctions are invokable from multiple threads, so this lock needs to
|
||||
// be held when we're initializing graph executor for the first time or
|
||||
// computing the optimized graph. We're using reentrant mutex so that we don't
|
||||
// need to worry about causing a deadlock by calling one method from another
|
||||
// (e.g. optimized_graph() from get_executor()).
|
||||
mutable std::recursive_mutex compile_mutex;
|
||||
|
||||
GraphExecutor executor_; // for execution
|
||||
@ -127,7 +117,7 @@ struct TORCH_API Function {
|
||||
// an optional function that actually creates the method when
|
||||
// ensure_defined() is called. This is used by the compiler so
|
||||
// that it can construct methods out of order
|
||||
std::function<void(Function&)> function_creator_;
|
||||
std::function<void(GraphFunction&)> function_creator_;
|
||||
|
||||
// if absent, then we generate a default schema based on the graph
|
||||
// mutable because getSchema caches the default schema if one is requested
|
@ -1,7 +1,8 @@
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/api/method.h>
|
||||
|
||||
@ -94,7 +95,7 @@ struct TORCH_API Object {
|
||||
}
|
||||
|
||||
const std::vector<Method> get_methods() const {
|
||||
return fmap(type()->methods(), [&](Function* func) {
|
||||
return c10::fmap(type()->methods(), [&](Function* func) {
|
||||
return Method(_ivalue(), func);
|
||||
});
|
||||
}
|
||||
|
@ -1,10 +1,13 @@
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <torch/csrc/jit/testing/hooks_for_testing.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
|
||||
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/passes/canonicalize.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/constant_propagation.h>
|
||||
@ -13,11 +16,9 @@
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/passes/lift_closures.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
|
||||
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <torch/csrc/jit/frontend/script_type_parser.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/testing/hooks_for_testing.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
|
||||
@ -3528,7 +3529,7 @@ std::unique_ptr<Function> CompilationUnit::define(
|
||||
name = mangle(name);
|
||||
}
|
||||
}
|
||||
auto fn = torch::make_unique<Function>(
|
||||
auto fn = torch::make_unique<GraphFunction>(
|
||||
std::move(name), std::make_shared<Graph>(), creator);
|
||||
if (self) {
|
||||
// Register this as a method on `self`'s type
|
||||
|
@ -1,129 +0,0 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
ClassTypePtr ClassType::create(
|
||||
c10::optional<QualifiedName> qualifiedName,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module) {
|
||||
return ClassTypePtr(
|
||||
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
|
||||
}
|
||||
|
||||
ClassType::ClassType(
|
||||
c10::optional<QualifiedName> name,
|
||||
std::weak_ptr<CompilationUnit> cu,
|
||||
bool is_module)
|
||||
: NamedType(TypeKind::ClassType, std::move(name)),
|
||||
compilation_unit_(std::move(cu)) {
|
||||
if (is_module) {
|
||||
parameterSlots_ = std::make_shared<std::vector<bool>>();
|
||||
}
|
||||
}
|
||||
|
||||
void ClassType::addMethod(Function* method) {
|
||||
TORCH_CHECK(
|
||||
getMethod(method->name()) == nullptr,
|
||||
"Can't redefine method: ",
|
||||
method->name(),
|
||||
" on class: ",
|
||||
python_str());
|
||||
methods_.push_back(method);
|
||||
}
|
||||
|
||||
Function* ClassType::getMethod(const std::string& name) const {
|
||||
for (auto method : methods_) {
|
||||
if (name == method->name()) {
|
||||
return method;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ClassType::unsafeRemoveMethod(const std::string& name) {
|
||||
size_t slot = 0;
|
||||
for (auto method : methods_) {
|
||||
if (method->name() == name) {
|
||||
methods_.erase(methods_.begin() + slot);
|
||||
return;
|
||||
}
|
||||
slot++;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Can't delete undefined method ",
|
||||
name,
|
||||
" on class: ",
|
||||
python_str());
|
||||
}
|
||||
|
||||
#ifndef USE_MOBILE_CLASSTYPE
|
||||
|
||||
// This file exists because we need to reference module.h, which we can't from
|
||||
// c10. Sigh...
|
||||
FunctionType::FunctionType(Function* function)
|
||||
: NamedType(TypeKind::FunctionType, function->qualname()),
|
||||
function_(function) {}
|
||||
|
||||
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
|
||||
auto ptr = ClassType::create(name(), compilation_unit_);
|
||||
AT_ASSERT(numAttributes() == refined_slots.size());
|
||||
for (size_t i = 0; i < attributeNames_.size(); ++i) {
|
||||
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
|
||||
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
|
||||
}
|
||||
// Copy methods over
|
||||
for (const auto& method : methods()) {
|
||||
ptr->addMethod(method);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
// to improve performance, this check can be cached
|
||||
if (auto iface = rhs->cast<InterfaceType>()) {
|
||||
// ClassType is not a subtype of InterfaceType if the InterfaceType is a
|
||||
// Module Interface Type but the Class Type is not a Module Class Type
|
||||
if (!is_module() && iface->is_module()) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' is not a subtype of "
|
||||
<< "the module interface '" << rhs->python_str()
|
||||
<< "' , only ScriptModule class can be subtype of module"
|
||||
<< " interface.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
for (const FunctionSchema& schema : iface->methods()) {
|
||||
auto self_method = getMethod(schema.name());
|
||||
if (!self_method) {
|
||||
if (why_not) {
|
||||
*why_not << "Class '" << python_str() << "' does not have method '"
|
||||
<< schema.name() << "' but '" << rhs->python_str()
|
||||
<< "' does.\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!self_method->getSchema().isSubtypeOf(
|
||||
schema, /*is_method=*/true, why_not)) {
|
||||
if (why_not) {
|
||||
*why_not << "Method on class '" << python_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs->python_str() << "' (2)\n"
|
||||
<< " (1) " << self_method->getSchema() << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
#else
|
||||
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
|
||||
return Type::isSubtypeOfExt(rhs, why_not);
|
||||
}
|
||||
#endif
|
||||
} // namespace c10
|
@ -2,12 +2,13 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
#include <torch/csrc/jit/frontend/schema_matching.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <torch/csrc/jit/ir/scope.h>
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <ATen/core/function.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
@ -6,11 +6,12 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -76,7 +77,7 @@ bool is_enabled(const char *cfname, JitLoggingLevels level) {
|
||||
// we won't have access to an original function, so we have to construct
|
||||
// a dummy function to give to PythonPrint
|
||||
std::string log_function(const std::shared_ptr<torch::jit::Graph> &graph) {
|
||||
torch::jit::Function func("source_dump", graph, nullptr);
|
||||
torch::jit::GraphFunction func("source_dump", graph, nullptr);
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<c10::NamedTypePtr> deps;
|
||||
PythonPrint pp(tensors, deps, false);
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <ATen/core/function.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
|
@ -7,16 +7,17 @@
|
||||
#include <torch/csrc/autograd/edge.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/bailout_graph.h>
|
||||
#include <torch/csrc/jit/runtime/exception_message.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/passes/bailout_graph.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
#include <exception>
|
||||
@ -654,13 +655,13 @@ struct CodeImpl {
|
||||
TORCH_INTERNAL_ASSERT(bailout_index >= 0);
|
||||
|
||||
auto build_bailout_graph = [bailout_index,
|
||||
unoptimized_graph](Function &func) {
|
||||
unoptimized_graph](Function& func) {
|
||||
|
||||
BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
|
||||
};
|
||||
|
||||
auto empty_graph = std::make_shared<Graph>();
|
||||
auto func = torch::make_unique<Function>(
|
||||
auto func = torch::make_unique<GraphFunction>(
|
||||
"bailout", empty_graph, build_bailout_graph);
|
||||
function_table_.emplace_back(func.get());
|
||||
bailout_functions_.emplace_back(std::move(func));
|
||||
|
@ -3,9 +3,9 @@
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <aten/src/ATen/quantized/Quantizer.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
|
@ -3,7 +3,7 @@
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#endif
|
||||
#include <torch/csrc/jit/api/function.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <string>
|
||||
#include "unpickler.h"
|
||||
|
Reference in New Issue
Block a user