[jit] Remove graph() call from abstract Function interface. (#65967)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65967

Graph is an implementation detail. If user wants to get access to the
underlying graph, they should be able to explicitly dynamic cast instead.
ghstack-source-id: 141659819

Test Plan: no behavior change.

Reviewed By: gmagogsfm

Differential Revision: D31326153

fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
This commit is contained in:
Zhengxu Chen
2021-10-27 11:52:48 -07:00
committed by Facebook GitHub Bot
parent 7c48b9ee25
commit b55a2500d2
43 changed files with 324 additions and 261 deletions

View File

@ -68,12 +68,6 @@ struct BuiltinOpFunction : public Function {
// nop
}
std::shared_ptr<Graph> graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "
"special case on Function::isGraphFunction()");
}
std::shared_ptr<Graph> optimized_graph() const override {
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
"from it. This probably indicates that the JIT calling context needs a "

View File

@ -56,8 +56,6 @@ struct TORCH_API Function {
// if this isn't yet defined, run its method_creator function
virtual void ensure_defined() = 0;
virtual std::shared_ptr<Graph> graph() const = 0;
virtual std::shared_ptr<Graph> optimized_graph() const = 0;
virtual void clear_execution_info() = 0;

View File

@ -115,7 +115,7 @@ c10::IValue preprocess(
}
auto method = mod.get_method(FLAGS_method_name);
auto graph = method.function().graph()->copy();
auto graph = toGraphFunction(method.function()).graph()->copy();
auto sizes = getInputSizes(compile_spec);
std::string llvm_asm_code;

View File

@ -1,8 +1,10 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/jit.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/argument_spec.h"
namespace torch {
namespace jit {
@ -136,11 +138,11 @@ TEST(ArgumentSpecTest, Basic_CUDA) {
auto& GF = at::CUDA(at::kFloat);
auto& GD = at::CUDA(at::kDouble);
auto graph = jit::compile(R"JIT(
auto graph = toGraphFunction(jit::compile(R"JIT(
def fn(a, b, c, d, e):
return a, b, c, d, e
)JIT")
->get_function("fn")
->get_function("fn"))
.graph();
ArgumentSpecCreator arg_spec_creator(*graph);

View File

@ -20,7 +20,7 @@ c10::IValue preprocess(
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
for (const auto& method : mod.get_methods()) {
auto graph = method.function().graph()->copy();
auto graph = toGraphFunction(method.function()).graph()->copy();
// Must inline the graph for debug info map.
Inline(*graph);
// This is here because to test module hierarchy we will have

View File

@ -43,7 +43,7 @@ TEST(InlinerTest, Basic) {
CompilationUnit cu(testSource);
auto& fn = cu.get_function("foo3");
auto g = fn.graph();
auto g = toGraphFunction(fn).graph();
Inline(*g);
FileCheck().check_count("prim::Print", 3)->run(*g);
}

View File

@ -362,7 +362,7 @@ struct ClassNamespaceValue : public SugaredValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& name) override {
const auto fullName = c10::QualifiedName(basename_, name);
@ -387,7 +387,7 @@ struct ClassNamespaceValue : public SugaredValue {
struct TestModuleResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");

View File

@ -11,6 +11,7 @@
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/frontend/code_template.h>
@ -473,7 +474,7 @@ TEST(ControlFlowTest, Basic) {
auto cu = compile(cf_examples);
auto run = [&](const std::string& name, std::vector<IValue> stack) {
auto graph = cu->get_function(name).graph();
auto graph = toGraphFunction(cu->get_function(name)).graph();
Code code(graph, "");
InterpreterState interp(code);
interp.run(stack);
@ -1609,7 +1610,7 @@ TEST(LoopPeelerTest, NoInductionVariableUse) {
)JIT";
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_peel_n_times");
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
@ -1651,7 +1652,7 @@ TEST(LoopPeelerTest, YesInductionVariableUse) {
)JIT";
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_peel_n_times");
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
auto stack = createStack({});
// peeling loop once
{
@ -1697,7 +1698,7 @@ TEST(LoopPeelerTest, LoopWithTerminationCondition) {
// the peel changes the termination condition to false
// so the original loop doesn't run
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_with_cond_times");
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
auto stack = createStack({});
// peeling 5 iterations should update the termination
// condition to false
@ -1742,7 +1743,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops) {
)JIT";
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_nested_loops");
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
@ -1782,7 +1783,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops2) {
)JIT";
auto cu = compile(str_func_def);
auto& f = cu->get_function("test_nested_loops");
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
auto stack = createStack({});
{
LoopsPeeler peeler(true_pred, 1);
@ -1859,7 +1860,7 @@ TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
)JIT";
auto cu = compile(basic_example);
auto& fun = cu->get_function("basic");
auto& fun = toGraphFunction(cu->get_function("basic"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
@ -1910,7 +1911,7 @@ TEST(InsertBailOutsTest, Basic) {
)JIT";
auto cu = compile(basic_example);
auto& fun = cu->get_function("basic_loop");
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
auto x = at::randn({2, 3}, at::kCPU);
auto y = at::randn({2, 3}, at::kCPU);
@ -2004,7 +2005,7 @@ def foo(x):
return bar(x)*baz(x)*11
)";
auto cu = compile(text);
const Function& foo = cu->get_function("foo");
const auto& foo = toGraphFunction(cu->get_function("foo"));
for (Node* n : foo.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
if (!n->hasAttribute(attr::value) ||
@ -2086,7 +2087,7 @@ def c(x):
return x
)";
auto cu = compile(text);
const Function& baz = cu->get_function("c");
const auto& baz = toGraphFunction(cu->get_function("c"));
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
for (Node* n : baz.optimized_graph()->nodes()) {
if (n->kind() == prim::Constant) {
@ -2131,7 +2132,8 @@ TEST(InlinedCallStackTest, BlockAnnotation) {
return self.A0.forward(x, y, z) + self.B0.forward(x)
)");
auto graph = c.get_method("forward").function().optimized_graph();
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::stringstream add_ss, mul_ss;
for (Node* n : graph->nodes()) {
if (n->kind() == prim::If) {
@ -2192,7 +2194,8 @@ TEST(InlinedCallStackTest, SelfCallMethods) {
return self.A0.forward(x, y) + self.call_b(x)
)");
auto graph = c.get_method("forward").function().optimized_graph();
auto graph =
toGraphFunction(c.get_method("forward").function()).optimized_graph();
std::unordered_map<std::string, size_t> module_hierarchies;
for (Node* n : graph->nodes()) {
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);

View File

@ -32,7 +32,7 @@ constexpr auto kInternalModule = "torch.distributed.rpc.internal";
struct PythonTypeResolver : public jit::Resolver {
std::shared_ptr<jit::SugaredValue> resolveValue(
const std::string& /* unused */,
torch::jit::Function& /* unused */,
torch::jit::GraphFunction& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");

View File

@ -10,7 +10,7 @@
namespace torch {
namespace jit {
namespace {
c10::FunctionSchema defaultSchemaFor(const Function& function) {
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
std::vector<c10::Argument> args;
std::vector<c10::Argument> returns;
Graph& g = *function.graph();
@ -26,6 +26,29 @@ c10::FunctionSchema defaultSchemaFor(const Function& function) {
}
return {function.name(), "", std::move(args), std::move(returns)};
}
template <typename T, typename F>
T* tryToGraphFunctionImpl(F& function) noexcept {
if (!function.isGraphFunction()) {
return nullptr;
}
return static_cast<T*>(&function);
}
template <typename T, typename F>
T& toGraphFunctionImpl(F& function) {
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
return *g;
}
TORCH_INTERNAL_ASSERT(
false,
"Failed to downcast a Function to a GraphFunction. "
"This probably indicates that the JIT calling context needs a "
"special case on tryToGraphFunction() instead.");
}
} // namespace
void placeholderCreator(GraphFunction&) {
@ -82,5 +105,17 @@ void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
ConstantPooling(graph);
}
GraphFunction* tryToGraphFunction(Function& function) noexcept {
return tryToGraphFunctionImpl<GraphFunction>(function);
}
GraphFunction& toGraphFunction(Function& function) {
return toGraphFunctionImpl<GraphFunction>(function);
}
const GraphFunction& toGraphFunction(const Function& function) {
return toGraphFunctionImpl<const GraphFunction>(function);
}
} // namespace jit
} // namespace torch

View File

@ -33,7 +33,7 @@ struct TORCH_API GraphFunction : public Function {
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
override;
std::shared_ptr<Graph> graph() const override {
std::shared_ptr<Graph> graph() const {
return graph_;
}
@ -143,5 +143,11 @@ struct TORCH_API GraphFunction : public Function {
// before a call to setSchema
mutable std::unique_ptr<FunctionSchema> schema_;
};
// Short hands for dynamic_cast<GraphFunction*>.
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
TORCH_API GraphFunction& toGraphFunction(Function&);
TORCH_API const GraphFunction& toGraphFunction(const Function&);
} // namespace jit
} // namespace torch

View File

@ -43,7 +43,7 @@ struct TORCH_API Method : public torch::IMethod {
TaskLauncher taskLauncher = at::launch);
std::shared_ptr<Graph> graph() const {
return function_->graph();
return toGraphFunction(*function_).graph();
}
const std::string& name() const override {

View File

@ -4,6 +4,7 @@
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
@ -27,14 +28,14 @@ std::string getInputDebugName(const Node& n, const int idx) {
}
void assert_ignored_methods_not_called(
torch::jit::Function* fn,
torch::jit::Function& fn,
const std::unordered_set<std::string>& ignored_methods) {
if (ignored_methods.empty()) {
return;
}
const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes(*fn->graph().get(), c10::prim::CallMethod, recurse);
std::vector<Node*> all_nodes = findAllNodes(
*toGraphFunction(fn).graph(), c10::prim::CallMethod, recurse);
// Extract method names from these nodes.
std::unordered_set<std::string> encountered_ignored_methods;
@ -56,14 +57,14 @@ void assert_ignored_methods_not_called(
TORCH_CHECK(
false,
"Preserved method '",
fn->name(),
fn.name(),
"' references ignored method(s) '",
encountered_ignored_methods_str,
"'. This is not permitted.");
}
void assert_ignored_attributes_not_referenced(
torch::jit::Function* fn,
torch::jit::Function& fn,
const std::unordered_set<std::string>& ignored_attributes) {
if (ignored_attributes.empty()) {
return;
@ -71,7 +72,7 @@ void assert_ignored_attributes_not_referenced(
const bool recurse = true;
std::vector<Node*> all_nodes =
findAllNodes(*fn->graph().get(), c10::prim::GetAttr, recurse);
findAllNodes(*toGraphFunction(fn).graph(), c10::prim::GetAttr, recurse);
// Extract attribute names from these nodes.
std::unordered_set<std::string> encountered_ignored_attributes;
@ -93,7 +94,7 @@ void assert_ignored_attributes_not_referenced(
TORCH_CHECK(
false,
"Preserved method '",
fn->name(),
fn.name(),
"' references ignored attribute(s) '",
encountered_ignored_attributes_str,
"'. This is not permitted.");
@ -282,7 +283,7 @@ void Module::clone_method(
return in;
return it->second;
};
auto graph = method.graph()->copy();
auto graph = toGraphFunction(method).graph()->copy();
graph->remapTypes(type_remap_fn);
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
const auto this_method_name = getNameForMethod(method.name());
@ -411,8 +412,8 @@ Module Module::clone_impl(
for (auto& fn : type()->methods()) {
// If this method is not in the list of ignored methods, clone it.
if (ignored_methods.count(fn->name()) == 0) {
assert_ignored_methods_not_called(fn, ignored_methods);
assert_ignored_attributes_not_referenced(fn, ignored_attributes);
assert_ignored_methods_not_called(*fn, ignored_methods);
assert_ignored_attributes_not_referenced(*fn, ignored_attributes);
r.clone_method(*this, *fn, type_remap);
}
}

View File

@ -15,7 +15,7 @@ struct ClassNamespaceValue : public SugaredValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& name) override {
auto fullName = c10::QualifiedName(basename_, name);
@ -41,7 +41,7 @@ struct ClassNamespaceValue : public SugaredValue {
struct LoweredModuleResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");

View File

@ -227,7 +227,7 @@ static std::shared_ptr<MagicMethod> makeMagic(
struct Environment {
Environment(
Function& method,
GraphFunction& method,
ResolverPtr resolver,
Block* b,
std::shared_ptr<Environment> next = nullptr)
@ -237,7 +237,7 @@ struct Environment {
next(std::move(next)) {}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Function& method;
GraphFunction& method;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
ResolverPtr resolver;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
@ -638,7 +638,7 @@ struct to_ir {
const Def& def,
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
GraphFunction& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
@ -675,7 +675,7 @@ struct to_ir {
}
private:
Function& method;
GraphFunction& method;
std::shared_ptr<Graph> graph;
ResolverPtr resolver;
std::unordered_map<int64_t, Value*, std::hash<int64_t>> integral_constants;
@ -5083,7 +5083,7 @@ struct FunctionResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
auto it = functionTable_.find(name);
if (it != functionTable_.end()) {
@ -5180,7 +5180,7 @@ std::unique_ptr<Function> CompilationUnit::define(
_resolver =
std::make_shared<FunctionResolver>(resolver.get(), function_table);
}
auto creator = [def, _resolver, self](Function& method) {
auto creator = [def, _resolver, self](GraphFunction& method) {
// Store the function name so that it can be referenced if there is an error
// while compiling this function
std::string call_name = method.qualname().name();

View File

@ -32,7 +32,7 @@ struct Resolver {
// the graph to create a value.
virtual std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
return nullptr;
}
@ -47,7 +47,7 @@ struct Resolver {
struct NativeResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");

View File

@ -652,10 +652,12 @@ Value* emitBuiltinCall(
if (matched.first < variants.size()) {
return emitBuiltinNode(matched.second, loc, graph, name);
} else {
Function* fn = builtin_functions[matched.first - variants.size()];
auto& fn = *builtin_functions[matched.first - variants.size()];
// we inline builtin calls because they are normally very small
// wrappers and are not useful for keeping around to debug
return insertGraph(graph, *fn->graph(), matched.second.inputs).at(0);
return insertGraph(
graph, *toGraphFunction(fn).graph(), matched.second.inputs)
.at(0);
}
}

View File

@ -18,7 +18,7 @@ struct NoneValue : SugaredValue {
std::shared_ptr<SugaredValue> PrintValue::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -49,7 +49,7 @@ builtin_cast_method_to_scalar_type() {
std::shared_ptr<SugaredValue> BuiltinFunction::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -69,7 +69,7 @@ struct EnumClassHash {
bool SimpleValue::hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
auto class_type = value_->type()->cast<ClassType>();
if (!class_type) {
@ -85,7 +85,7 @@ bool SimpleValue::hasAttr(
// callable value that will resolve to foo(x, y, z) when called.
std::shared_ptr<SugaredValue> SimpleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
if (value_->type()->isSubtypeOf(*TensorType::get())) {
@ -239,7 +239,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint) {
static const auto make_simple_value =
[](Value* v) -> std::shared_ptr<SugaredValue> {
@ -283,7 +283,7 @@ static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
void SimpleValue::setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) {
const auto classType = value_->type()->cast<ClassType>();
@ -361,7 +361,7 @@ void SimpleValue::setAttr(
std::shared_ptr<SugaredValue> SimpleValue::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -399,7 +399,7 @@ std::shared_ptr<SugaredValue> SimpleValue::call(
return SugaredValue::call(loc, m, args, kwargs, n_binders);
}
Value* SimpleValue::len(const SourceRange& loc, Function& m) {
Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
// List, Tuple, Tensor, fill in missing information desugaring
Value* val = getValue();
TypePtr val_type = val->type();
@ -415,7 +415,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
SugaredValuePtr SimpleValue::getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) {
Value* val = getValue();
@ -452,7 +452,7 @@ SugaredValuePtr SimpleValue::getitem(
}
}
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
auto value = getValue();
auto type = value->type();
// built-in iterable types
@ -480,7 +480,7 @@ SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
RangeValue::RangeValue(
const SourceRange& loc,
Function& m,
GraphFunction& m,
std::vector<Value*> inputs,
c10::optional<int64_t> static_len) {
for (const auto i : c10::irange(inputs.size())) {
@ -518,11 +518,11 @@ RangeValue::RangeValue(
static_len_ = static_len;
}
SugaredValuePtr RangeValue::iter(const SourceRange& loc, Function& m) {
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
return shared_from_this();
};
Value* RangeValue::len(const SourceRange& loc, Function& m) {
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
if (static_len_) {
return insertConstant(*m.graph(), *static_len_, loc);
}
@ -536,7 +536,7 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) {
SugaredValuePtr RangeValue::getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) {
if (has_only_end_) {
@ -568,7 +568,7 @@ std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
return base_iters;
}
Value* IterableTree::len(const SourceRange& loc, Function& m) {
Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
// if it's a iterable tree, we get the base iterables that consists of
// SimpleValue or RangeValue, and then calculate the minimum length of all the
// base iterables to be max_trip_count_val
@ -587,7 +587,7 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) {
SugaredValuePtr IterableTree::getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) {
std::vector<SugaredValuePtr> child_items;
@ -599,7 +599,7 @@ SugaredValuePtr IterableTree::getitem(
void IterableTree::addChild(
const SourceRange& range,
Function& m,
GraphFunction& m,
const SugaredValuePtr& iter_value) {
c10::optional<int64_t> child_len = iter_value->staticLen();
if (children_.size() == 0) {
@ -622,7 +622,7 @@ void IterableTree::addChild(
std::shared_ptr<SugaredValue> MagicMethod::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -640,7 +640,7 @@ std::shared_ptr<SugaredValue> MagicMethod::call(
std::shared_ptr<SugaredValue> ClassValue::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
@ -663,7 +663,7 @@ std::shared_ptr<SugaredValue> ClassValue::call(
std::shared_ptr<SugaredValue> ClassValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Allow import_source.cpp to resolve calls to a submodule's
// hooks. Edge case because normally you wouldn't allow a module to
@ -681,7 +681,7 @@ std::shared_ptr<SugaredValue> ClassValue::attr(
std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -728,7 +728,7 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
const auto& names_values = enum_type_->enumNamesValues();
auto it = std::find_if(
@ -745,7 +745,9 @@ std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
m.graph()->insertConstant(IValue(enum_holder), loc));
}
SugaredValuePtr SugaredEnumClass::iter(const SourceRange& loc, Function& m) {
SugaredValuePtr SugaredEnumClass::iter(
const SourceRange& loc,
GraphFunction& m) {
const auto& names_values = enum_type_->enumNamesValues();
auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
enum_value_ivalues.reserve(names_values.size());

View File

@ -31,21 +31,21 @@ struct TORCH_API SugaredValue
// what can we do with this thing?
// use it as a value e.g. `this + 4`
virtual Value* asValue(const SourceRange& loc, Function& m) {
virtual Value* asValue(const SourceRange& loc, GraphFunction& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a value";
}
// select an attribute on it, e.g. `this.field`
virtual std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
virtual bool hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
@ -53,7 +53,7 @@ struct TORCH_API SugaredValue
// assign an attribute on it, e.g. `this.field = newValue`
virtual void setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) {
throw ErrorReport(loc) << "attribute assignment is not defined on "
@ -64,13 +64,15 @@ struct TORCH_API SugaredValue
// a method invocation
virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint = {}) {
throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
}
// TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
virtual SugaredValuePtr asTupleValue(const SourceRange& loc, Function& m) {
virtual SugaredValuePtr asTupleValue(
const SourceRange& loc,
GraphFunction& m) {
throw ErrorReport(loc) << kind() << " cannot be used as a tuplevalue";
}
@ -83,7 +85,7 @@ struct TORCH_API SugaredValue
// call it like a function, e.g. `outputs = this(inputs)`
virtual std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
@ -109,7 +111,7 @@ struct TORCH_API SugaredValue
// For example, when iterating through a Dict we iterate over its keys
virtual std::shared_ptr<SugaredValue> iter(
const SourceRange& loc,
Function& m) {
GraphFunction& m) {
throw ErrorReport(loc) << kind() << " cannot be used as an iterable";
}
@ -131,7 +133,7 @@ struct TORCH_API SugaredValue
// If it does not have a statically-determinable length, then it cannot
// be iterated over with a modulelist. If it does it must return a constant
// Value *
virtual Value* len(const SourceRange& loc, Function& m) {
virtual Value* len(const SourceRange& loc, GraphFunction& m) {
throw ErrorReport(loc) << "'" << kind() << "'"
<< " object is not iterable";
}
@ -139,7 +141,7 @@ struct TORCH_API SugaredValue
// expression for ith elemement for iterable value
virtual std::shared_ptr<SugaredValue> getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint = nullptr) {
throw ErrorReport(loc) << "'" << kind() << "'"
@ -159,46 +161,48 @@ struct TORCH_API SimpleValue : public SugaredValue {
ss << "value of type '" << value_->type()->annotation_str() << "'";
return ss.str();
}
Value* asValue(const SourceRange& range, Function& m) override {
Value* asValue(const SourceRange& range, GraphFunction& m) override {
return value_;
}
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
override;
bool hasAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
void setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) override;
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
// note: names for args will be 'argument 0', 'argument 1', etc..
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
override;
Value* getValue() const {
return value_;
}
Value* len(const SourceRange& loc, Function& m) override;
Value* len(const SourceRange& loc, GraphFunction& m) override;
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint = nullptr) override;
@ -220,7 +224,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue {
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -239,12 +243,12 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint = {}) override {
return tup_;
};
Value* asValue(const SourceRange& loc, Function& m) override {
Value* asValue(const SourceRange& loc, GraphFunction& m) override {
std::vector<Value*> vec;
for (const auto& sv : tup_) {
vec.push_back(sv->asValue(loc, m));
@ -259,7 +263,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint = nullptr) override {
if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
@ -281,7 +285,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
// This function is called when a SugaredValue is used to convert a
// SugaredValue to its iterator. For example, when iterating through a Dict we
// iterate over its keys
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
override {
return shared_from_this();
};
@ -305,7 +309,7 @@ struct TORCH_API BuiltinModule : public SugaredValue {
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override {
if (field == "autograd") {
// When refering torch.autograd, it is also considered to be a
@ -340,14 +344,14 @@ struct TORCH_API ClassValue : public SugaredValue {
// n = Foo(constructor_arg)
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
std::string kind() const override {
@ -362,7 +366,7 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -392,7 +396,7 @@ struct FunctionValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
GraphFunction& f,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -431,7 +435,7 @@ struct TORCH_API ClosureValue : public SugaredValue {
std::string kind() const override {
return "closure";
}
Value* asValue(const SourceRange& range, Function& m) override {
Value* asValue(const SourceRange& range, GraphFunction& m) override {
return value_;
}
Value* value_;
@ -450,7 +454,7 @@ struct MethodValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
GraphFunction& f,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -493,7 +497,7 @@ struct TORCH_API PrintValue : public SugaredValue {
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -507,7 +511,7 @@ struct TORCH_API CastValue : public BuiltinFunction {
: BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -545,7 +549,7 @@ struct TORCH_API TensorCastValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -578,7 +582,7 @@ struct TORCH_API MagicMethod : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -615,20 +619,20 @@ struct TORCH_API SpecialFormValue : public SugaredValue {
struct TORCH_API RangeValue : SugaredValue {
RangeValue(
const SourceRange& loc,
Function& m,
GraphFunction& m,
std::vector<Value*> input,
c10::optional<int64_t> static_len = c10::nullopt);
std::string kind() const override {
return "range";
}
Value* len(const SourceRange& loc, Function& m) override;
Value* len(const SourceRange& loc, GraphFunction& m) override;
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint = nullptr) override;
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
override;
// When Range is instantiated via enumerate(iterable_with_static_len),
@ -665,7 +669,7 @@ struct TORCH_API IterableTree : SugaredValue {
IterableTree() = default;
IterableTree(
const SourceRange& range,
Function& m,
GraphFunction& m,
at::ArrayRef<SugaredValuePtr> children) {
for (const auto& child : children) {
addChild(range, m, child);
@ -675,14 +679,14 @@ struct TORCH_API IterableTree : SugaredValue {
return "iterabletree";
}
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
override {
return shared_from_this();
}
void addChild(
const SourceRange& range,
Function& m,
GraphFunction& m,
const SugaredValuePtr& iter_value);
std::vector<SugaredValuePtr> get_children() {
@ -701,10 +705,10 @@ struct TORCH_API IterableTree : SugaredValue {
// with len() and getitem()
std::vector<SugaredValuePtr> get_base_iterables();
Value* len(const SourceRange& loc, Function& m) override;
Value* len(const SourceRange& loc, GraphFunction& m) override;
SugaredValuePtr getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint = nullptr) override;
@ -759,7 +763,7 @@ struct TORCH_API ExceptionValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> /*attributes*/,
size_t /*n_binders*/) override {
@ -789,10 +793,10 @@ struct TORCH_API SugaredEnumClass : public SugaredValue {
SugaredValuePtr attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
SugaredValuePtr iter(const SourceRange& loc, Function& m) override;
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
private:
EnumTypePtr enum_type_;

View File

@ -2087,7 +2087,7 @@ void inlineCallStackOfNode(
// ONNX conversion
std::vector<Value*> inlineCallTo(
Node* to_replace,
Function* callee,
GraphFunction* callee,
bool inline_optimized_graph /*=true*/) {
WithInsertPoint guard(to_replace);
TORCH_INTERNAL_ASSERT(callee->isGraphFunction());

View File

@ -82,6 +82,7 @@ using namespace ::c10::cuda;
} // namespace cuda
struct Function;
struct GraphFunction;
struct MatchedSchema;
// A Graph represents one "function" of computation.
@ -1541,7 +1542,7 @@ TORCH_API std::vector<Value*> insertGraph(
*/
TORCH_API std::vector<Value*> inlineCallTo(
Node* to_replace,
Function* callee,
GraphFunction* callee,
bool use_graph = true);
/** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a

View File

@ -106,7 +106,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
decomposed = true;
WithInsertPoint guard(*it);
std::shared_ptr<Graph> d_graph =
decompose_funcs.get_function("addmm").graph();
toGraphFunction(decompose_funcs.get_function("addmm")).graph();
Value* new_output =
insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
// Set the output of the decomposed graph to have the same output type as
@ -136,7 +136,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
// inline the compiled decomposed batchnorm
std::shared_ptr<Graph> d_graph =
decompose_funcs.get_function("batch_norm").graph();
toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
// post processing the graph
@ -171,7 +171,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
// inline the compiled decomposed layernorm
std::shared_ptr<Graph> d_graph =
decompose_funcs.get_function("layer_norm").graph();
toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
// post processing the graph

View File

@ -3,6 +3,7 @@
#include <torch/csrc/jit/jit_log.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/inliner.h>
@ -108,7 +109,7 @@ class AttributePropagator {
for (auto function : preservedMethods_) {
GRAPH_DEBUG("Analyzing function: " + function->name());
auto graph = function->graph();
auto graph = toGraphFunction(*function).graph();
optimizeSubGraphs(graph, applyInline);
if (freezeInterfaces_) {
inlineInterfaceCalls(graph);
@ -120,7 +121,7 @@ class AttributePropagator {
for (auto function : preservedMethods_) {
GRAPH_DEBUG("Propagating function: " + function->name());
auto graph = function->graph();
auto graph = toGraphFunction(*function).graph();
propagateAttributes(graph);
optimizeSubGraphs(graph, applyOptimizations);
}
@ -412,18 +413,17 @@ class AttributePropagator {
if (user_node->kind() == prim::CallMethod) {
const std::string& methodName = user_node->s(attr::name);
Function& function = class_type->getMethod(methodName);
if (!function.isGraphFunction()) {
continue;
}
GRAPH_UPDATE(
"Inlining interface method '",
function.name(),
"' to ",
*user_node);
if (auto graphFunction = tryToGraphFunction(function)) {
GRAPH_UPDATE(
"Inlining interface method '",
function.name(),
"' to ",
*user_node);
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
inlineCallTo(user_node, &function);
inlined = true;
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
inlineCallTo(user_node, graphFunction);
inlined = true;
}
}
}
return inlined;
@ -643,7 +643,7 @@ class AttributePropagator {
// 3) Remove non public unreferenced methods.
void cleanupFrozenModule() {
for (auto function : preservedMethods_) {
auto graph = function->graph();
auto graph = toGraphFunction(*function).graph();
recordReferencedAttrs(graph);
handleSharedClassType(module_, graph);
removeExtraWaitCalls(graph->block());

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/jit_log.h>
@ -21,26 +22,28 @@ void inlineCalls(Block* block) {
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
if (!fun_type->function()->isGraphFunction()) {
continue;
if (auto graphFunction = tryToGraphFunction(*fun_type->function())) {
cur->removeInput(0);
GRAPH_UPDATE(
"Inlining function '",
fun_type->function()->name(),
"' to ",
*cur);
GRAPH_UPDATE(
"Function body: ", *fun_type->function()->optimized_graph());
inlineCallTo(cur, graphFunction);
}
cur->removeInput(0);
GRAPH_UPDATE(
"Inlining function '", fun_type->function()->name(), "' to ", *cur);
GRAPH_UPDATE(
"Function body: ", *fun_type->function()->optimized_graph());
inlineCallTo(cur, fun_type->function());
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
Function& function = class_type->getMethod(name);
if (!function.isGraphFunction()) {
continue;
if (auto graphFunction = tryToGraphFunction(function)) {
GRAPH_UPDATE("Inlining method '", function.name(), "' to ", *cur);
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
inlineCallTo(cur, graphFunction);
}
GRAPH_UPDATE("Inlining method '", function.name(), "' to ", *cur);
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
inlineCallTo(cur, &function);
}
} break;
default: {

View File

@ -51,19 +51,19 @@ void functionCallSubstitution(Block* block) {
if (!input_node_0->hasUses()) {
input_node_0->destroy();
}
functionCallSubstitution(fun_type->function()->graph()->block());
inlineCallTo(cur, fun_type->function(), false);
auto& graphFunction = toGraphFunction(*fun_type->function());
functionCallSubstitution(graphFunction.graph()->block());
inlineCallTo(cur, &graphFunction, false);
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
Function& function = class_type->getMethod(name);
if (!function.isGraphFunction()) {
continue;
if (auto graphFunction = tryToGraphFunction(function)) {
functionCallSubstitution(graphFunction->graph()->block());
inlineCallTo(cur, graphFunction, false);
}
functionCallSubstitution(function.graph()->block());
inlineCallTo(cur, &function, false);
}
} break;
default: {

View File

@ -61,7 +61,8 @@ Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
schema.is_vararg(),
schema.is_varret());
function->setSchema(new_schema);
return function->graph()->addInput(name)->setType(attr.type());
return toGraphFunction(*function).graph()->addInput(name)->setType(
attr.type());
}
std::vector<IValue> getParamAttributes(
@ -177,7 +178,7 @@ std::pair<Module, std::vector<IValue>> list_module_parameters(
Module moduleClone = module.clone(true);
Method method = moduleClone.get_method("forward");
auto function = &method.function();
auto graph = function->graph();
auto graph = toGraphFunction(*function).graph();
// A map of names and values of referenced attributes, to avoid duplicates.
std::unordered_map<std::string, Value*> attrValues = {};

View File

@ -89,7 +89,7 @@ Module Finalize(
// To prevent the JIT optimizations from leveraging the annotated shape info,
// clear shape information in the graph.
for (auto func : module.type()->methods()) {
ClearProfilingInformation(func->graph());
ClearProfilingInformation(toGraphFunction(*func).graph());
}
auto graph = module.get_method("forward").graph();

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
namespace torch {
@ -533,9 +534,9 @@ bool useQuantizable(const Use& use, QuantType quant_type) {
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
auto* func_node = n->input(0)->node();
auto func = func_node->output()->type()->expectRef<FunctionType>().function();
TORCH_CHECK(
func->isGraphFunction(), "Quantization only works for graph function");
return func->graph();
auto graphFunc = tryToGraphFunction(*func);
TORCH_CHECK(graphFunc, "Quantization only works for graph function");
return graphFunc->graph();
}
// Block helper functions

View File

@ -263,7 +263,7 @@ class ModuleCloneHelper {
}
return type_ptr;
};
auto graph = method.graph()->copy();
auto graph = toGraphFunction(method).graph()->copy();
remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
// remap self
graph->inputs()[0]->setType(target.type());

View File

@ -931,7 +931,7 @@ std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
const std::vector<Node*>& nodes,
const std::string& name) {
auto observer_subgraph = std::make_shared<Graph>();
auto build_observer_graph = [&](Function& func) {
auto build_observer_graph = [&](GraphFunction& func) {
buildObserverSubgraph(nodes, func.graph());
};
return torch::make_unique<GraphFunction>(

View File

@ -61,7 +61,7 @@ void SubgraphRewriter::RegisterRewritePattern(
Module SubgraphRewriter::runOnModule(const Module& module) {
nodes_to_delete_.clear();
for (const auto& m : module.get_methods()) {
auto g = m.function().graph();
auto g = toGraphFunction(m.function()).graph();
runOnGraph(g);
}
return module;

View File

@ -979,7 +979,7 @@ inline py::object runAndInsertCall(
// and then run the callee with tracing disabled.
// Get the graph `Value`s that represent the input IValues
auto inputs = last(stack, callee.graph()->inputs().size());
auto inputs = last(stack, toGraphFunction(callee).graph()->inputs().size());
auto input_values =
fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); });
TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1)

View File

@ -114,7 +114,7 @@ FunctionSchema PythonValue::getSchema(
std::shared_ptr<SugaredValue> PythonValue::call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -168,7 +168,7 @@ std::string PythonValue::kind() const {
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint) {
const std::string type_str = typeString(self);
std::stringstream ss;
@ -179,7 +179,7 @@ std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
std::shared_ptr<SugaredValue> PythonValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
const std::string type_str = typeString(self);
std::stringstream ss;
@ -208,7 +208,7 @@ void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
@ -220,7 +220,7 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
#if !defined(USE_ROCM)
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// List of all the cuda operators which are supported in JIT
const std::unordered_set<std::string> cuda_ops = {
@ -259,11 +259,13 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
}
#endif
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
return self_;
}
SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
SugaredValuePtr ModuleValue::asTupleValue(
const SourceRange& loc,
GraphFunction& m) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
auto dict = getSugaredDict(loc, m);
auto mods = dict->getModules();
@ -298,7 +300,7 @@ bool ModuleValue::areAllSubmodulesSubtypeOf(
SugaredValuePtr ModuleValue::getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
@ -365,7 +367,7 @@ SugaredValuePtr ModuleValue::getitem(
void checkInterface(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::shared_ptr<ModuleValue>& self,
const std::string& field) {
if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
@ -377,7 +379,7 @@ void checkInterface(
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
GraphFunction& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue>& self,
@ -413,7 +415,7 @@ void recurseThroughNestedModules(
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
const SourceRange& loc,
Function& m) {
GraphFunction& m) {
std::vector<std::string> paramNames;
std::vector<SugaredValuePtr> values;
@ -441,7 +443,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
const SourceRange& loc,
Function& m) {
GraphFunction& m) {
std::vector<std::string> submoduleNames;
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
@ -472,7 +474,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
std::shared_ptr<SugaredValue> SugaredDict::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Recursive compilation does not maintain module aliasing,
// so we do not add uniqueness checks on
@ -508,7 +510,7 @@ std::shared_ptr<SugaredValue> SugaredDict::attr(
std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
const py::object& obj,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
auto annotation_type = py::module::import("torch.jit.annotations")
.attr("try_ann_to_type")(obj, loc);
@ -521,7 +523,7 @@ std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
// helper function for instantiating a SugaredValue from an IValue
std::shared_ptr<SugaredValue> toSugaredValue(
const IValue& v,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
if (v.isTuple()) {
auto tp = v.toTuple();
@ -540,7 +542,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
// This method controls how we desugar attribute lookups on ScriptModules
std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// 1. Look inside Module object for the field.
const auto& selfType_ = concreteType_->getJitType();
@ -661,14 +663,14 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
bool ModuleValue::hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
return tryGetAttr(loc, m, field) != nullptr;
}
std::shared_ptr<SugaredValue> ModuleValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -759,7 +761,7 @@ std::shared_ptr<SugaredValue> ModuleValue::call(
// This method controls how we desugar attribute lookups on ScriptModules.
std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
if (auto attr = tryGetAttr(loc, m, field)) {
return attr;
@ -788,7 +790,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
<< " has no attribute '" << field << "' " << hint;
}
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
const auto iterableModuleKind = concreteType_->getIterableModuleKind();
if (iterableModuleKind == IterableModuleKind::NONE) {
throw ErrorReport(loc)
@ -807,7 +809,7 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
std::shared_ptr<SugaredValue> PythonClassValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
// Resolve values from the Python object first (e.g. for static methods on
// this type, resolve them as functions)
@ -824,7 +826,7 @@ std::shared_ptr<SugaredValue> PythonClassValue::attr(
bool PythonClassValue::hasAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) {
try {
py::getattr(py_type_, field.c_str());
@ -836,7 +838,7 @@ bool PythonClassValue::hasAttr(
void ModuleValue::setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) {
// Forward to SimpleValue::setAttr
@ -846,7 +848,7 @@ void ModuleValue::setAttr(
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) {
@ -888,7 +890,7 @@ std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t /*n_binders*/) {
@ -984,7 +986,7 @@ bool isEnumClass(py::object obj) {
std::shared_ptr<SugaredValue> createSimpleEnumValue(
const py::object& obj,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
auto enum_class = obj.attr("__class__");
auto enum_type =
@ -996,7 +998,7 @@ std::shared_ptr<SugaredValue> createSimpleEnumValue(
std::shared_ptr<SugaredValue> PythonSliceClass::call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t /*n_binders*/) {
@ -1046,7 +1048,7 @@ std::shared_ptr<SugaredValue> PythonSliceClass::call(
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Function& m,
GraphFunction& m,
const SourceRange& loc,
bool is_constant) {
// directly create SimpleValues when possible, because they are first-class

View File

@ -24,7 +24,7 @@ inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
// type, *add it in this function's implementation*.
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Function& m,
GraphFunction& m,
const SourceRange& loc,
bool is_constant = false);
@ -47,7 +47,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
// call it like a function, e.g. `outputs = this(inputs)`
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -56,15 +56,15 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
Value* asValue(const SourceRange& loc, Function& m) override {
Value* asValue(const SourceRange& loc, GraphFunction& m) override {
throw ErrorReport(loc)
<< kind() << " cannot be used as a value. "
<< "Perhaps it is a closed over global variable? If so, please "
@ -90,7 +90,7 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
};
@ -103,7 +103,7 @@ struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
};
#endif
@ -116,7 +116,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -137,7 +137,7 @@ struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& f,
GraphFunction& f,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
@ -169,53 +169,56 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
return "module";
}
Value* asValue(const SourceRange& loc, Function& m) override;
Value* asValue(const SourceRange& loc, GraphFunction& m) override;
SugaredValuePtr asTupleValue(const SourceRange& loc, Function& m) override;
SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
override;
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> tryGetAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field);
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
// select an attribute on it, e.g. `this.field`
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
override;
bool hasAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
// call module.forward with pre_hooks and hooks
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::shared_ptr<SugaredDict> getSugaredDict(
const SourceRange& loc,
Function& m);
GraphFunction& m);
std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
const SourceRange& loc,
Function& m);
GraphFunction& m);
void setAttr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field,
Value* newValue) override;
SugaredValuePtr iter(const SourceRange& loc, Function& m) override;
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
std::shared_ptr<SugaredValue> getitem(
const SourceRange& loc,
Function& m,
GraphFunction& m,
Value* idx,
TypePtr type_hint) override;
@ -237,7 +240,7 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
void recurseThroughNestedModules(
const SourceRange& loc,
Function& m,
GraphFunction& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue>& self,
@ -269,10 +272,10 @@ struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
SugaredValuePtr iter(const SourceRange& loc, Function& m) override {
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
return keys_;
};
@ -291,7 +294,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -310,11 +313,13 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override;
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
override;
bool hasAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
private:
py::object py_type_;
@ -331,7 +336,7 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
@ -347,7 +352,7 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;

View File

@ -92,7 +92,7 @@ struct PythonResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override {
pybind11::gil_scoped_acquire ag;
py::object obj = rcb_(name);
@ -531,7 +531,7 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
// Make a graph with a fake self argument
auto graph = func.function_->graph()->copy();
auto graph = toGraphFunction(*func.function_).graph()->copy();
auto v = graph->insertInput(0, "self");
v->setType(module._ivalue()->type());
const auto name = QualifiedName(*module.type()->name(), "forward");
@ -1414,11 +1414,13 @@ void initJitScriptBindings(PyObject* module) {
py::arg("_extra_files") = ExtraFilesMap())
.def_property_readonly(
"graph",
[](const StrongFunctionPtr& self) { return self.function_->graph(); })
[](const StrongFunctionPtr& self) {
return toGraphFunction(*self.function_).graph();
})
.def_property_readonly(
"inlined_graph",
[](const StrongFunctionPtr& self) {
auto g = self.function_->graph()->copy();
auto g = toGraphFunction(*self.function_).graph()->copy();
Inline(*g);
return g;
})
@ -1479,7 +1481,7 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"inlined_graph",
[](const Method& self) {
auto g = self.function().graph()->copy();
auto g = toGraphFunction(self.function()).graph()->copy();
Inline(*g);
return g;
})

View File

@ -463,7 +463,7 @@ struct CodeImpl {
TORCH_INTERNAL_ASSERT(bailout_index >= 0);
auto build_bailout_graph = [bailout_index,
unoptimized_graph](Function& func) {
unoptimized_graph](GraphFunction& func) {
BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
};

View File

@ -721,7 +721,7 @@ GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
Node* insertFallbackFunctionCall(
Graph* graph,
Function* func,
GraphFunction* func,
ArrayRef<Value*> inputs) {
auto tuple_type = func->graph()->return_node()->input(0)->type();
Value* fn_constant = graph->insertNode(graph->create(prim::Constant))
@ -740,7 +740,7 @@ Node* insertFallbackFunctionCall(
return fun_unpack_tuple;
}
Function* createFallbackPathFunction(
GraphFunction* createFallbackPathFunction(
Block* b,
const std::string& function_name) {
auto value_map = [](Value* v) { return v; };

View File

@ -1546,7 +1546,7 @@ void loadModule(const CompilationUnit& module) {
continue;
GradientPair pair;
pair.forward = method->graph();
pair.forward = toGraphFunction(*method).graph();
// lookup the backward function
Node* forward_tuple = pair.forward->outputs().at(0)->node();

View File

@ -735,7 +735,8 @@ void loadModule(const CompilationUnit& module) {
Function& shape_compute_function =
module.get_function(shape_compute_function_name);
std::shared_ptr<Graph> graph = shape_compute_function.graph();
std::shared_ptr<Graph> graph =
toGraphFunction(shape_compute_function).graph();
Inline(*graph);
// ATEN operators can return multiple unboxed values, this in contrast to

View File

@ -1,6 +1,7 @@
#include <torch/csrc/jit/serialization/export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/backends/backend_debug_info.h>
#include <torch/csrc/jit/frontend/source_range.h>
@ -367,10 +368,10 @@ std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
}
struct ModuleMethod {
ModuleMethod(const Module& m, const Function& f, c10::QualifiedName n)
ModuleMethod(const Module& m, const GraphFunction& f, c10::QualifiedName n)
: module(m), function(f), exportName(std::move(n)) {}
Module module;
const Function& function;
const GraphFunction& function;
c10::QualifiedName exportName;
};
@ -387,9 +388,9 @@ std::vector<ModuleMethod> getModuleInterfaceExports(
std::vector<ModuleMethod> ret;
for (const auto& submodule : module.modules()) {
for (const auto& method : submodule.get_methods()) {
if (names.find(method.function().qualname().name()) != names.end()) {
ret.emplace_back(
submodule, method.function(), method.function().qualname());
const auto& f = toGraphFunction(method.function());
if (names.find(f.qualname().name()) != names.end()) {
ret.emplace_back(submodule, f, f.qualname());
}
}
}
@ -441,8 +442,8 @@ void setstateTuple(
if (exportSet.contains(qn)) {
return;
}
if (setstate.isGraphFunction()) {
exportFunction(exportSet, ModuleMethod{module, setstate, qn}, toplevel);
if (auto f = tryToGraphFunction(setstate)) {
exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
}
} else {
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
@ -545,10 +546,9 @@ BytecodeExportSet moduleMethodsTuple(
auto methods = module.get_methods();
// top level methods
for (const auto& method : methods) {
const auto& f = toGraphFunction(method.function());
exportFunction(
exportSet,
ModuleMethod{module, method.function(), method.function().qualname()},
/* toplevel */ true);
exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
}
// __setstate__ of all components

View File

@ -19,7 +19,7 @@ struct OpsValue : public SugaredValue {
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override {
return std::make_shared<BuiltinModule>(field, version_);
}
@ -42,7 +42,7 @@ struct TORCH_API ClassNamespaceValue : public SugaredValue {
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& name) override;
std::string kind() const override {
return "Class Namespace";
@ -65,7 +65,7 @@ struct ConstantTableValue : public SugaredValue {
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& field) override {
const char* field_s = field.c_str();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -222,7 +222,7 @@ void SourceImporterImpl::LEGACY_import_methods(
std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) {
auto it = env_.find(name);
if (it != env_.end()) {
@ -720,7 +720,7 @@ void SourceImporterImpl::parseImports(Lexer& L) {
std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
const SourceRange& loc,
Function& m,
GraphFunction& m,
const std::string& name) {
auto fullName = c10::QualifiedName(basename_, name);
// Could be a ClassType or NamedTuple constructor

View File

@ -38,7 +38,7 @@ struct SourceImporterImpl : public Resolver,
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
GraphFunction& m,
const SourceRange& loc) override;
TypePtr resolveType(const std::string& name, const SourceRange& loc) override;

View File

@ -1,9 +1,12 @@
#include <torch/csrc/jit/serialization/python_print.h>
#include <algorithm>
#include <ATen/core/qualified_name.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/versioned_symbols.h>
@ -13,8 +16,6 @@
#include <torch/csrc/jit/resource_guard.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <algorithm>
using c10::QualifiedName;
namespace torch {
@ -1301,9 +1302,8 @@ struct PythonPrintImpl {
void printFunction(
const Function& func,
bool print_first_argument_type = true) {
TORCH_INTERNAL_ASSERT(func.isGraphFunction());
const FunctionSchema& schema = func.getSchema();
Graph& graph = *func.graph();
Graph& graph = *toGraphFunction(func).graph();
used_names_.clear(); // each graph can reuse local names
WithSourceRange guard(&source_range_stack_, graph.param_node());