[jit] add inlined_graph method to ScriptFunctions (#33508)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33508

Ever since we switched to not inlining by default, some users have
complained since they relied on inlining occuring to, e.g. process the
graph with some other tool. Add an inlined_graph for convenience in
those cases.

Test Plan: Imported from OSS

Differential Revision: D19977638

Pulled By: suo

fbshipit-source-id: fe1fa92ff888959203d5d1995930d488b5f9e24c
This commit is contained in:
Michael Suo
2020-02-19 15:39:18 -08:00
committed by Facebook Github Bot
parent 5e80ca12bb
commit 416413dec4
3 changed files with 58 additions and 2 deletions

View File

@ -4014,6 +4014,39 @@ class TestFrontend(JitTestCase):
class TestScript(JitTestCase):
def test_inlined_graph(self):
"""
Check that the `inlined_graph` property correctly returns an inlined
graph, both through function calls and method calls.
"""
@torch.jit.script
def foo(x):
return torch.add(x, x)
class MyNestedMod(torch.nn.Module):
def __init__(self):
super(MyNestedMod, self).__init__()
def forward(self, x):
return torch.sub(x, x)
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.nested = MyNestedMod()
def forward(self, x):
x = self.nested(x) # sub
x = foo(x) # add
return torch.mul(x, x)
m = torch.jit.script(MyMod())
FileCheck().check("aten::sub") \
.check("aten::add") \
.check("aten::mul") \
.run(m.inlined_graph)
def test_oneline_func(self):
def fn(x): return x # noqa: E704

View File

@ -11,17 +11,18 @@
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/irparser.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_tracer.h>
#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
@ -939,7 +940,6 @@ void initJitScriptBindings(PyObject* module) {
// see: [pybind11 varargs]
auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
Function& callee = *strongPtr.function_;
bool tracing = tracer::isTracing();
py::object result = invokeScriptFunctionFromPython(
callee, tuple_slice(std::move(args), 1), std::move(kwargs));
return result;
@ -980,6 +980,13 @@ void initJitScriptBindings(PyObject* module) {
.def_property_readonly(
"graph",
[](const StrongFunctionPtr& self) { return self.function_->graph(); })
.def_property_readonly(
"inlined_graph",
[](const StrongFunctionPtr& self) {
auto g = self.function_->graph()->copy();
Inline(*g);
return g;
})
.def_property_readonly(
"schema",
[](const StrongFunctionPtr& self) {
@ -1017,6 +1024,13 @@ void initJitScriptBindings(PyObject* module) {
method, tuple_slice(std::move(args), 1), std::move(kwargs));
})
.def_property_readonly("graph", &Method::graph)
.def_property_readonly(
"inlined_graph",
[](const Method& self) {
auto g = self.function().graph()->copy();
Inline(*g);
return g;
})
.def_property_readonly(
"schema", [](Method& m) { return m.function().getSchema(); })
.def_property_readonly("name", &Method::name)

View File

@ -1627,6 +1627,15 @@ if _enabled:
"""
return self.forward.graph
@property
def inlined_graph(self):
r"""
Returns a string representation of the internal graph for the
``forward`` method. This graph will be preprocessed to inline all function and method calls.
See `Interpreting Graphs`_ for details.
"""
return self.forward.inlined_graph
@property
def code(self):
r"""