mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Remove graph() call from abstract Function interface. (#65967)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65967 Graph is an implementation detail. If user wants to get access to the underlying graph, they should be able to explicitly dynamic cast instead. ghstack-source-id: 141659819 Test Plan: no behavior change. Reviewed By: gmagogsfm Differential Revision: D31326153 fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7c48b9ee25
commit
b55a2500d2
@ -68,12 +68,6 @@ struct BuiltinOpFunction : public Function {
|
||||
// nop
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> graph() const override {
|
||||
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
|
||||
"from it. This probably indicates that the JIT calling context needs a "
|
||||
"special case on Function::isGraphFunction()");
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> optimized_graph() const override {
|
||||
TORCH_INTERNAL_ASSERT(false , "BuiltinFunction had a graph requested "
|
||||
"from it. This probably indicates that the JIT calling context needs a "
|
||||
|
@ -56,8 +56,6 @@ struct TORCH_API Function {
|
||||
// if this isn't yet defined, run its method_creator function
|
||||
virtual void ensure_defined() = 0;
|
||||
|
||||
virtual std::shared_ptr<Graph> graph() const = 0;
|
||||
|
||||
virtual std::shared_ptr<Graph> optimized_graph() const = 0;
|
||||
|
||||
virtual void clear_execution_info() = 0;
|
||||
|
@ -115,7 +115,7 @@ c10::IValue preprocess(
|
||||
}
|
||||
|
||||
auto method = mod.get_method(FLAGS_method_name);
|
||||
auto graph = method.function().graph()->copy();
|
||||
auto graph = toGraphFunction(method.function()).graph()->copy();
|
||||
auto sizes = getInputSizes(compile_spec);
|
||||
|
||||
std::string llvm_asm_code;
|
||||
|
@ -1,8 +1,10 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/runtime/argument_spec.h>
|
||||
#include <torch/jit.h>
|
||||
|
||||
#include "test/cpp/jit/test_utils.h"
|
||||
#include "torch/csrc/jit/runtime/argument_spec.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -136,11 +138,11 @@ TEST(ArgumentSpecTest, Basic_CUDA) {
|
||||
auto& GF = at::CUDA(at::kFloat);
|
||||
auto& GD = at::CUDA(at::kDouble);
|
||||
|
||||
auto graph = jit::compile(R"JIT(
|
||||
auto graph = toGraphFunction(jit::compile(R"JIT(
|
||||
def fn(a, b, c, d, e):
|
||||
return a, b, c, d, e
|
||||
)JIT")
|
||||
->get_function("fn")
|
||||
->get_function("fn"))
|
||||
.graph();
|
||||
|
||||
ArgumentSpecCreator arg_spec_creator(*graph);
|
||||
|
@ -20,7 +20,7 @@ c10::IValue preprocess(
|
||||
c10::Dict<IValue, IValue> compiled(StringType::get(), StringType::get());
|
||||
|
||||
for (const auto& method : mod.get_methods()) {
|
||||
auto graph = method.function().graph()->copy();
|
||||
auto graph = toGraphFunction(method.function()).graph()->copy();
|
||||
// Must inline the graph for debug info map.
|
||||
Inline(*graph);
|
||||
// This is here because to test module hierarchy we will have
|
||||
|
@ -43,7 +43,7 @@ TEST(InlinerTest, Basic) {
|
||||
CompilationUnit cu(testSource);
|
||||
auto& fn = cu.get_function("foo3");
|
||||
|
||||
auto g = fn.graph();
|
||||
auto g = toGraphFunction(fn).graph();
|
||||
Inline(*g);
|
||||
FileCheck().check_count("prim::Print", 3)->run(*g);
|
||||
}
|
||||
|
@ -362,7 +362,7 @@ struct ClassNamespaceValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& name) override {
|
||||
const auto fullName = c10::QualifiedName(basename_, name);
|
||||
|
||||
@ -387,7 +387,7 @@ struct ClassNamespaceValue : public SugaredValue {
|
||||
struct TestModuleResolver : public Resolver {
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override {
|
||||
if (name == "torch") {
|
||||
return std::make_shared<BuiltinModule>("aten");
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <torch/csrc/autograd/engine.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/codegen/fuser/interface.h>
|
||||
#include <torch/csrc/jit/frontend/code_template.h>
|
||||
@ -473,7 +474,7 @@ TEST(ControlFlowTest, Basic) {
|
||||
auto cu = compile(cf_examples);
|
||||
|
||||
auto run = [&](const std::string& name, std::vector<IValue> stack) {
|
||||
auto graph = cu->get_function(name).graph();
|
||||
auto graph = toGraphFunction(cu->get_function(name)).graph();
|
||||
Code code(graph, "");
|
||||
InterpreterState interp(code);
|
||||
interp.run(stack);
|
||||
@ -1609,7 +1610,7 @@ TEST(LoopPeelerTest, NoInductionVariableUse) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(str_func_def);
|
||||
auto& f = cu->get_function("test_peel_n_times");
|
||||
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
|
||||
auto stack = createStack({});
|
||||
// peeling loop once
|
||||
{
|
||||
@ -1651,7 +1652,7 @@ TEST(LoopPeelerTest, YesInductionVariableUse) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(str_func_def);
|
||||
auto& f = cu->get_function("test_peel_n_times");
|
||||
auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
|
||||
auto stack = createStack({});
|
||||
// peeling loop once
|
||||
{
|
||||
@ -1697,7 +1698,7 @@ TEST(LoopPeelerTest, LoopWithTerminationCondition) {
|
||||
// the peel changes the termination condition to false
|
||||
// so the original loop doesn't run
|
||||
auto cu = compile(str_func_def);
|
||||
auto& f = cu->get_function("test_with_cond_times");
|
||||
auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
|
||||
auto stack = createStack({});
|
||||
// peeling 5 iterations should update the termination
|
||||
// condition to false
|
||||
@ -1742,7 +1743,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(str_func_def);
|
||||
auto& f = cu->get_function("test_nested_loops");
|
||||
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
|
||||
auto stack = createStack({});
|
||||
|
||||
{
|
||||
@ -1782,7 +1783,7 @@ TEST(LoopPeelerTest, SimpleNestedLoops2) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(str_func_def);
|
||||
auto& f = cu->get_function("test_nested_loops");
|
||||
auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
|
||||
auto stack = createStack({});
|
||||
{
|
||||
LoopsPeeler peeler(true_pred, 1);
|
||||
@ -1859,7 +1860,7 @@ TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(basic_example);
|
||||
auto& fun = cu->get_function("basic");
|
||||
auto& fun = toGraphFunction(cu->get_function("basic"));
|
||||
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
|
||||
auto x = at::randn({2, 3}, at::kCPU);
|
||||
auto y = at::randn({2, 3}, at::kCPU);
|
||||
@ -1910,7 +1911,7 @@ TEST(InsertBailOutsTest, Basic) {
|
||||
)JIT";
|
||||
|
||||
auto cu = compile(basic_example);
|
||||
auto& fun = cu->get_function("basic_loop");
|
||||
auto& fun = toGraphFunction(cu->get_function("basic_loop"));
|
||||
auto pr = ProfilingRecord::instrumentGraph(fun.graph());
|
||||
auto x = at::randn({2, 3}, at::kCPU);
|
||||
auto y = at::randn({2, 3}, at::kCPU);
|
||||
@ -2004,7 +2005,7 @@ def foo(x):
|
||||
return bar(x)*baz(x)*11
|
||||
)";
|
||||
auto cu = compile(text);
|
||||
const Function& foo = cu->get_function("foo");
|
||||
const auto& foo = toGraphFunction(cu->get_function("foo"));
|
||||
for (Node* n : foo.optimized_graph()->nodes()) {
|
||||
if (n->kind() == prim::Constant) {
|
||||
if (!n->hasAttribute(attr::value) ||
|
||||
@ -2086,7 +2087,7 @@ def c(x):
|
||||
return x
|
||||
)";
|
||||
auto cu = compile(text);
|
||||
const Function& baz = cu->get_function("c");
|
||||
const auto& baz = toGraphFunction(cu->get_function("c"));
|
||||
std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
|
||||
for (Node* n : baz.optimized_graph()->nodes()) {
|
||||
if (n->kind() == prim::Constant) {
|
||||
@ -2131,7 +2132,8 @@ TEST(InlinedCallStackTest, BlockAnnotation) {
|
||||
return self.A0.forward(x, y, z) + self.B0.forward(x)
|
||||
)");
|
||||
|
||||
auto graph = c.get_method("forward").function().optimized_graph();
|
||||
auto graph =
|
||||
toGraphFunction(c.get_method("forward").function()).optimized_graph();
|
||||
std::stringstream add_ss, mul_ss;
|
||||
for (Node* n : graph->nodes()) {
|
||||
if (n->kind() == prim::If) {
|
||||
@ -2192,7 +2194,8 @@ TEST(InlinedCallStackTest, SelfCallMethods) {
|
||||
return self.A0.forward(x, y) + self.call_b(x)
|
||||
)");
|
||||
|
||||
auto graph = c.get_method("forward").function().optimized_graph();
|
||||
auto graph =
|
||||
toGraphFunction(c.get_method("forward").function()).optimized_graph();
|
||||
std::unordered_map<std::string, size_t> module_hierarchies;
|
||||
for (Node* n : graph->nodes()) {
|
||||
auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
|
||||
|
@ -32,7 +32,7 @@ constexpr auto kInternalModule = "torch.distributed.rpc.internal";
|
||||
struct PythonTypeResolver : public jit::Resolver {
|
||||
std::shared_ptr<jit::SugaredValue> resolveValue(
|
||||
const std::string& /* unused */,
|
||||
torch::jit::Function& /* unused */,
|
||||
torch::jit::GraphFunction& /* unused */,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "RPC Type resolver does not need to resolve value");
|
||||
|
@ -10,7 +10,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace {
|
||||
c10::FunctionSchema defaultSchemaFor(const Function& function) {
|
||||
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
|
||||
std::vector<c10::Argument> args;
|
||||
std::vector<c10::Argument> returns;
|
||||
Graph& g = *function.graph();
|
||||
@ -26,6 +26,29 @@ c10::FunctionSchema defaultSchemaFor(const Function& function) {
|
||||
}
|
||||
return {function.name(), "", std::move(args), std::move(returns)};
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
T* tryToGraphFunctionImpl(F& function) noexcept {
|
||||
if (!function.isGraphFunction()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return static_cast<T*>(&function);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
T& toGraphFunctionImpl(F& function) {
|
||||
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
|
||||
return *g;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"Failed to downcast a Function to a GraphFunction. "
|
||||
"This probably indicates that the JIT calling context needs a "
|
||||
"special case on tryToGraphFunction() instead.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void placeholderCreator(GraphFunction&) {
|
||||
@ -82,5 +105,17 @@ void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
|
||||
ConstantPooling(graph);
|
||||
}
|
||||
|
||||
GraphFunction* tryToGraphFunction(Function& function) noexcept {
|
||||
return tryToGraphFunctionImpl<GraphFunction>(function);
|
||||
}
|
||||
|
||||
GraphFunction& toGraphFunction(Function& function) {
|
||||
return toGraphFunctionImpl<GraphFunction>(function);
|
||||
}
|
||||
|
||||
const GraphFunction& toGraphFunction(const Function& function) {
|
||||
return toGraphFunctionImpl<const GraphFunction>(function);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -33,7 +33,7 @@ struct TORCH_API GraphFunction : public Function {
|
||||
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs())
|
||||
override;
|
||||
|
||||
std::shared_ptr<Graph> graph() const override {
|
||||
std::shared_ptr<Graph> graph() const {
|
||||
return graph_;
|
||||
}
|
||||
|
||||
@ -143,5 +143,11 @@ struct TORCH_API GraphFunction : public Function {
|
||||
// before a call to setSchema
|
||||
mutable std::unique_ptr<FunctionSchema> schema_;
|
||||
};
|
||||
|
||||
// Short hands for dynamic_cast<GraphFunction*>.
|
||||
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
|
||||
TORCH_API GraphFunction& toGraphFunction(Function&);
|
||||
TORCH_API const GraphFunction& toGraphFunction(const Function&);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -43,7 +43,7 @@ struct TORCH_API Method : public torch::IMethod {
|
||||
TaskLauncher taskLauncher = at::launch);
|
||||
|
||||
std::shared_ptr<Graph> graph() const {
|
||||
return function_->graph();
|
||||
return toGraphFunction(*function_).graph();
|
||||
}
|
||||
|
||||
const std::string& name() const override {
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
||||
@ -27,14 +28,14 @@ std::string getInputDebugName(const Node& n, const int idx) {
|
||||
}
|
||||
|
||||
void assert_ignored_methods_not_called(
|
||||
torch::jit::Function* fn,
|
||||
torch::jit::Function& fn,
|
||||
const std::unordered_set<std::string>& ignored_methods) {
|
||||
if (ignored_methods.empty()) {
|
||||
return;
|
||||
}
|
||||
const bool recurse = true;
|
||||
std::vector<Node*> all_nodes =
|
||||
findAllNodes(*fn->graph().get(), c10::prim::CallMethod, recurse);
|
||||
std::vector<Node*> all_nodes = findAllNodes(
|
||||
*toGraphFunction(fn).graph(), c10::prim::CallMethod, recurse);
|
||||
|
||||
// Extract method names from these nodes.
|
||||
std::unordered_set<std::string> encountered_ignored_methods;
|
||||
@ -56,14 +57,14 @@ void assert_ignored_methods_not_called(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Preserved method '",
|
||||
fn->name(),
|
||||
fn.name(),
|
||||
"' references ignored method(s) '",
|
||||
encountered_ignored_methods_str,
|
||||
"'. This is not permitted.");
|
||||
}
|
||||
|
||||
void assert_ignored_attributes_not_referenced(
|
||||
torch::jit::Function* fn,
|
||||
torch::jit::Function& fn,
|
||||
const std::unordered_set<std::string>& ignored_attributes) {
|
||||
if (ignored_attributes.empty()) {
|
||||
return;
|
||||
@ -71,7 +72,7 @@ void assert_ignored_attributes_not_referenced(
|
||||
|
||||
const bool recurse = true;
|
||||
std::vector<Node*> all_nodes =
|
||||
findAllNodes(*fn->graph().get(), c10::prim::GetAttr, recurse);
|
||||
findAllNodes(*toGraphFunction(fn).graph(), c10::prim::GetAttr, recurse);
|
||||
|
||||
// Extract attribute names from these nodes.
|
||||
std::unordered_set<std::string> encountered_ignored_attributes;
|
||||
@ -93,7 +94,7 @@ void assert_ignored_attributes_not_referenced(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Preserved method '",
|
||||
fn->name(),
|
||||
fn.name(),
|
||||
"' references ignored attribute(s) '",
|
||||
encountered_ignored_attributes_str,
|
||||
"'. This is not permitted.");
|
||||
@ -282,7 +283,7 @@ void Module::clone_method(
|
||||
return in;
|
||||
return it->second;
|
||||
};
|
||||
auto graph = method.graph()->copy();
|
||||
auto graph = toGraphFunction(method).graph()->copy();
|
||||
graph->remapTypes(type_remap_fn);
|
||||
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
|
||||
const auto this_method_name = getNameForMethod(method.name());
|
||||
@ -411,8 +412,8 @@ Module Module::clone_impl(
|
||||
for (auto& fn : type()->methods()) {
|
||||
// If this method is not in the list of ignored methods, clone it.
|
||||
if (ignored_methods.count(fn->name()) == 0) {
|
||||
assert_ignored_methods_not_called(fn, ignored_methods);
|
||||
assert_ignored_attributes_not_referenced(fn, ignored_attributes);
|
||||
assert_ignored_methods_not_called(*fn, ignored_methods);
|
||||
assert_ignored_attributes_not_referenced(*fn, ignored_attributes);
|
||||
r.clone_method(*this, *fn, type_remap);
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ struct ClassNamespaceValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& name) override {
|
||||
auto fullName = c10::QualifiedName(basename_, name);
|
||||
|
||||
@ -41,7 +41,7 @@ struct ClassNamespaceValue : public SugaredValue {
|
||||
struct LoweredModuleResolver : public Resolver {
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override {
|
||||
if (name == "torch") {
|
||||
return std::make_shared<BuiltinModule>("aten");
|
||||
|
@ -227,7 +227,7 @@ static std::shared_ptr<MagicMethod> makeMagic(
|
||||
|
||||
struct Environment {
|
||||
Environment(
|
||||
Function& method,
|
||||
GraphFunction& method,
|
||||
ResolverPtr resolver,
|
||||
Block* b,
|
||||
std::shared_ptr<Environment> next = nullptr)
|
||||
@ -237,7 +237,7 @@ struct Environment {
|
||||
next(std::move(next)) {}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
Function& method;
|
||||
GraphFunction& method;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
ResolverPtr resolver;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
@ -638,7 +638,7 @@ struct to_ir {
|
||||
const Def& def,
|
||||
ResolverPtr resolver_,
|
||||
const Self* self,
|
||||
Function& method) // method being constructed
|
||||
GraphFunction& method) // method being constructed
|
||||
: method(method),
|
||||
graph(method.graph()),
|
||||
resolver(std::move(resolver_)),
|
||||
@ -675,7 +675,7 @@ struct to_ir {
|
||||
}
|
||||
|
||||
private:
|
||||
Function& method;
|
||||
GraphFunction& method;
|
||||
std::shared_ptr<Graph> graph;
|
||||
ResolverPtr resolver;
|
||||
std::unordered_map<int64_t, Value*, std::hash<int64_t>> integral_constants;
|
||||
@ -5083,7 +5083,7 @@ struct FunctionResolver : public Resolver {
|
||||
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override {
|
||||
auto it = functionTable_.find(name);
|
||||
if (it != functionTable_.end()) {
|
||||
@ -5180,7 +5180,7 @@ std::unique_ptr<Function> CompilationUnit::define(
|
||||
_resolver =
|
||||
std::make_shared<FunctionResolver>(resolver.get(), function_table);
|
||||
}
|
||||
auto creator = [def, _resolver, self](Function& method) {
|
||||
auto creator = [def, _resolver, self](GraphFunction& method) {
|
||||
// Store the function name so that it can be referenced if there is an error
|
||||
// while compiling this function
|
||||
std::string call_name = method.qualname().name();
|
||||
|
@ -32,7 +32,7 @@ struct Resolver {
|
||||
// the graph to create a value.
|
||||
virtual std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -47,7 +47,7 @@ struct Resolver {
|
||||
struct NativeResolver : public Resolver {
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override {
|
||||
if (name == "torch") {
|
||||
return std::make_shared<BuiltinModule>("aten");
|
||||
|
@ -652,10 +652,12 @@ Value* emitBuiltinCall(
|
||||
if (matched.first < variants.size()) {
|
||||
return emitBuiltinNode(matched.second, loc, graph, name);
|
||||
} else {
|
||||
Function* fn = builtin_functions[matched.first - variants.size()];
|
||||
auto& fn = *builtin_functions[matched.first - variants.size()];
|
||||
// we inline builtin calls because they are normally very small
|
||||
// wrappers and are not useful for keeping around to debug
|
||||
return insertGraph(graph, *fn->graph(), matched.second.inputs).at(0);
|
||||
return insertGraph(
|
||||
graph, *toGraphFunction(fn).graph(), matched.second.inputs)
|
||||
.at(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ struct NoneValue : SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> PrintValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -49,7 +49,7 @@ builtin_cast_method_to_scalar_type() {
|
||||
|
||||
std::shared_ptr<SugaredValue> BuiltinFunction::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -69,7 +69,7 @@ struct EnumClassHash {
|
||||
|
||||
bool SimpleValue::hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
auto class_type = value_->type()->cast<ClassType>();
|
||||
if (!class_type) {
|
||||
@ -85,7 +85,7 @@ bool SimpleValue::hasAttr(
|
||||
// callable value that will resolve to foo(x, y, z) when called.
|
||||
std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Allow method-style casts on Tensor types. e.g. x.int()
|
||||
if (value_->type()->isSubtypeOf(*TensorType::get())) {
|
||||
@ -239,7 +239,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
|
||||
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint) {
|
||||
static const auto make_simple_value =
|
||||
[](Value* v) -> std::shared_ptr<SugaredValue> {
|
||||
@ -283,7 +283,7 @@ static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
|
||||
|
||||
void SimpleValue::setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) {
|
||||
const auto classType = value_->type()->cast<ClassType>();
|
||||
@ -361,7 +361,7 @@ void SimpleValue::setAttr(
|
||||
|
||||
std::shared_ptr<SugaredValue> SimpleValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -399,7 +399,7 @@ std::shared_ptr<SugaredValue> SimpleValue::call(
|
||||
return SugaredValue::call(loc, m, args, kwargs, n_binders);
|
||||
}
|
||||
|
||||
Value* SimpleValue::len(const SourceRange& loc, Function& m) {
|
||||
Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
|
||||
// List, Tuple, Tensor, fill in missing information desugaring
|
||||
Value* val = getValue();
|
||||
TypePtr val_type = val->type();
|
||||
@ -415,7 +415,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
|
||||
|
||||
SugaredValuePtr SimpleValue::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
Value* val = getValue();
|
||||
@ -452,7 +452,7 @@ SugaredValuePtr SimpleValue::getitem(
|
||||
}
|
||||
}
|
||||
|
||||
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
auto value = getValue();
|
||||
auto type = value->type();
|
||||
// built-in iterable types
|
||||
@ -480,7 +480,7 @@ SugaredValuePtr SimpleValue::iter(const SourceRange& loc, Function& m) {
|
||||
|
||||
RangeValue::RangeValue(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
std::vector<Value*> inputs,
|
||||
c10::optional<int64_t> static_len) {
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
@ -518,11 +518,11 @@ RangeValue::RangeValue(
|
||||
static_len_ = static_len;
|
||||
}
|
||||
|
||||
SugaredValuePtr RangeValue::iter(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
return shared_from_this();
|
||||
};
|
||||
|
||||
Value* RangeValue::len(const SourceRange& loc, Function& m) {
|
||||
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
|
||||
if (static_len_) {
|
||||
return insertConstant(*m.graph(), *static_len_, loc);
|
||||
}
|
||||
@ -536,7 +536,7 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) {
|
||||
|
||||
SugaredValuePtr RangeValue::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
if (has_only_end_) {
|
||||
@ -568,7 +568,7 @@ std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
|
||||
return base_iters;
|
||||
}
|
||||
|
||||
Value* IterableTree::len(const SourceRange& loc, Function& m) {
|
||||
Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
|
||||
// if it's a iterable tree, we get the base iterables that consists of
|
||||
// SimpleValue or RangeValue, and then calculate the minimum length of all the
|
||||
// base iterables to be max_trip_count_val
|
||||
@ -587,7 +587,7 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) {
|
||||
|
||||
SugaredValuePtr IterableTree::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
std::vector<SugaredValuePtr> child_items;
|
||||
@ -599,7 +599,7 @@ SugaredValuePtr IterableTree::getitem(
|
||||
|
||||
void IterableTree::addChild(
|
||||
const SourceRange& range,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SugaredValuePtr& iter_value) {
|
||||
c10::optional<int64_t> child_len = iter_value->staticLen();
|
||||
if (children_.size() == 0) {
|
||||
@ -622,7 +622,7 @@ void IterableTree::addChild(
|
||||
|
||||
std::shared_ptr<SugaredValue> MagicMethod::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -640,7 +640,7 @@ std::shared_ptr<SugaredValue> MagicMethod::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> ClassValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
// note: names for args will be 'argument 0', 'argument 1', etc..
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
@ -663,7 +663,7 @@ std::shared_ptr<SugaredValue> ClassValue::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> ClassValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Allow import_source.cpp to resolve calls to a submodule's
|
||||
// hooks. Edge case because normally you wouldn't allow a module to
|
||||
@ -681,7 +681,7 @@ std::shared_ptr<SugaredValue> ClassValue::attr(
|
||||
|
||||
std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -728,7 +728,7 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
|
||||
|
||||
std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
const auto& names_values = enum_type_->enumNamesValues();
|
||||
auto it = std::find_if(
|
||||
@ -745,7 +745,9 @@ std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
|
||||
m.graph()->insertConstant(IValue(enum_holder), loc));
|
||||
}
|
||||
|
||||
SugaredValuePtr SugaredEnumClass::iter(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr SugaredEnumClass::iter(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
const auto& names_values = enum_type_->enumNamesValues();
|
||||
auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
|
||||
enum_value_ivalues.reserve(names_values.size());
|
||||
|
@ -31,21 +31,21 @@ struct TORCH_API SugaredValue
|
||||
|
||||
// what can we do with this thing?
|
||||
// use it as a value e.g. `this + 4`
|
||||
virtual Value* asValue(const SourceRange& loc, Function& m) {
|
||||
virtual Value* asValue(const SourceRange& loc, GraphFunction& m) {
|
||||
throw ErrorReport(loc) << kind() << " cannot be used as a value";
|
||||
}
|
||||
|
||||
// select an attribute on it, e.g. `this.field`
|
||||
virtual std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
|
||||
}
|
||||
|
||||
virtual bool hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
|
||||
}
|
||||
@ -53,7 +53,7 @@ struct TORCH_API SugaredValue
|
||||
// assign an attribute on it, e.g. `this.field = newValue`
|
||||
virtual void setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) {
|
||||
throw ErrorReport(loc) << "attribute assignment is not defined on "
|
||||
@ -64,13 +64,15 @@ struct TORCH_API SugaredValue
|
||||
// a method invocation
|
||||
virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint = {}) {
|
||||
throw ErrorReport(loc) << kind() << " cannot be used as a tuple";
|
||||
}
|
||||
|
||||
// TODO @wconstab refactor to use ModuleValue::asTuple instead of new API
|
||||
virtual SugaredValuePtr asTupleValue(const SourceRange& loc, Function& m) {
|
||||
virtual SugaredValuePtr asTupleValue(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
throw ErrorReport(loc) << kind() << " cannot be used as a tuplevalue";
|
||||
}
|
||||
|
||||
@ -83,7 +85,7 @@ struct TORCH_API SugaredValue
|
||||
// call it like a function, e.g. `outputs = this(inputs)`
|
||||
virtual std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
// note: names for args will be 'argument 0', 'argument 1', etc..
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
@ -109,7 +111,7 @@ struct TORCH_API SugaredValue
|
||||
// For example, when iterating through a Dict we iterate over its keys
|
||||
virtual std::shared_ptr<SugaredValue> iter(
|
||||
const SourceRange& loc,
|
||||
Function& m) {
|
||||
GraphFunction& m) {
|
||||
throw ErrorReport(loc) << kind() << " cannot be used as an iterable";
|
||||
}
|
||||
|
||||
@ -131,7 +133,7 @@ struct TORCH_API SugaredValue
|
||||
// If it does not have a statically-determinable length, then it cannot
|
||||
// be iterated over with a modulelist. If it does it must return a constant
|
||||
// Value *
|
||||
virtual Value* len(const SourceRange& loc, Function& m) {
|
||||
virtual Value* len(const SourceRange& loc, GraphFunction& m) {
|
||||
throw ErrorReport(loc) << "'" << kind() << "'"
|
||||
<< " object is not iterable";
|
||||
}
|
||||
@ -139,7 +141,7 @@ struct TORCH_API SugaredValue
|
||||
// expression for ith elemement for iterable value
|
||||
virtual std::shared_ptr<SugaredValue> getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint = nullptr) {
|
||||
throw ErrorReport(loc) << "'" << kind() << "'"
|
||||
@ -159,46 +161,48 @@ struct TORCH_API SimpleValue : public SugaredValue {
|
||||
ss << "value of type '" << value_->type()->annotation_str() << "'";
|
||||
return ss.str();
|
||||
}
|
||||
Value* asValue(const SourceRange& range, Function& m) override {
|
||||
Value* asValue(const SourceRange& range, GraphFunction& m) override {
|
||||
return value_;
|
||||
}
|
||||
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint = {}) override;
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
|
||||
override;
|
||||
bool hasAttr(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
void setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
// note: names for args will be 'argument 0', 'argument 1', etc..
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
|
||||
override;
|
||||
|
||||
Value* getValue() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
Value* len(const SourceRange& loc, Function& m) override;
|
||||
Value* len(const SourceRange& loc, GraphFunction& m) override;
|
||||
SugaredValuePtr getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint = nullptr) override;
|
||||
|
||||
@ -220,7 +224,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue {
|
||||
}
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -239,12 +243,12 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
|
||||
|
||||
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint = {}) override {
|
||||
return tup_;
|
||||
};
|
||||
|
||||
Value* asValue(const SourceRange& loc, Function& m) override {
|
||||
Value* asValue(const SourceRange& loc, GraphFunction& m) override {
|
||||
std::vector<Value*> vec;
|
||||
for (const auto& sv : tup_) {
|
||||
vec.push_back(sv->asValue(loc, m));
|
||||
@ -259,7 +263,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
|
||||
|
||||
SugaredValuePtr getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint = nullptr) override {
|
||||
if (!(idx->type()->cast<IntType>() && toIValue(idx))) {
|
||||
@ -281,7 +285,7 @@ struct TORCH_API SugaredTupleValue : public SugaredValue {
|
||||
// This function is called when a SugaredValue is used to convert a
|
||||
// SugaredValue to its iterator. For example, when iterating through a Dict we
|
||||
// iterate over its keys
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
|
||||
override {
|
||||
return shared_from_this();
|
||||
};
|
||||
@ -305,7 +309,7 @@ struct TORCH_API BuiltinModule : public SugaredValue {
|
||||
}
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override {
|
||||
if (field == "autograd") {
|
||||
// When refering torch.autograd, it is also considered to be a
|
||||
@ -340,14 +344,14 @@ struct TORCH_API ClassValue : public SugaredValue {
|
||||
// n = Foo(constructor_arg)
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
std::string kind() const override {
|
||||
@ -362,7 +366,7 @@ struct TORCH_API NamedTupleConstructor : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -392,7 +396,7 @@ struct FunctionValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& f,
|
||||
GraphFunction& f,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -431,7 +435,7 @@ struct TORCH_API ClosureValue : public SugaredValue {
|
||||
std::string kind() const override {
|
||||
return "closure";
|
||||
}
|
||||
Value* asValue(const SourceRange& range, Function& m) override {
|
||||
Value* asValue(const SourceRange& range, GraphFunction& m) override {
|
||||
return value_;
|
||||
}
|
||||
Value* value_;
|
||||
@ -450,7 +454,7 @@ struct MethodValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& f,
|
||||
GraphFunction& f,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -493,7 +497,7 @@ struct TORCH_API PrintValue : public SugaredValue {
|
||||
}
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -507,7 +511,7 @@ struct TORCH_API CastValue : public BuiltinFunction {
|
||||
: BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -545,7 +549,7 @@ struct TORCH_API TensorCastValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -578,7 +582,7 @@ struct TORCH_API MagicMethod : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -615,20 +619,20 @@ struct TORCH_API SpecialFormValue : public SugaredValue {
|
||||
struct TORCH_API RangeValue : SugaredValue {
|
||||
RangeValue(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
std::vector<Value*> input,
|
||||
c10::optional<int64_t> static_len = c10::nullopt);
|
||||
|
||||
std::string kind() const override {
|
||||
return "range";
|
||||
}
|
||||
Value* len(const SourceRange& loc, Function& m) override;
|
||||
Value* len(const SourceRange& loc, GraphFunction& m) override;
|
||||
SugaredValuePtr getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint = nullptr) override;
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
|
||||
override;
|
||||
|
||||
// When Range is instantiated via enumerate(iterable_with_static_len),
|
||||
@ -665,7 +669,7 @@ struct TORCH_API IterableTree : SugaredValue {
|
||||
IterableTree() = default;
|
||||
IterableTree(
|
||||
const SourceRange& range,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<SugaredValuePtr> children) {
|
||||
for (const auto& child : children) {
|
||||
addChild(range, m, child);
|
||||
@ -675,14 +679,14 @@ struct TORCH_API IterableTree : SugaredValue {
|
||||
return "iterabletree";
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, Function& m)
|
||||
std::shared_ptr<SugaredValue> iter(const SourceRange& loc, GraphFunction& m)
|
||||
override {
|
||||
return shared_from_this();
|
||||
}
|
||||
|
||||
void addChild(
|
||||
const SourceRange& range,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SugaredValuePtr& iter_value);
|
||||
|
||||
std::vector<SugaredValuePtr> get_children() {
|
||||
@ -701,10 +705,10 @@ struct TORCH_API IterableTree : SugaredValue {
|
||||
// with len() and getitem()
|
||||
std::vector<SugaredValuePtr> get_base_iterables();
|
||||
|
||||
Value* len(const SourceRange& loc, Function& m) override;
|
||||
Value* len(const SourceRange& loc, GraphFunction& m) override;
|
||||
SugaredValuePtr getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint = nullptr) override;
|
||||
|
||||
@ -759,7 +763,7 @@ struct TORCH_API ExceptionValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> /*attributes*/,
|
||||
size_t /*n_binders*/) override {
|
||||
@ -789,10 +793,10 @@ struct TORCH_API SugaredEnumClass : public SugaredValue {
|
||||
|
||||
SugaredValuePtr attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
SugaredValuePtr iter(const SourceRange& loc, Function& m) override;
|
||||
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
|
||||
|
||||
private:
|
||||
EnumTypePtr enum_type_;
|
||||
|
@ -2087,7 +2087,7 @@ void inlineCallStackOfNode(
|
||||
// ONNX conversion
|
||||
std::vector<Value*> inlineCallTo(
|
||||
Node* to_replace,
|
||||
Function* callee,
|
||||
GraphFunction* callee,
|
||||
bool inline_optimized_graph /*=true*/) {
|
||||
WithInsertPoint guard(to_replace);
|
||||
TORCH_INTERNAL_ASSERT(callee->isGraphFunction());
|
||||
|
@ -82,6 +82,7 @@ using namespace ::c10::cuda;
|
||||
} // namespace cuda
|
||||
|
||||
struct Function;
|
||||
struct GraphFunction;
|
||||
struct MatchedSchema;
|
||||
|
||||
// A Graph represents one "function" of computation.
|
||||
@ -1541,7 +1542,7 @@ TORCH_API std::vector<Value*> insertGraph(
|
||||
*/
|
||||
TORCH_API std::vector<Value*> inlineCallTo(
|
||||
Node* to_replace,
|
||||
Function* callee,
|
||||
GraphFunction* callee,
|
||||
bool use_graph = true);
|
||||
|
||||
/** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a
|
||||
|
@ -106,7 +106,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
|
||||
decomposed = true;
|
||||
WithInsertPoint guard(*it);
|
||||
std::shared_ptr<Graph> d_graph =
|
||||
decompose_funcs.get_function("addmm").graph();
|
||||
toGraphFunction(decompose_funcs.get_function("addmm")).graph();
|
||||
Value* new_output =
|
||||
insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
|
||||
// Set the output of the decomposed graph to have the same output type as
|
||||
@ -136,7 +136,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
|
||||
|
||||
// inline the compiled decomposed batchnorm
|
||||
std::shared_ptr<Graph> d_graph =
|
||||
decompose_funcs.get_function("batch_norm").graph();
|
||||
toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
|
||||
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
|
||||
|
||||
// post processing the graph
|
||||
@ -171,7 +171,7 @@ bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
|
||||
|
||||
// inline the compiled decomposed layernorm
|
||||
std::shared_ptr<Graph> d_graph =
|
||||
decompose_funcs.get_function("layer_norm").graph();
|
||||
toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
|
||||
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
|
||||
|
||||
// post processing the graph
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/clear_profiling.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
@ -108,7 +109,7 @@ class AttributePropagator {
|
||||
|
||||
for (auto function : preservedMethods_) {
|
||||
GRAPH_DEBUG("Analyzing function: " + function->name());
|
||||
auto graph = function->graph();
|
||||
auto graph = toGraphFunction(*function).graph();
|
||||
optimizeSubGraphs(graph, applyInline);
|
||||
if (freezeInterfaces_) {
|
||||
inlineInterfaceCalls(graph);
|
||||
@ -120,7 +121,7 @@ class AttributePropagator {
|
||||
|
||||
for (auto function : preservedMethods_) {
|
||||
GRAPH_DEBUG("Propagating function: " + function->name());
|
||||
auto graph = function->graph();
|
||||
auto graph = toGraphFunction(*function).graph();
|
||||
propagateAttributes(graph);
|
||||
optimizeSubGraphs(graph, applyOptimizations);
|
||||
}
|
||||
@ -412,18 +413,17 @@ class AttributePropagator {
|
||||
if (user_node->kind() == prim::CallMethod) {
|
||||
const std::string& methodName = user_node->s(attr::name);
|
||||
Function& function = class_type->getMethod(methodName);
|
||||
if (!function.isGraphFunction()) {
|
||||
continue;
|
||||
}
|
||||
GRAPH_UPDATE(
|
||||
"Inlining interface method '",
|
||||
function.name(),
|
||||
"' to ",
|
||||
*user_node);
|
||||
if (auto graphFunction = tryToGraphFunction(function)) {
|
||||
GRAPH_UPDATE(
|
||||
"Inlining interface method '",
|
||||
function.name(),
|
||||
"' to ",
|
||||
*user_node);
|
||||
|
||||
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
|
||||
inlineCallTo(user_node, &function);
|
||||
inlined = true;
|
||||
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
|
||||
inlineCallTo(user_node, graphFunction);
|
||||
inlined = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return inlined;
|
||||
@ -643,7 +643,7 @@ class AttributePropagator {
|
||||
// 3) Remove non public unreferenced methods.
|
||||
void cleanupFrozenModule() {
|
||||
for (auto function : preservedMethods_) {
|
||||
auto graph = function->graph();
|
||||
auto graph = toGraphFunction(*function).graph();
|
||||
recordReferencedAttrs(graph);
|
||||
handleSharedClassType(module_, graph);
|
||||
removeExtraWaitCalls(graph->block());
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
@ -21,26 +22,28 @@ void inlineCalls(Block* block) {
|
||||
auto function_constant = cur->input(0)->node();
|
||||
auto fun_type =
|
||||
function_constant->output()->type()->expect<FunctionType>();
|
||||
if (!fun_type->function()->isGraphFunction()) {
|
||||
continue;
|
||||
|
||||
if (auto graphFunction = tryToGraphFunction(*fun_type->function())) {
|
||||
cur->removeInput(0);
|
||||
GRAPH_UPDATE(
|
||||
"Inlining function '",
|
||||
fun_type->function()->name(),
|
||||
"' to ",
|
||||
*cur);
|
||||
GRAPH_UPDATE(
|
||||
"Function body: ", *fun_type->function()->optimized_graph());
|
||||
inlineCallTo(cur, graphFunction);
|
||||
}
|
||||
cur->removeInput(0);
|
||||
GRAPH_UPDATE(
|
||||
"Inlining function '", fun_type->function()->name(), "' to ", *cur);
|
||||
GRAPH_UPDATE(
|
||||
"Function body: ", *fun_type->function()->optimized_graph());
|
||||
inlineCallTo(cur, fun_type->function());
|
||||
} break;
|
||||
case prim::CallMethod: {
|
||||
const std::string& name = cur->s(attr::name);
|
||||
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
|
||||
Function& function = class_type->getMethod(name);
|
||||
if (!function.isGraphFunction()) {
|
||||
continue;
|
||||
if (auto graphFunction = tryToGraphFunction(function)) {
|
||||
GRAPH_UPDATE("Inlining method '", function.name(), "' to ", *cur);
|
||||
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
|
||||
inlineCallTo(cur, graphFunction);
|
||||
}
|
||||
GRAPH_UPDATE("Inlining method '", function.name(), "' to ", *cur);
|
||||
GRAPH_UPDATE("Function body: ", *function.optimized_graph());
|
||||
inlineCallTo(cur, &function);
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
|
@ -51,19 +51,19 @@ void functionCallSubstitution(Block* block) {
|
||||
if (!input_node_0->hasUses()) {
|
||||
input_node_0->destroy();
|
||||
}
|
||||
functionCallSubstitution(fun_type->function()->graph()->block());
|
||||
inlineCallTo(cur, fun_type->function(), false);
|
||||
auto& graphFunction = toGraphFunction(*fun_type->function());
|
||||
functionCallSubstitution(graphFunction.graph()->block());
|
||||
inlineCallTo(cur, &graphFunction, false);
|
||||
}
|
||||
} break;
|
||||
case prim::CallMethod: {
|
||||
const std::string& name = cur->s(attr::name);
|
||||
if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
|
||||
Function& function = class_type->getMethod(name);
|
||||
if (!function.isGraphFunction()) {
|
||||
continue;
|
||||
if (auto graphFunction = tryToGraphFunction(function)) {
|
||||
functionCallSubstitution(graphFunction->graph()->block());
|
||||
inlineCallTo(cur, graphFunction, false);
|
||||
}
|
||||
functionCallSubstitution(function.graph()->block());
|
||||
inlineCallTo(cur, &function, false);
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
|
@ -61,7 +61,8 @@ Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
|
||||
schema.is_vararg(),
|
||||
schema.is_varret());
|
||||
function->setSchema(new_schema);
|
||||
return function->graph()->addInput(name)->setType(attr.type());
|
||||
return toGraphFunction(*function).graph()->addInput(name)->setType(
|
||||
attr.type());
|
||||
}
|
||||
|
||||
std::vector<IValue> getParamAttributes(
|
||||
@ -177,7 +178,7 @@ std::pair<Module, std::vector<IValue>> list_module_parameters(
|
||||
Module moduleClone = module.clone(true);
|
||||
Method method = moduleClone.get_method("forward");
|
||||
auto function = &method.function();
|
||||
auto graph = function->graph();
|
||||
auto graph = toGraphFunction(*function).graph();
|
||||
// A map of names and values of referenced attributes, to avoid duplicates.
|
||||
std::unordered_map<std::string, Value*> attrValues = {};
|
||||
|
||||
|
@ -89,7 +89,7 @@ Module Finalize(
|
||||
// To prevent the JIT optimizations from leveraging the annotated shape info,
|
||||
// clear shape information in the graph.
|
||||
for (auto func : module.type()->methods()) {
|
||||
ClearProfilingInformation(func->graph());
|
||||
ClearProfilingInformation(toGraphFunction(*func).graph());
|
||||
}
|
||||
|
||||
auto graph = module.get_method("forward").graph();
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/jit/passes/quantization/helper.h>
|
||||
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
|
||||
namespace torch {
|
||||
@ -533,9 +534,9 @@ bool useQuantizable(const Use& use, QuantType quant_type) {
|
||||
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
|
||||
auto* func_node = n->input(0)->node();
|
||||
auto func = func_node->output()->type()->expectRef<FunctionType>().function();
|
||||
TORCH_CHECK(
|
||||
func->isGraphFunction(), "Quantization only works for graph function");
|
||||
return func->graph();
|
||||
auto graphFunc = tryToGraphFunction(*func);
|
||||
TORCH_CHECK(graphFunc, "Quantization only works for graph function");
|
||||
return graphFunc->graph();
|
||||
}
|
||||
|
||||
// Block helper functions
|
||||
|
@ -263,7 +263,7 @@ class ModuleCloneHelper {
|
||||
}
|
||||
return type_ptr;
|
||||
};
|
||||
auto graph = method.graph()->copy();
|
||||
auto graph = toGraphFunction(method).graph()->copy();
|
||||
remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
|
||||
// remap self
|
||||
graph->inputs()[0]->setType(target.type());
|
||||
|
@ -931,7 +931,7 @@ std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
|
||||
const std::vector<Node*>& nodes,
|
||||
const std::string& name) {
|
||||
auto observer_subgraph = std::make_shared<Graph>();
|
||||
auto build_observer_graph = [&](Function& func) {
|
||||
auto build_observer_graph = [&](GraphFunction& func) {
|
||||
buildObserverSubgraph(nodes, func.graph());
|
||||
};
|
||||
return torch::make_unique<GraphFunction>(
|
||||
|
@ -61,7 +61,7 @@ void SubgraphRewriter::RegisterRewritePattern(
|
||||
Module SubgraphRewriter::runOnModule(const Module& module) {
|
||||
nodes_to_delete_.clear();
|
||||
for (const auto& m : module.get_methods()) {
|
||||
auto g = m.function().graph();
|
||||
auto g = toGraphFunction(m.function()).graph();
|
||||
runOnGraph(g);
|
||||
}
|
||||
return module;
|
||||
|
@ -979,7 +979,7 @@ inline py::object runAndInsertCall(
|
||||
// and then run the callee with tracing disabled.
|
||||
|
||||
// Get the graph `Value`s that represent the input IValues
|
||||
auto inputs = last(stack, callee.graph()->inputs().size());
|
||||
auto inputs = last(stack, toGraphFunction(callee).graph()->inputs().size());
|
||||
auto input_values =
|
||||
fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); });
|
||||
TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1)
|
||||
|
@ -114,7 +114,7 @@ FunctionSchema PythonValue::getSchema(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -168,7 +168,7 @@ std::string PythonValue::kind() const {
|
||||
|
||||
std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint) {
|
||||
const std::string type_str = typeString(self);
|
||||
std::stringstream ss;
|
||||
@ -179,7 +179,7 @@ std::vector<std::shared_ptr<SugaredValue>> PythonValue::asTuple(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
const std::string type_str = typeString(self);
|
||||
std::stringstream ss;
|
||||
@ -208,7 +208,7 @@ void PythonValue::checkForAddToConstantsError(std::stringstream& ss) {
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
py::object member = getattr(loc, field);
|
||||
// note: is_constant = true because we consider that global properties
|
||||
@ -220,7 +220,7 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
|
||||
#if !defined(USE_ROCM)
|
||||
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// List of all the cuda operators which are supported in JIT
|
||||
const std::unordered_set<std::string> cuda_ops = {
|
||||
@ -259,11 +259,13 @@ std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
|
||||
}
|
||||
#endif
|
||||
|
||||
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
|
||||
Value* ModuleValue::asValue(const SourceRange& loc, GraphFunction& m) {
|
||||
return self_;
|
||||
}
|
||||
|
||||
SugaredValuePtr ModuleValue::asTupleValue(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr ModuleValue::asTupleValue(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
auto dict = getSugaredDict(loc, m);
|
||||
auto mods = dict->getModules();
|
||||
@ -298,7 +300,7 @@ bool ModuleValue::areAllSubmodulesSubtypeOf(
|
||||
|
||||
SugaredValuePtr ModuleValue::getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) {
|
||||
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
|
||||
@ -365,7 +367,7 @@ SugaredValuePtr ModuleValue::getitem(
|
||||
|
||||
void checkInterface(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::shared_ptr<ModuleValue>& self,
|
||||
const std::string& field) {
|
||||
if (self->asValue(loc, m)->type()->cast<InterfaceType>()) {
|
||||
@ -377,7 +379,7 @@ void checkInterface(
|
||||
|
||||
void recurseThroughNestedModules(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
std::vector<SugaredValuePtr>& keys,
|
||||
std::vector<SugaredValuePtr>& values,
|
||||
std::shared_ptr<ModuleValue>& self,
|
||||
@ -413,7 +415,7 @@ void recurseThroughNestedModules(
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
|
||||
const SourceRange& loc,
|
||||
Function& m) {
|
||||
GraphFunction& m) {
|
||||
std::vector<std::string> paramNames;
|
||||
std::vector<SugaredValuePtr> values;
|
||||
|
||||
@ -441,7 +443,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredNamedBufferDict(
|
||||
|
||||
std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
|
||||
const SourceRange& loc,
|
||||
Function& m) {
|
||||
GraphFunction& m) {
|
||||
std::vector<std::string> submoduleNames;
|
||||
const auto& selfType = concreteType_->getJitType()->expect<ClassType>();
|
||||
for (size_t i = 0; i < selfType->numAttributes(); ++i) {
|
||||
@ -472,7 +474,7 @@ std::shared_ptr<SugaredDict> ModuleValue::getSugaredDict(
|
||||
|
||||
std::shared_ptr<SugaredValue> SugaredDict::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Recursive compilation does not maintain module aliasing,
|
||||
// so we do not add uniqueness checks on
|
||||
@ -508,7 +510,7 @@ std::shared_ptr<SugaredValue> SugaredDict::attr(
|
||||
|
||||
std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
|
||||
const py::object& obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
auto annotation_type = py::module::import("torch.jit.annotations")
|
||||
.attr("try_ann_to_type")(obj, loc);
|
||||
@ -521,7 +523,7 @@ std::shared_ptr<SugaredEnumClass> createSugaredEnumClassFromObj(
|
||||
// helper function for instantiating a SugaredValue from an IValue
|
||||
std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
const IValue& v,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
if (v.isTuple()) {
|
||||
auto tp = v.toTuple();
|
||||
@ -540,7 +542,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
// This method controls how we desugar attribute lookups on ScriptModules
|
||||
std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// 1. Look inside Module object for the field.
|
||||
const auto& selfType_ = concreteType_->getJitType();
|
||||
@ -661,14 +663,14 @@ std::shared_ptr<SugaredValue> ModuleValue::tryGetAttr(
|
||||
|
||||
bool ModuleValue::hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
return tryGetAttr(loc, m, field) != nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<SugaredValue> ModuleValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -759,7 +761,7 @@ std::shared_ptr<SugaredValue> ModuleValue::call(
|
||||
// This method controls how we desugar attribute lookups on ScriptModules.
|
||||
std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
if (auto attr = tryGetAttr(loc, m, field)) {
|
||||
return attr;
|
||||
@ -788,7 +790,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
|
||||
<< " has no attribute '" << field << "' " << hint;
|
||||
}
|
||||
|
||||
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
|
||||
SugaredValuePtr ModuleValue::iter(const SourceRange& loc, GraphFunction& m) {
|
||||
const auto iterableModuleKind = concreteType_->getIterableModuleKind();
|
||||
if (iterableModuleKind == IterableModuleKind::NONE) {
|
||||
throw ErrorReport(loc)
|
||||
@ -807,7 +809,7 @@ SugaredValuePtr ModuleValue::iter(const SourceRange& loc, Function& m) {
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonClassValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
// Resolve values from the Python object first (e.g. for static methods on
|
||||
// this type, resolve them as functions)
|
||||
@ -824,7 +826,7 @@ std::shared_ptr<SugaredValue> PythonClassValue::attr(
|
||||
|
||||
bool PythonClassValue::hasAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) {
|
||||
try {
|
||||
py::getattr(py_type_, field.c_str());
|
||||
@ -836,7 +838,7 @@ bool PythonClassValue::hasAttr(
|
||||
|
||||
void ModuleValue::setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) {
|
||||
// Forward to SimpleValue::setAttr
|
||||
@ -846,7 +848,7 @@ void ModuleValue::setAttr(
|
||||
|
||||
std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) {
|
||||
@ -888,7 +890,7 @@ std::shared_ptr<SugaredValue> BooleanDispatchValue::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonExceptionValue::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t /*n_binders*/) {
|
||||
@ -984,7 +986,7 @@ bool isEnumClass(py::object obj) {
|
||||
|
||||
std::shared_ptr<SugaredValue> createSimpleEnumValue(
|
||||
const py::object& obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
auto enum_class = obj.attr("__class__");
|
||||
auto enum_type =
|
||||
@ -996,7 +998,7 @@ std::shared_ptr<SugaredValue> createSimpleEnumValue(
|
||||
|
||||
std::shared_ptr<SugaredValue> PythonSliceClass::call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t /*n_binders*/) {
|
||||
@ -1046,7 +1048,7 @@ std::shared_ptr<SugaredValue> PythonSliceClass::call(
|
||||
|
||||
std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
py::object obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc,
|
||||
bool is_constant) {
|
||||
// directly create SimpleValues when possible, because they are first-class
|
||||
|
@ -24,7 +24,7 @@ inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
|
||||
// type, *add it in this function's implementation*.
|
||||
std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
py::object obj,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc,
|
||||
bool is_constant = false);
|
||||
|
||||
@ -47,7 +47,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
||||
// call it like a function, e.g. `outputs = this(inputs)`
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -56,15 +56,15 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
||||
|
||||
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const c10::optional<size_t>& size_hint = {}) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
Value* asValue(const SourceRange& loc, Function& m) override {
|
||||
Value* asValue(const SourceRange& loc, GraphFunction& m) override {
|
||||
throw ErrorReport(loc)
|
||||
<< kind() << " cannot be used as a value. "
|
||||
<< "Perhaps it is a closed over global variable? If so, please "
|
||||
@ -90,7 +90,7 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
};
|
||||
|
||||
@ -103,7 +103,7 @@ struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
};
|
||||
#endif
|
||||
@ -116,7 +116,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
|
||||
}
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -137,7 +137,7 @@ struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& f,
|
||||
GraphFunction& f,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
@ -169,53 +169,56 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
|
||||
return "module";
|
||||
}
|
||||
|
||||
Value* asValue(const SourceRange& loc, Function& m) override;
|
||||
Value* asValue(const SourceRange& loc, GraphFunction& m) override;
|
||||
|
||||
SugaredValuePtr asTupleValue(const SourceRange& loc, Function& m) override;
|
||||
SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
|
||||
override;
|
||||
|
||||
// select an attribute on it, e.g. `this.field`
|
||||
std::shared_ptr<SugaredValue> tryGetAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field);
|
||||
|
||||
// select an attribute on it, e.g. `this.field`
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
// select an attribute on it, e.g. `this.field`
|
||||
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
|
||||
override;
|
||||
bool hasAttr(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
// call module.forward with pre_hooks and hooks
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
|
||||
std::shared_ptr<SugaredDict> getSugaredDict(
|
||||
const SourceRange& loc,
|
||||
Function& m);
|
||||
GraphFunction& m);
|
||||
|
||||
std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
|
||||
const SourceRange& loc,
|
||||
Function& m);
|
||||
GraphFunction& m);
|
||||
|
||||
void setAttr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field,
|
||||
Value* newValue) override;
|
||||
|
||||
SugaredValuePtr iter(const SourceRange& loc, Function& m) override;
|
||||
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
|
||||
|
||||
std::shared_ptr<SugaredValue> getitem(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
Value* idx,
|
||||
TypePtr type_hint) override;
|
||||
|
||||
@ -237,7 +240,7 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
|
||||
|
||||
void recurseThroughNestedModules(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
std::vector<SugaredValuePtr>& keys,
|
||||
std::vector<SugaredValuePtr>& values,
|
||||
std::shared_ptr<ModuleValue>& self,
|
||||
@ -269,10 +272,10 @@ struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
SugaredValuePtr iter(const SourceRange& loc, Function& m) override {
|
||||
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
|
||||
return keys_;
|
||||
};
|
||||
|
||||
@ -291,7 +294,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -310,11 +313,13 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
bool hasAttr(const SourceRange& loc, Function& m, const std::string& field)
|
||||
override;
|
||||
bool hasAttr(
|
||||
const SourceRange& loc,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override;
|
||||
|
||||
private:
|
||||
py::object py_type_;
|
||||
@ -331,7 +336,7 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
@ -347,7 +352,7 @@ struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> call(
|
||||
const SourceRange& loc,
|
||||
Function& caller,
|
||||
GraphFunction& caller,
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
|
@ -92,7 +92,7 @@ struct PythonResolver : public Resolver {
|
||||
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override {
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
py::object obj = rcb_(name);
|
||||
@ -531,7 +531,7 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
|
||||
|
||||
void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
||||
// Make a graph with a fake self argument
|
||||
auto graph = func.function_->graph()->copy();
|
||||
auto graph = toGraphFunction(*func.function_).graph()->copy();
|
||||
auto v = graph->insertInput(0, "self");
|
||||
v->setType(module._ivalue()->type());
|
||||
const auto name = QualifiedName(*module.type()->name(), "forward");
|
||||
@ -1414,11 +1414,13 @@ void initJitScriptBindings(PyObject* module) {
|
||||
py::arg("_extra_files") = ExtraFilesMap())
|
||||
.def_property_readonly(
|
||||
"graph",
|
||||
[](const StrongFunctionPtr& self) { return self.function_->graph(); })
|
||||
[](const StrongFunctionPtr& self) {
|
||||
return toGraphFunction(*self.function_).graph();
|
||||
})
|
||||
.def_property_readonly(
|
||||
"inlined_graph",
|
||||
[](const StrongFunctionPtr& self) {
|
||||
auto g = self.function_->graph()->copy();
|
||||
auto g = toGraphFunction(*self.function_).graph()->copy();
|
||||
Inline(*g);
|
||||
return g;
|
||||
})
|
||||
@ -1479,7 +1481,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||
.def_property_readonly(
|
||||
"inlined_graph",
|
||||
[](const Method& self) {
|
||||
auto g = self.function().graph()->copy();
|
||||
auto g = toGraphFunction(self.function()).graph()->copy();
|
||||
Inline(*g);
|
||||
return g;
|
||||
})
|
||||
|
@ -463,7 +463,7 @@ struct CodeImpl {
|
||||
TORCH_INTERNAL_ASSERT(bailout_index >= 0);
|
||||
|
||||
auto build_bailout_graph = [bailout_index,
|
||||
unoptimized_graph](Function& func) {
|
||||
unoptimized_graph](GraphFunction& func) {
|
||||
BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
|
||||
};
|
||||
|
||||
|
@ -721,7 +721,7 @@ GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
|
||||
|
||||
Node* insertFallbackFunctionCall(
|
||||
Graph* graph,
|
||||
Function* func,
|
||||
GraphFunction* func,
|
||||
ArrayRef<Value*> inputs) {
|
||||
auto tuple_type = func->graph()->return_node()->input(0)->type();
|
||||
Value* fn_constant = graph->insertNode(graph->create(prim::Constant))
|
||||
@ -740,7 +740,7 @@ Node* insertFallbackFunctionCall(
|
||||
return fun_unpack_tuple;
|
||||
}
|
||||
|
||||
Function* createFallbackPathFunction(
|
||||
GraphFunction* createFallbackPathFunction(
|
||||
Block* b,
|
||||
const std::string& function_name) {
|
||||
auto value_map = [](Value* v) { return v; };
|
||||
|
@ -1546,7 +1546,7 @@ void loadModule(const CompilationUnit& module) {
|
||||
continue;
|
||||
|
||||
GradientPair pair;
|
||||
pair.forward = method->graph();
|
||||
pair.forward = toGraphFunction(*method).graph();
|
||||
|
||||
// lookup the backward function
|
||||
Node* forward_tuple = pair.forward->outputs().at(0)->node();
|
||||
|
@ -735,7 +735,8 @@ void loadModule(const CompilationUnit& module) {
|
||||
|
||||
Function& shape_compute_function =
|
||||
module.get_function(shape_compute_function_name);
|
||||
std::shared_ptr<Graph> graph = shape_compute_function.graph();
|
||||
std::shared_ptr<Graph> graph =
|
||||
toGraphFunction(shape_compute_function).graph();
|
||||
Inline(*graph);
|
||||
|
||||
// ATEN operators can return multiple unboxed values, this in contrast to
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/backends/backend_debug_handler.h>
|
||||
#include <torch/csrc/jit/backends/backend_debug_info.h>
|
||||
#include <torch/csrc/jit/frontend/source_range.h>
|
||||
@ -367,10 +368,10 @@ std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
|
||||
}
|
||||
|
||||
struct ModuleMethod {
|
||||
ModuleMethod(const Module& m, const Function& f, c10::QualifiedName n)
|
||||
ModuleMethod(const Module& m, const GraphFunction& f, c10::QualifiedName n)
|
||||
: module(m), function(f), exportName(std::move(n)) {}
|
||||
Module module;
|
||||
const Function& function;
|
||||
const GraphFunction& function;
|
||||
c10::QualifiedName exportName;
|
||||
};
|
||||
|
||||
@ -387,9 +388,9 @@ std::vector<ModuleMethod> getModuleInterfaceExports(
|
||||
std::vector<ModuleMethod> ret;
|
||||
for (const auto& submodule : module.modules()) {
|
||||
for (const auto& method : submodule.get_methods()) {
|
||||
if (names.find(method.function().qualname().name()) != names.end()) {
|
||||
ret.emplace_back(
|
||||
submodule, method.function(), method.function().qualname());
|
||||
const auto& f = toGraphFunction(method.function());
|
||||
if (names.find(f.qualname().name()) != names.end()) {
|
||||
ret.emplace_back(submodule, f, f.qualname());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -441,8 +442,8 @@ void setstateTuple(
|
||||
if (exportSet.contains(qn)) {
|
||||
return;
|
||||
}
|
||||
if (setstate.isGraphFunction()) {
|
||||
exportFunction(exportSet, ModuleMethod{module, setstate, qn}, toplevel);
|
||||
if (auto f = tryToGraphFunction(setstate)) {
|
||||
exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
|
||||
@ -545,10 +546,9 @@ BytecodeExportSet moduleMethodsTuple(
|
||||
auto methods = module.get_methods();
|
||||
// top level methods
|
||||
for (const auto& method : methods) {
|
||||
const auto& f = toGraphFunction(method.function());
|
||||
exportFunction(
|
||||
exportSet,
|
||||
ModuleMethod{module, method.function(), method.function().qualname()},
|
||||
/* toplevel */ true);
|
||||
exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
|
||||
}
|
||||
|
||||
// __setstate__ of all components
|
||||
|
@ -19,7 +19,7 @@ struct OpsValue : public SugaredValue {
|
||||
}
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override {
|
||||
return std::make_shared<BuiltinModule>(field, version_);
|
||||
}
|
||||
@ -42,7 +42,7 @@ struct TORCH_API ClassNamespaceValue : public SugaredValue {
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& name) override;
|
||||
std::string kind() const override {
|
||||
return "Class Namespace";
|
||||
@ -65,7 +65,7 @@ struct ConstantTableValue : public SugaredValue {
|
||||
// select an attribute on it, e.g. `this.field`
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& field) override {
|
||||
const char* field_s = field.c_str();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -222,7 +222,7 @@ void SourceImporterImpl::LEGACY_import_methods(
|
||||
|
||||
std::shared_ptr<SugaredValue> SourceImporterImpl::resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) {
|
||||
auto it = env_.find(name);
|
||||
if (it != env_.end()) {
|
||||
@ -720,7 +720,7 @@ void SourceImporterImpl::parseImports(Lexer& L) {
|
||||
|
||||
std::shared_ptr<SugaredValue> ClassNamespaceValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const std::string& name) {
|
||||
auto fullName = c10::QualifiedName(basename_, name);
|
||||
// Could be a ClassType or NamedTuple constructor
|
||||
|
@ -38,7 +38,7 @@ struct SourceImporterImpl : public Resolver,
|
||||
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
GraphFunction& m,
|
||||
const SourceRange& loc) override;
|
||||
TypePtr resolveType(const std::string& name, const SourceRange& loc) override;
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
||||
@ -13,8 +16,6 @@
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using c10::QualifiedName;
|
||||
|
||||
namespace torch {
|
||||
@ -1301,9 +1302,8 @@ struct PythonPrintImpl {
|
||||
void printFunction(
|
||||
const Function& func,
|
||||
bool print_first_argument_type = true) {
|
||||
TORCH_INTERNAL_ASSERT(func.isGraphFunction());
|
||||
const FunctionSchema& schema = func.getSchema();
|
||||
Graph& graph = *func.graph();
|
||||
Graph& graph = *toGraphFunction(func).graph();
|
||||
used_names_.clear(); // each graph can reuse local names
|
||||
|
||||
WithSourceRange guard(&source_range_stack_, graph.param_node());
|
||||
|
Reference in New Issue
Block a user