mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
[jit] add inlined_graph method to ScriptFunctions (#33508)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33508 Ever since we switched to not inlining by default, some users have complained since they relied on inlining occuring to, e.g. process the graph with some other tool. Add an inlined_graph for convenience in those cases. Test Plan: Imported from OSS Differential Revision: D19977638 Pulled By: suo fbshipit-source-id: fe1fa92ff888959203d5d1995930d488b5f9e24c
This commit is contained in:
committed by
Facebook Github Bot
parent
5e80ca12bb
commit
416413dec4
@ -4014,6 +4014,39 @@ class TestFrontend(JitTestCase):
|
||||
|
||||
|
||||
class TestScript(JitTestCase):
|
||||
def test_inlined_graph(self):
|
||||
"""
|
||||
Check that the `inlined_graph` property correctly returns an inlined
|
||||
graph, both through function calls and method calls.
|
||||
"""
|
||||
@torch.jit.script
|
||||
def foo(x):
|
||||
return torch.add(x, x)
|
||||
|
||||
class MyNestedMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyNestedMod, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.sub(x, x)
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyMod, self).__init__()
|
||||
self.nested = MyNestedMod()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.nested(x) # sub
|
||||
x = foo(x) # add
|
||||
return torch.mul(x, x)
|
||||
|
||||
m = torch.jit.script(MyMod())
|
||||
FileCheck().check("aten::sub") \
|
||||
.check("aten::add") \
|
||||
.check("aten::mul") \
|
||||
.run(m.inlined_graph)
|
||||
|
||||
def test_oneline_func(self):
|
||||
def fn(x): return x # noqa: E704
|
||||
|
||||
|
||||
@ -11,17 +11,18 @@
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
|
||||
#include <torch/csrc/jit/constants.h>
|
||||
#include <torch/csrc/jit/export.h>
|
||||
#include <torch/csrc/jit/graph_executor.h>
|
||||
#include <torch/csrc/jit/hooks_for_testing.h>
|
||||
#include <torch/csrc/jit/import_source.h>
|
||||
#include <torch/csrc/jit/irparser.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/passes/python_print.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/jit/python_tracer.h>
|
||||
#include <torch/csrc/jit/script/logging.h>
|
||||
#include <torch/csrc/jit/script/parser.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
#include <torch/csrc/jit/export.h>
|
||||
|
||||
#include <torch/csrc/api/include/torch/ordered_dict.h>
|
||||
|
||||
@ -939,7 +940,6 @@ void initJitScriptBindings(PyObject* module) {
|
||||
// see: [pybind11 varargs]
|
||||
auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
|
||||
Function& callee = *strongPtr.function_;
|
||||
bool tracing = tracer::isTracing();
|
||||
py::object result = invokeScriptFunctionFromPython(
|
||||
callee, tuple_slice(std::move(args), 1), std::move(kwargs));
|
||||
return result;
|
||||
@ -980,6 +980,13 @@ void initJitScriptBindings(PyObject* module) {
|
||||
.def_property_readonly(
|
||||
"graph",
|
||||
[](const StrongFunctionPtr& self) { return self.function_->graph(); })
|
||||
.def_property_readonly(
|
||||
"inlined_graph",
|
||||
[](const StrongFunctionPtr& self) {
|
||||
auto g = self.function_->graph()->copy();
|
||||
Inline(*g);
|
||||
return g;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"schema",
|
||||
[](const StrongFunctionPtr& self) {
|
||||
@ -1017,6 +1024,13 @@ void initJitScriptBindings(PyObject* module) {
|
||||
method, tuple_slice(std::move(args), 1), std::move(kwargs));
|
||||
})
|
||||
.def_property_readonly("graph", &Method::graph)
|
||||
.def_property_readonly(
|
||||
"inlined_graph",
|
||||
[](const Method& self) {
|
||||
auto g = self.function().graph()->copy();
|
||||
Inline(*g);
|
||||
return g;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"schema", [](Method& m) { return m.function().getSchema(); })
|
||||
.def_property_readonly("name", &Method::name)
|
||||
|
||||
@ -1627,6 +1627,15 @@ if _enabled:
|
||||
"""
|
||||
return self.forward.graph
|
||||
|
||||
@property
|
||||
def inlined_graph(self):
|
||||
r"""
|
||||
Returns a string representation of the internal graph for the
|
||||
``forward`` method. This graph will be preprocessed to inline all function and method calls.
|
||||
See `Interpreting Graphs`_ for details.
|
||||
"""
|
||||
return self.forward.inlined_graph
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user