mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FX] Add a default_value arg to Graph.placeholder and fix split_module (#71016)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71016 I found out that `split_module` doesn't preserve default values for arguments. In trying to fix that, I noticed that `Graph.placeholder` doesn't make it easy to add a default argument when making a placeholder. This PR addresses both of those issues Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D33482218 Pulled By: jamesr66a fbshipit-source-id: 57ebcdab25d267333fb1034994e08fc1bdb128ee
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5749be4678
commit
de902b5d02
@ -20,7 +20,7 @@ torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = No
|
||||
torch.fx.graph.Graph.lint(self)
|
||||
torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = <function <lambda>>) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
|
@ -852,6 +852,28 @@ terrible spacing
|
||||
module_with_submodules = split_module(traced, m, lambda node: 0)
|
||||
module_with_submodules(a)
|
||||
|
||||
def test_split_module_default_arg(self):
|
||||
class ModelToTrace(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(512, 512)
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
x = self.lin(x)
|
||||
|
||||
if targets is not None:
|
||||
x = x + targets
|
||||
|
||||
return x
|
||||
|
||||
mtt = ModelToTrace()
|
||||
traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
|
||||
|
||||
split = split_module(traced, mtt, lambda node: 0)
|
||||
|
||||
x = torch.randn(50, 512)
|
||||
torch.testing.assert_allclose(split(x), traced(x))
|
||||
|
||||
def test_normalize_binary_operators(self):
|
||||
ops_to_test = {
|
||||
torch.add,
|
||||
|
@ -14,6 +14,7 @@ import re
|
||||
import builtins
|
||||
import math
|
||||
import warnings
|
||||
import inspect
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -522,7 +523,8 @@ class Graph:
|
||||
return _InsertPoint(self, n.append)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node:
|
||||
def placeholder(self, name: str, type_expr: Optional[Any] = None,
|
||||
default_value : Any = inspect.Signature.empty) -> Node:
|
||||
"""
|
||||
Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
|
||||
a function input.
|
||||
@ -537,11 +539,17 @@ class Graph:
|
||||
cases for proper code generation (e.g. when the function is used
|
||||
subsequently in TorchScript compilation).
|
||||
|
||||
default_value (Any): The default value this function argument should take
|
||||
on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
|
||||
should be passed as this argument to specify that the parameter does _not_
|
||||
have a default value.
|
||||
|
||||
.. note::
|
||||
The same insertion point and type expression rules apply for this method
|
||||
as ``Graph.create_node``.
|
||||
"""
|
||||
return self.create_node('placeholder', name, type_expr=type_expr)
|
||||
args = () if default_value is inspect.Signature.empty else (default_value,)
|
||||
return self.create_node('placeholder', name, args=args, type_expr=type_expr)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
import inspect
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Partition:
|
||||
@ -224,7 +225,9 @@ def split_module(
|
||||
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name)
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(
|
||||
node.name, type_expr=node.type, default_value=default_value)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
elif node.op == 'get_attr':
|
||||
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
|
||||
|
Reference in New Issue
Block a user