[JIT] Virtualize Function (#33921)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33921

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.intern.facebook.com/intern/diff/D20153092/)!

Test Plan: Imported from OSS

Differential Revision: D20177227

Pulled By: jamesr66a

fbshipit-source-id: 87f3e484c4f873d60f76f50f6789c1b4a73bdfde
This commit is contained in:
James Reed
2020-03-07 09:59:11 -08:00
committed by Facebook Github Bot
parent bb1114258c
commit 60e8615a6d
20 changed files with 269 additions and 230 deletions

View File

@ -0,0 +1,62 @@
#pragma once
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <mutex>
namespace c10 {
struct FunctionSchema;
};
namespace torch {
namespace jit {
struct Graph;
struct GraphExecutor;
using Stack = std::vector<at::IValue>;
using Kwargs = std::unordered_map<std::string, at::IValue>;
struct RecursiveMethodCallError : public std::exception {};
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// A Function is a pure Graph with no implicit `self` object bound.
// It contains schema information, and the executor that manages the
// execution of the function. script::Method is a wrapper around a
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
virtual void run(Stack& stack) = 0;
virtual void run(Stack&& stack) = 0;
virtual at::IValue operator()(
std::vector<at::IValue> stack,
const Kwargs& kwargs = Kwargs()) = 0;
virtual const c10::QualifiedName& qualname() const = 0;
virtual const std::string& name() const = 0;
// if this isn't yet defined, run its method_creator function
virtual void ensure_defined() = 0;
virtual std::shared_ptr<Graph> graph() const = 0;
virtual std::shared_ptr<Graph> optimized_graph() const = 0;
virtual GraphExecutor& get_executor() = 0;
virtual const c10::FunctionSchema& getSchema() const = 0;
virtual size_t num_inputs() const = 0;
virtual void check_single_output() = 0;
virtual std::string pretty_print_schema() const = 0;
virtual Function& setSchema(c10::FunctionSchema schema) = 0;
virtual ~Function() {}
};
} // namespace jit
} // namespace torch

View File

@ -17,7 +17,6 @@
struct ClassType;
namespace torch {
namespace jit {
struct Function;
namespace script {
struct CompilationUnit;
}
@ -828,7 +827,6 @@ struct CAFFE2_API RRefType
};
using ::torch::jit::Function;
struct NamedType;
using NamedTypePtr = std::shared_ptr<NamedType>;
@ -1060,9 +1058,8 @@ struct CAFFE2_API StringType : public Type {
struct FunctionType;
using FunctionTypePtr = std::shared_ptr<FunctionType>;
using ::torch::jit::Function;
struct CAFFE2_API FunctionType : public NamedType {
static FunctionTypePtr create(Function* function) {
static FunctionTypePtr create(torch::jit::Function* function) {
return FunctionTypePtr(
new FunctionType(function)); // NOLINT(modernize-make-shared)
}
@ -1079,14 +1076,14 @@ struct CAFFE2_API FunctionType : public NamedType {
std::string python_str() const override {
return "Function";
}
Function* function() const {
torch::jit::Function* function() const {
return function_;
}
static const TypeKind Kind = TypeKind::FunctionType;
private:
FunctionType(Function* function);
Function* function_;
FunctionType(torch::jit::Function* function);
torch::jit::Function* function_;
};
struct NoneType;
@ -1516,7 +1513,7 @@ struct CAFFE2_API ClassType : public NamedType {
return n.qualifiedName();
}
const std::vector<Function*>& methods() const;
const std::vector<torch::jit::Function*>& methods() const;
TypePtr findAttribute(const std::string& name) const {
TORCH_INTERNAL_ASSERT(attributeNames_.size() == attributeTypes_.size());
@ -1738,8 +1735,8 @@ struct CAFFE2_API ClassType : public NamedType {
return parameterSlots_->at(slot);
}
void addMethod(Function* method);
Function* getMethod(const std::string& name) const;
void addMethod(torch::jit::Function* method);
torch::jit::Function* getMethod(const std::string& name) const;
// [Internal Only] Remove method from the ClassType
// caller is responsible to make sure the modification is safe:
@ -1791,14 +1788,13 @@ struct CAFFE2_API ClassType : public NamedType {
std::shared_ptr<std::vector<bool>> parameterSlots_;
// List of methods associated with this class.
std::vector<Function*> methods_;
std::vector<torch::jit::Function*> methods_;
};
struct InterfaceType;
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
using ::torch::jit::script::CompilationUnit;
using ::torch::jit::Function;
// Interfaces are a list of abstract methods that a class might meet.
// If a class provides those methods, it implicitly meets the interface.

View File

@ -4,6 +4,7 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <ATen/core/grad_mode.h>
#include <ATen/core/function.h>
#include <iostream>
namespace c10 {
@ -685,6 +686,101 @@ InterfaceTypePtr InterfaceType::create(QualifiedName qualifiedName, bool is_modu
new InterfaceType(std::move(qualifiedName), is_module));
}
void ClassType::addMethod(torch::jit::Function* method) {
TORCH_CHECK(
getMethod(method->name()) == nullptr,
"Can't redefine method: ",
method->name(),
" on class: ",
python_str());
methods_.push_back(method);
}
torch::jit::Function* ClassType::getMethod(const std::string& name) const {
for (auto method : methods_) {
if (name == method->name()) {
return method;
}
}
return nullptr;
}
void ClassType::unsafeRemoveMethod(const std::string& name) {
size_t slot = 0;
for (auto method : methods_) {
if (method->name() == name) {
methods_.erase(methods_.begin() + slot);
return;
}
slot++;
}
TORCH_CHECK(
false,
"Can't delete undefined method ",
name,
" on class: ",
python_str());
}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
auto ptr = ClassType::create(name(), compilation_unit_);
AT_ASSERT(numAttributes() == refined_slots.size());
for (size_t i = 0; i < attributeNames_.size(); ++i) {
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
}
// Copy methods over
for (const auto& method : methods()) {
ptr->addMethod(method);
}
return ptr;
}
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
// to improve performance, this check can be cached
if (auto iface = rhs->cast<InterfaceType>()) {
// ClassType is not a subtype of InterfaceType if the InterfaceType is a
// Module Interface Type but the Class Type is not a Module Class Type
if (!is_module() && iface->is_module()) {
if (why_not) {
*why_not << "Class '" << python_str() << "' is not a subtype of "
<< "the module interface '" << rhs->python_str()
<< "' , only ScriptModule class can be subtype of module"
<< " interface.\n";
}
return false;
}
for (const FunctionSchema& schema : iface->methods()) {
auto self_method = getMethod(schema.name());
if (!self_method) {
if (why_not) {
*why_not << "Class '" << python_str() << "' does not have method '"
<< schema.name() << "' but '" << rhs->python_str()
<< "' does.\n";
}
return false;
}
if (!self_method->getSchema().isSubtypeOf(
schema, /*is_method=*/true, why_not)) {
if (why_not) {
*why_not << "Method on class '" << python_str()
<< "' (1) is not compatible with interface '"
<< rhs->python_str() << "' (2)\n"
<< " (1) " << self_method->getSchema() << "\n"
<< " (2) " << schema << "\n";
}
return false;
}
}
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
}
FunctionType::FunctionType(torch::jit::Function* function)
: NamedType(TypeKind::FunctionType, function->qualname()),
function_(function) {}
bool InterfaceType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
// to improve performance this check can be cached
if (auto iface = rhs->cast<InterfaceType>()) {
@ -740,7 +836,26 @@ InterfaceType::InterfaceType(QualifiedName name, bool is_module)
InterfaceType::~InterfaceType() = default;
const std::vector<Function*>& ClassType::methods() const {
ClassTypePtr ClassType::create(
c10::optional<QualifiedName> qualifiedName,
std::weak_ptr<CompilationUnit> cu,
bool is_module) {
return ClassTypePtr(
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
}
ClassType::ClassType(
c10::optional<QualifiedName> name,
std::weak_ptr<CompilationUnit> cu,
bool is_module)
: NamedType(TypeKind::ClassType, std::move(name)),
compilation_unit_(std::move(cu)) {
if (is_module) {
parameterSlots_ = std::make_shared<std::vector<bool>>();
}
}
const std::vector<torch::jit::Function*>& ClassType::methods() const {
return methods_;
}

View File

@ -449,7 +449,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/frontend/schema_matching.cpp
${TORCH_SRC_DIR}/csrc/jit/frontend/script_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/frontend/sugared_value.cpp
${TORCH_SRC_DIR}/csrc/jit/ir/class_type.cpp
${TORCH_SRC_DIR}/csrc/jit/frontend/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/frontend/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/frontend/canonicalize_modified_loop.cpp
@ -469,7 +468,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/executor.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/codegen.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/fallback.cpp
${TORCH_SRC_DIR}/csrc/jit/api/function.cpp
${TORCH_SRC_DIR}/csrc/jit/api/function_impl.cpp
${TORCH_SRC_DIR}/csrc/jit/runtime/vararg_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp

View File

@ -81,6 +81,7 @@ libtorch_sources = [
"torch/csrc/jit/ir/attributes.cpp",
"torch/csrc/jit/runtime/argument_spec.cpp",
"torch/csrc/jit/ir/constants.cpp",
"torch/csrc/jit/api/function_impl.cpp",
"torch/csrc/jit/api/custom_class.cpp",
"torch/csrc/jit/ir/node_hashing.cpp",
"torch/csrc/jit/ir/type_hashing.cpp",
@ -163,7 +164,6 @@ libtorch_sources = [
"torch/csrc/jit/frontend/script_type_parser.cpp",
"torch/csrc/jit/frontend/sugared_value.cpp",
"torch/csrc/jit/frontend/schema_matching.cpp",
"torch/csrc/jit/ir/class_type.cpp",
"torch/csrc/jit/frontend/parser.cpp",
"torch/csrc/jit/runtime/jit_exception.cpp",
"torch/csrc/jit/serialization/source_range_serialization.cpp",
@ -183,7 +183,6 @@ libtorch_sources = [
"torch/csrc/jit/codegen/fuser/fallback.cpp",
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
"torch/csrc/jit/codegen/fuser/interface.cpp",
"torch/csrc/jit/api/function.cpp",
"torch/csrc/jit/runtime/vararg_functions.cpp",
"torch/csrc/jit/python/update_graph_executor_opt.cpp",
"torch/csrc/jit/mobile/function.cpp",

View File

@ -30,7 +30,7 @@ namespace {
struct PythonTypeResolver : public jit::script::Resolver {
std::shared_ptr<jit::script::SugaredValue> resolveValue(
const std::string& /* unused */,
Function& /* unused */,
torch::jit::Function& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");

View File

@ -1,9 +1,10 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/jit/api/function.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/ir/ir.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/source_range.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/utils/memory.h>
@ -117,7 +118,7 @@ struct TORCH_API CompilationUnit {
if (shouldMangle) {
name = mangle(name);
}
auto fn = torch::make_unique<Function>(
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::move(graph), nullptr);
auto ret = fn.get();
register_function(std::move(fn));

View File

@ -1,4 +1,4 @@
#include <torch/csrc/jit/api/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/frontend/error_report.h>
@ -6,9 +6,9 @@
namespace torch {
namespace jit {
namespace {
FunctionSchema defaultSchemaFor(const Function& function) {
std::vector<Argument> args;
std::vector<Argument> returns;
c10::FunctionSchema defaultSchemaFor(const Function& function) {
std::vector<c10::Argument> args;
std::vector<c10::Argument> returns;
Graph& g = *function.graph();
size_t num_inputs = function.num_inputs();
for (size_t i = 0; i < num_inputs; ++i) {
@ -24,19 +24,19 @@ FunctionSchema defaultSchemaFor(const Function& function) {
}
} // namespace
void placeholderCreator(Function&) {
void placeholderCreator(GraphFunction&) {
throw RecursiveMethodCallError();
}
void Function::run(Stack& stack) {
void GraphFunction::run(Stack& stack) {
get_executor().run(stack);
}
void Function::run(Stack&& stack) {
void GraphFunction::run(Stack&& stack) {
run(stack);
}
IValue Function::operator()(
IValue GraphFunction::operator()(
std::vector<IValue> stack,
const Kwargs& kwargs) {
getSchema().checkAndNormalizeInputs(stack, kwargs);
@ -44,7 +44,7 @@ IValue Function::operator()(
return stack.front();
}
void Function::ensure_defined() {
void GraphFunction::ensure_defined() {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
@ -54,9 +54,9 @@ void Function::ensure_defined() {
check_single_output();
}
const FunctionSchema& Function::getSchema() const {
const c10::FunctionSchema& GraphFunction::getSchema() const {
if (schema_ == nullptr) {
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
}
return *schema_;
}

View File

@ -1,43 +1,34 @@
#pragma once
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/utils/memory.h>
#include <mutex>
namespace torch {
namespace jit {
using Kwargs = std::unordered_map<std::string, IValue>;
struct RecursiveMethodCallError : public std::exception {};
TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// A Function is a pure Graph with no implicit `self` object bound.
// It contains schema information, and the executor that manages the
// execution of the function. script::Method is a wrapper around a
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
Function(
struct TORCH_API GraphFunction : public Function {
GraphFunction(
c10::QualifiedName name,
std::shared_ptr<Graph> graph,
std::function<void(Function&)> function_creator)
std::function<void(GraphFunction&)> function_creator)
: name_(std::move(name)),
graph_(std::move(graph)),
function_creator_(std::move(function_creator)) {}
void run(Stack& stack);
void run(Stack& stack) override;
void run(Stack&& stack);
void run(Stack&& stack) override;
IValue operator()(
std::vector<IValue> stack,
const Kwargs& kwargs = Kwargs());
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;
std::shared_ptr<Graph> graph() const {
std::shared_ptr<Graph> graph() const override {
return graph_;
}
std::shared_ptr<Graph> optimized_graph() const {
std::shared_ptr<Graph> optimized_graph() const override {
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
if (optimized_graph_) {
return *optimized_graph_;
@ -47,29 +38,29 @@ struct TORCH_API Function {
return *optimized_graph_;
}
const c10::QualifiedName& qualname() const {
const c10::QualifiedName& qualname() const override {
return name_;
}
const std::string& name() const {
const std::string& name() const override {
return name_.name();
}
// if this isn't yet defined, run its method_creator function
void ensure_defined();
void ensure_defined() override;
size_t num_inputs() const {
size_t num_inputs() const override {
return graph()->inputs().size();
}
Function& setSchema(FunctionSchema schema) {
Function& setSchema(FunctionSchema schema) override {
schema_ = make_unique<FunctionSchema>(std::move(schema));
return *this;
}
const FunctionSchema& getSchema() const;
const FunctionSchema& getSchema() const override;
std::string pretty_print_schema() const {
std::string pretty_print_schema() const override {
AT_ASSERT(schema_);
std::stringstream ss;
ss << *schema_;
@ -82,18 +73,18 @@ struct TORCH_API Function {
bool is_optimized() const {
AT_WARN(
"Function::is_optimized() is deprecated and always returns true. "
"GraphFunction::is_optimized() is deprecated and always returns true. "
"Please use getGraphExecutorOptimize()");
return true;
}
void check_single_output() {
void check_single_output() override {
TORCH_CHECK(
graph()->outputs().size() == 1,
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
}
GraphExecutor& get_executor() {
GraphExecutor& get_executor() override {
ensure_defined();
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
if (executor_) {
@ -114,12 +105,11 @@ struct TORCH_API Function {
// here.
mutable c10::optional<std::shared_ptr<Graph>> optimized_graph_;
// Functions are invokable from multiple threads, so this lock needs to be
// held when we're initializing graph executor for the first time or computing
// the optimized graph.
// We're using reentrant mutex so that we don't need to worry about causing a
// deadlock by calling one method from another (e.g. optimized_graph() from
// get_executor()).
// GraphFunctions are invokable from multiple threads, so this lock needs to
// be held when we're initializing graph executor for the first time or
// computing the optimized graph. We're using reentrant mutex so that we don't
// need to worry about causing a deadlock by calling one method from another
// (e.g. optimized_graph() from get_executor()).
mutable std::recursive_mutex compile_mutex;
GraphExecutor executor_; // for execution
@ -127,7 +117,7 @@ struct TORCH_API Function {
// an optional function that actually creates the method when
// ensure_defined() is called. This is used by the compiler so
// that it can construct methods out of order
std::function<void(Function&)> function_creator_;
std::function<void(GraphFunction&)> function_creator_;
// if absent, then we generate a default schema based on the graph
// mutable because getSchema caches the default schema if one is requested

View File

@ -1,7 +1,8 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <torch/csrc/jit/api/function.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/api/function_impl.h>
namespace torch {
namespace jit {

View File

@ -1,5 +1,6 @@
#pragma once
#include <ATen/core/functional.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/api/method.h>
@ -94,7 +95,7 @@ struct TORCH_API Object {
}
const std::vector<Method> get_methods() const {
return fmap(type()->methods(), [&](Function* func) {
return c10::fmap(type()->methods(), [&](Function* func) {
return Method(_ivalue(), func);
});
}

View File

@ -1,10 +1,13 @@
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/testing/hooks_for_testing.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
@ -13,11 +16,9 @@
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/lift_closures.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/testing/hooks_for_testing.h>
#include <torch/csrc/jit/ir/constants.h>
@ -3528,7 +3529,7 @@ std::unique_ptr<Function> CompilationUnit::define(
name = mangle(name);
}
}
auto fn = torch::make_unique<Function>(
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
if (self) {
// Register this as a method on `self`'s type

View File

@ -1,129 +0,0 @@
#include <ATen/core/jit_type.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/jit/api/module.h>
namespace c10 {
ClassTypePtr ClassType::create(
c10::optional<QualifiedName> qualifiedName,
std::weak_ptr<CompilationUnit> cu,
bool is_module) {
return ClassTypePtr(
new ClassType(std::move(qualifiedName), std::move(cu), is_module));
}
ClassType::ClassType(
c10::optional<QualifiedName> name,
std::weak_ptr<CompilationUnit> cu,
bool is_module)
: NamedType(TypeKind::ClassType, std::move(name)),
compilation_unit_(std::move(cu)) {
if (is_module) {
parameterSlots_ = std::make_shared<std::vector<bool>>();
}
}
void ClassType::addMethod(Function* method) {
TORCH_CHECK(
getMethod(method->name()) == nullptr,
"Can't redefine method: ",
method->name(),
" on class: ",
python_str());
methods_.push_back(method);
}
Function* ClassType::getMethod(const std::string& name) const {
for (auto method : methods_) {
if (name == method->name()) {
return method;
}
}
return nullptr;
}
void ClassType::unsafeRemoveMethod(const std::string& name) {
size_t slot = 0;
for (auto method : methods_) {
if (method->name() == name) {
methods_.erase(methods_.begin() + slot);
return;
}
slot++;
}
TORCH_CHECK(
false,
"Can't delete undefined method ",
name,
" on class: ",
python_str());
}
#ifndef USE_MOBILE_CLASSTYPE
// This file exists because we need to reference module.h, which we can't from
// c10. Sigh...
FunctionType::FunctionType(Function* function)
: NamedType(TypeKind::FunctionType, function->qualname()),
function_(function) {}
ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
auto ptr = ClassType::create(name(), compilation_unit_);
AT_ASSERT(numAttributes() == refined_slots.size());
for (size_t i = 0; i < attributeNames_.size(); ++i) {
AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
ptr->addAttribute(attributeNames_[i], refined_slots[i]);
}
// Copy methods over
for (const auto& method : methods()) {
ptr->addMethod(method);
}
return ptr;
}
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
// to improve performance, this check can be cached
if (auto iface = rhs->cast<InterfaceType>()) {
// ClassType is not a subtype of InterfaceType if the InterfaceType is a
// Module Interface Type but the Class Type is not a Module Class Type
if (!is_module() && iface->is_module()) {
if (why_not) {
*why_not << "Class '" << python_str() << "' is not a subtype of "
<< "the module interface '" << rhs->python_str()
<< "' , only ScriptModule class can be subtype of module"
<< " interface.\n";
}
return false;
}
for (const FunctionSchema& schema : iface->methods()) {
auto self_method = getMethod(schema.name());
if (!self_method) {
if (why_not) {
*why_not << "Class '" << python_str() << "' does not have method '"
<< schema.name() << "' but '" << rhs->python_str()
<< "' does.\n";
}
return false;
}
if (!self_method->getSchema().isSubtypeOf(
schema, /*is_method=*/true, why_not)) {
if (why_not) {
*why_not << "Method on class '" << python_str()
<< "' (1) is not compatible with interface '"
<< rhs->python_str() << "' (2)\n"
<< " (1) " << self_method->getSchema() << "\n"
<< " (2) " << schema << "\n";
}
return false;
}
}
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
}
#else
bool ClassType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
return Type::isSubtypeOfExt(rhs, why_not);
}
#endif
} // namespace c10

View File

@ -2,12 +2,13 @@
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/api/function.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <algorithm>
#include <iostream>

View File

@ -1,5 +1,5 @@
#include <torch/csrc/jit/ir/scope.h>
#include <torch/csrc/jit/api/function.h>
#include <ATen/core/function.h>
namespace torch {
namespace jit {

View File

@ -6,11 +6,12 @@
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/function.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/frontend/error_report.h>
namespace torch {
namespace jit {
@ -76,7 +77,7 @@ bool is_enabled(const char *cfname, JitLoggingLevels level) {
// we won't have access to an original function, so we have to construct
// a dummy function to give to PythonPrint
std::string log_function(const std::shared_ptr<torch::jit::Graph> &graph) {
torch::jit::Function func("source_dump", graph, nullptr);
torch::jit::GraphFunction func("source_dump", graph, nullptr);
std::vector<at::Tensor> tensors;
std::vector<c10::NamedTypePtr> deps;
PythonPrint pp(tensors, deps, false);

View File

@ -1,4 +1,4 @@
#include <torch/csrc/jit/api/function.h>
#include <ATen/core/function.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/ir/alias_analysis.h>

View File

@ -7,16 +7,17 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>
#include <exception>
@ -654,13 +655,13 @@ struct CodeImpl {
TORCH_INTERNAL_ASSERT(bailout_index >= 0);
auto build_bailout_graph = [bailout_index,
unoptimized_graph](Function &func) {
unoptimized_graph](Function& func) {
BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
};
auto empty_graph = std::make_shared<Graph>();
auto func = torch::make_unique<Function>(
auto func = torch::make_unique<GraphFunction>(
"bailout", empty_graph, build_bailout_graph);
function_table_.emplace_back(func.get());
bailout_functions_.emplace_back(std::move(func));

View File

@ -3,9 +3,9 @@
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <torch/csrc/jit/api/function.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <aten/src/ATen/quantized/Quantizer.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <string>
namespace torch {

View File

@ -3,7 +3,7 @@
#ifdef USE_DISTRIBUTED
#include <torch/csrc/distributed/rpc/rref_context.h>
#endif
#include <torch/csrc/jit/api/function.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <string>
#include "unpickler.h"