[FX] Add a default_value arg to Graph.placeholder and fix split_module (#71016)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71016

I found out that `split_module` doesn't preserve default values for arguments. In trying to fix that, I noticed that `Graph.placeholder` doesn't make it easy to add a default argument when making a placeholder. This PR addresses both of those issues

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D33482218

Pulled By: jamesr66a

fbshipit-source-id: 57ebcdab25d267333fb1034994e08fc1bdb128ee
This commit is contained in:
James Reed
2022-01-12 13:57:33 -08:00
committed by Facebook GitHub Bot
parent 5749be4678
commit de902b5d02
4 changed files with 37 additions and 4 deletions

View File

@ -20,7 +20,7 @@ torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = No
torch.fx.graph.Graph.lint(self)
torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = <function <lambda>>) -> torch.fx.node.Node
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')

View File

@ -852,6 +852,28 @@ terrible spacing
module_with_submodules = split_module(traced, m, lambda node: 0)
module_with_submodules(a)
def test_split_module_default_arg(self):
class ModelToTrace(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(512, 512)
def forward(self, x, targets=None):
x = self.lin(x)
if targets is not None:
x = x + targets
return x
mtt = ModelToTrace()
traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
split = split_module(traced, mtt, lambda node: 0)
x = torch.randn(50, 512)
torch.testing.assert_allclose(split(x), traced(x))
def test_normalize_binary_operators(self):
ops_to_test = {
torch.add,

View File

@ -14,6 +14,7 @@ import re
import builtins
import math
import warnings
import inspect
if TYPE_CHECKING:
@ -522,7 +523,8 @@ class Graph:
return _InsertPoint(self, n.append)
@compatibility(is_backward_compatible=True)
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
def placeholder(self, name: str, type_expr: Optional[Any] = None,
default_value : Any = inspect.Signature.empty) -> Node:
"""
Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
a function input.
@ -537,11 +539,17 @@ class Graph:
cases for proper code generation (e.g. when the function is used
subsequently in TorchScript compilation).
default_value (Any): The default value this function argument should take
on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
should be passed as this argument to specify that the parameter does _not_
have a default value.
.. note::
The same insertion point and type expression rules apply for this method
as ``Graph.create_node``.
"""
return self.create_node('placeholder', name, type_expr=type_expr)
args = () if default_value is inspect.Signature.empty else (default_value,)
return self.create_node('placeholder', name, args=args, type_expr=type_expr)
@compatibility(is_backward_compatible=True)
def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:

View File

@ -2,6 +2,7 @@ import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
import inspect
@compatibility(is_backward_compatible=True)
class Partition:
@ -224,7 +225,9 @@ def split_module(
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
base_mod_env[node.name] = base_mod_graph.placeholder(node.name)
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(
node.name, type_expr=node.type, default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
elif node.op == 'get_attr':
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)