mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add flag to fx.passes.split_module to normalize input names (#157733)
This is useful for vLLM, which runs AOTAutograd directly on graphs after they have been split. I created a new flag for this instead of reusing `keep_original_node_name` (please let me know if you think I should reuse this). The reasoning is: - The names of the placeholder nodes is different from the targets of the placehoder nodes. The targets are the actual input names. - Backwards compatibility: this API has been out for ~4 years, it looks public, and it has extensive public use. For example, this change would actually be BC-breaking to vLLM (they rely on the subgraph input names being different at the moment). Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/157733 Approved by: https://github.com/ezyang
This commit is contained in:
@ -64,7 +64,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no
|
|||||||
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||||
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
|
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
|
||||||
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
|
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
|
||||||
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False)
|
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True)
|
||||||
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
||||||
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
|
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
|
||||||
torch.fx.proxy.Proxy.keys(self)
|
torch.fx.proxy.Proxy.keys(self)
|
||||||
|
@ -791,6 +791,46 @@ terrible spacing
|
|||||||
|
|
||||||
self.assertEqual(orig_out, submodules_out)
|
self.assertEqual(orig_out, submodules_out)
|
||||||
|
|
||||||
|
def test_split_module_input_names(self):
|
||||||
|
class Mod(torch.nn.Module):
|
||||||
|
def forward(self, x, a0, a1, b0, b1, c0, c1):
|
||||||
|
x = x + (a0 ** 2) + (a1 / 2)
|
||||||
|
x = x + (b0 ** 2) + (b1 / 2)
|
||||||
|
x = x + (c0 ** 2) + (c1 / 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
mod = Mod()
|
||||||
|
traced = torch.fx.symbolic_trace(mod)
|
||||||
|
|
||||||
|
seen = 0
|
||||||
|
|
||||||
|
def split(n):
|
||||||
|
nonlocal seen
|
||||||
|
result = seen // 4
|
||||||
|
seen += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
split = split_module(traced, mod, split, keep_original_input_name=False)
|
||||||
|
|
||||||
|
# All the submodules should take in the inputs in the same order.
|
||||||
|
args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)]
|
||||||
|
output0 = split.submod_0(*args)
|
||||||
|
output1 = split.submod_1(*args)
|
||||||
|
output2 = split.submod_2(*args)
|
||||||
|
self.assertEqual(output0, output1)
|
||||||
|
self.assertEqual(output1, output2)
|
||||||
|
|
||||||
|
# Each submodule should have normalized input names
|
||||||
|
def check_ph(gm):
|
||||||
|
nodes = list(gm.graph.nodes)
|
||||||
|
self.assertEqual(nodes[0].target, "arg_0")
|
||||||
|
self.assertEqual(nodes[1].target, "arg_1")
|
||||||
|
self.assertEqual(nodes[2].target, "arg_2")
|
||||||
|
|
||||||
|
check_ph(split.submod_0)
|
||||||
|
check_ph(split.submod_1)
|
||||||
|
check_ph(split.submod_2)
|
||||||
|
|
||||||
def test_split_module_dead_code(self):
|
def test_split_module_dead_code(self):
|
||||||
class ModWithDeadCode(torch.nn.Module):
|
class ModWithDeadCode(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -58,6 +58,7 @@ def split_module(
|
|||||||
qualname_map: Optional[dict[str, str]] = None,
|
qualname_map: Optional[dict[str, str]] = None,
|
||||||
keep_original_order: Optional[bool] = False,
|
keep_original_order: Optional[bool] = False,
|
||||||
keep_original_node_name: Optional[bool] = False,
|
keep_original_node_name: Optional[bool] = False,
|
||||||
|
keep_original_input_name: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates subgraphs out of main graph
|
Creates subgraphs out of main graph
|
||||||
@ -76,7 +77,10 @@ def split_module(
|
|||||||
names in the original module.
|
names in the original module.
|
||||||
keep_original_order: Optional[bool]: keep the original order of the GraphModule
|
keep_original_order: Optional[bool]: keep the original order of the GraphModule
|
||||||
or use the Topological order of the new constructed GraphModule
|
or use the Topological order of the new constructed GraphModule
|
||||||
|
keep_original_node_name: Optional[bool]: If the partitioned graphs should
|
||||||
|
have the same node names as the original graph.
|
||||||
|
keep_original_input_name: bool: If the partitioned graphs should
|
||||||
|
have the same input names as the original graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphModule: the module after split.
|
GraphModule: the module after split.
|
||||||
@ -419,11 +423,28 @@ def split_module(
|
|||||||
for partition_name in sorted_partitions:
|
for partition_name in sorted_partitions:
|
||||||
partition = partitions[partition_name]
|
partition = partitions[partition_name]
|
||||||
new_inputs: dict[str, None] = {}
|
new_inputs: dict[str, None] = {}
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
for inp in partition.inputs:
|
for inp in partition.inputs:
|
||||||
orig_node = orig_nodes[inp]
|
orig_node = orig_nodes[inp]
|
||||||
# We don't pass in get_attr nodes as inputs to the partition, but
|
# We don't pass in get_attr nodes as inputs to the partition, but
|
||||||
# instead set them as targets and use getattr within the module
|
# instead set them as targets and use getattr within the module
|
||||||
|
|
||||||
|
def add_placeholder():
|
||||||
|
if keep_original_input_name:
|
||||||
|
name = inp
|
||||||
|
else:
|
||||||
|
nonlocal counter
|
||||||
|
name = f"arg_{counter}"
|
||||||
|
counter += 1
|
||||||
|
placeholder = partition.graph.placeholder(
|
||||||
|
name,
|
||||||
|
type_expr=orig_nodes[inp].type,
|
||||||
|
)
|
||||||
|
new_inputs[inp] = None
|
||||||
|
return placeholder
|
||||||
|
|
||||||
if orig_node.op == "get_attr":
|
if orig_node.op == "get_attr":
|
||||||
assert isinstance(orig_node.target, str)
|
assert isinstance(orig_node.target, str)
|
||||||
|
|
||||||
@ -432,17 +453,9 @@ def split_module(
|
|||||||
placeholder = partition.graph.get_attr(orig_node.target)
|
placeholder = partition.graph.get_attr(orig_node.target)
|
||||||
partition.targets[orig_node.target] = orig_attr
|
partition.targets[orig_node.target] = orig_attr
|
||||||
else:
|
else:
|
||||||
placeholder = partition.graph.placeholder(
|
placeholder = add_placeholder()
|
||||||
inp,
|
|
||||||
type_expr=orig_nodes[inp].type,
|
|
||||||
)
|
|
||||||
new_inputs[inp] = None
|
|
||||||
else:
|
else:
|
||||||
placeholder = partition.graph.placeholder(
|
placeholder = add_placeholder()
|
||||||
inp,
|
|
||||||
type_expr=orig_nodes[inp].type,
|
|
||||||
)
|
|
||||||
new_inputs[inp] = None
|
|
||||||
placeholder.meta = orig_nodes[inp].meta.copy()
|
placeholder.meta = orig_nodes[inp].meta.copy()
|
||||||
partition.environment[orig_nodes[inp]] = placeholder
|
partition.environment[orig_nodes[inp]] = placeholder
|
||||||
partition.inputs = new_inputs
|
partition.inputs = new_inputs
|
||||||
|
Reference in New Issue
Block a user