Compare commits

...

3 Commits

6 changed files with 174 additions and 80 deletions

View File

@ -426,7 +426,7 @@ enable_cpp_framelocals_guard_eval = True
# Whether to automatically find and replace identical graph
# regions with a call to invoke_subgraph
use_graph_deduplication = False
use_graph_deduplication = True
# Whether to track nodes for deduplication (testing only)
# This flag is ignored if use_graph_deduplication is True

View File

@ -12,17 +12,18 @@ import operator
from collections.abc import Iterable
from typing import Any
import torch
import torch.fx
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.utils._pytree import tree_flatten
from torch._dynamo import config
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _flatten_args_kwargs
log = logging.getLogger(__name__)
def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore[no-untyped-def]
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
"""
This is the main entry point for applying the graph deduplication pass. \
Deduplication occurs in two phases:
@ -49,15 +50,14 @@ The deduplication mutates the output_graph argument in place.
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
when they are created in output_graph.
"""
from torch._inductor.pattern_matcher import stable_topological_sort
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
# Used to track which nodes were replaced with subgraph outputs
# today, we have to register the new subgraph submodules before the
# graph outputs have been created, so we pass the replacement mapping
# back to output graph to do the replacements at the site of output creation
output_replacements: dict[Node, Node] = {}
sub_gms: dict[str, torch.fx.GraphModule] = {}
for region_group in duplicated_region_groups:
inds_with_external_users = _get_all_output_indices(region_group)
region = region_group[0]
@ -65,8 +65,14 @@ when they are created in output_graph.
subgraph,
node_ind_arg_inds,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
if not list(node_ind_arg_inds):
continue
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
sub_gms[subgraph_name] = sub_gm
with output_graph.graph.inserting_before():
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
@ -78,36 +84,11 @@ when they are created in output_graph.
get_subgraph_node,
node_ind_arg_inds.keys(),
inds_with_external_users,
sub_gm,
subgraph_name,
output_replacements,
)
return output_replacements
# flattens with support for slices
# Note: a better way to do this would
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _flatten_args_kwargs(args: Any) -> list[Node]:
fully_flattened = []
def flatten(args: Any) -> None:
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
stable_topological_sort(output_graph.graph)
return sub_gms
def _replace_region_with_subgraph(
@ -116,9 +97,7 @@ def _replace_region_with_subgraph(
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[tuple[int, int]],
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
output_replacements: dict[Node, Node],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
@ -127,32 +106,35 @@ def _replace_region_with_subgraph(
sub_args.append(flattened_args_kwargs[arg_ind])
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
# fake_inputs = [node.meta["example_value"] for node in sub_args]
# if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
# log.debug(
# "NYI: Failed to substitute region %s due to input alias or mutation",
# region,
# )
# return
from torch._inductor.pattern_matcher import stable_topological_sort
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
return
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
latest_region_node = region[-1]
with graph.inserting_after(latest_region_node):
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
with graph.inserting_after(invoke_subgraph_node):
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
output_replacements[node] = subgraph_output
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
if config.graph_deduplication_lint:
_detect_cycles(graph)
stable_topological_sort(graph)
graph.lint()
def _get_external_inputs(

View File

@ -27,6 +27,8 @@ import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import tree_flatten
from .graph_utils import _flatten_args_kwargs
T = TypeVar("T")
@ -253,6 +255,8 @@ class GraphRegionTracker:
"""
topological_ranking = {node: i for i, node in enumerate(graph.nodes)}
region_groups_with_rank = []
# needed to detect if replacing a region will create cycles
node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph)
# Create region groups; a region group is a group
# of regions that are all identical. In this initial state
@ -281,7 +285,12 @@ class GraphRegionTracker:
# overlap.
seen_nodes: set[Node] = set()
for region_group in region_groups:
fully_expand_region_group(region_group, seen_nodes, self._is_identical)
fully_expand_region_group(
region_group,
seen_nodes,
node_to_recursive_ancestors,
self._is_identical,
)
# sort topologically
for region in region_group:
region.sort(key=lambda n: topological_ranking[n])
@ -297,6 +306,7 @@ class GraphRegionTracker:
def fully_expand_region_group(
regions: list[Region],
seen_nodes: set[Node],
node_to_recursive_ancestors: dict[Node, set[Node]],
is_identical_fn: Callable[[Node, Node], bool],
) -> None:
debug_log("--------------------------------------------------")
@ -327,17 +337,19 @@ def fully_expand_region_group(
# regions are only expanded if the node to add is valid
# for ALL regions
while current_node:
add_node = True
add_node = not _will_create_cycle(
current_node, regions[0], node_to_recursive_ancestors
)
nodes_to_add.clear()
nodes_to_add.append(current_node)
nodes_to_add_set = set(nodes_to_add)
for region_it in region_iters[1:]:
for ind, region_it in enumerate(region_iters[1:]):
ind += 1 # compensate for the 0th region
node = region_it.next()
debug_log("--------------------")
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
debug_log("%s", seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
@ -345,6 +357,9 @@ def fully_expand_region_group(
and node not in nodes_to_add_set
and node.op != "placeholder"
and is_identical_fn(node, current_node)
and not _will_create_cycle(
node, regions[ind], node_to_recursive_ancestors
)
)
nodes_to_add.append(node)
nodes_to_add_set.add(node)
@ -369,3 +384,35 @@ def fully_expand_region_group(
debug_log("end expand new region group: %s", regions)
debug_log("--------------------------------------------------")
def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[Node]]:
node_to_recursive_ancestors: dict[Node, set[Node]] = {}
for node in graph.nodes:
node_to_recursive_ancestors[node] = set()
for node in graph.nodes:
all_args = _flatten_args_kwargs((node.args, node.kwargs))
for arg in all_args:
if isinstance(arg, Node):
node_to_recursive_ancestors[node].update(
node_to_recursive_ancestors[arg]
)
node_to_recursive_ancestors[node].add(node)
return node_to_recursive_ancestors
def _will_create_cycle(
node_to_add: Node,
region: Region,
node_to_recursive_ancestors: dict[Node, set[Node]],
) -> bool:
region_set: set[Node] = set(region)
region_ancestors: set[Node] = set(
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
)
external_users = [user for user in node_to_add.users if user not in region_set]
for user in external_users:
if user in region_ancestors:
return True
return False

View File

@ -0,0 +1,68 @@
from collections import deque
from typing import Any
from torch.fx import Graph, Node
from torch.utils._pytree import tree_flatten
# flattens with support for slices
# Note: a better way to do this would
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _flatten_args_kwargs(args: Any) -> list[Node]:
fully_flattened = []
def flatten(args: Any) -> None:
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
def _detect_cycles(graph: Graph) -> str:
current_path: deque[Node] = deque()
current_path_set: set[Node] = set()
pending: deque[tuple[Node, Node]] = deque()
def add_to_current_path(node: Node) -> None:
current_path.append(node)
current_path_set.add(node)
def pop_current_path() -> None:
node = current_path.pop()
current_path_set.remove(node)
def current_path_head() -> Node:
return current_path[-1]
for origin in graph.find_nodes(op="placeholder"):
current_path.clear()
current_path_set.clear()
add_to_current_path(origin)
for child in origin.users:
pending.append((child, origin))
while pending:
cur_node, parent = pending.pop()
while current_path_head() != parent:
pop_current_path()
if cur_node in current_path_set:
return f"cycle detected in path: {current_path}"
add_to_current_path(cur_node)
for child in cur_node.users:
pending.append((child, cur_node))
return "no cycle detected"

View File

@ -227,6 +227,10 @@ class FakeRootModule(torch.nn.Module):
def __repr__(self) -> str:
return "FakeRootModule(...)"
def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]):
for k, v in nn_modules.items():
setattr(self, k, v)
class WrapperBackend:
def __init__(self, backend: CompilerFn):
@ -1057,8 +1061,6 @@ class OutputGraph:
for value in stack_values:
value.realize()
output_replacements = self.dedup_pass()
# Use nn.Module "proxies" in the constructed GraphModule so that
# the resulting GM does not hold additional strong references to the original modules.
# This prevents a strong ref cycle where Dynamo created code holds on to references
@ -1142,9 +1144,7 @@ class OutputGraph:
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(
tx, list(reversed(stack_values)), root, output_replacements
)
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
)
# restore all the live local vars
@ -1177,9 +1177,7 @@ class OutputGraph:
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(
tx, pass2.graph_output_vars(), root, output_replacements
)
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
)
if len(pass2.graph_outputs) != 0:
@ -1343,7 +1341,7 @@ class OutputGraph:
tx.speculation_log.clear()
raise exc.CompileCollectiveRestartAnalysis
def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
def compile_and_call_fx_graph(self, tx, rv, root):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
@ -1366,9 +1364,8 @@ class OutputGraph:
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
{},
)
for old_node, new_node in replaced_outputs.items():
old_node.replace_all_uses_with(new_node)
sub_gms = self.dedup_pass()
root.add_nn_modules(sub_gms)
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
if not config.do_not_emit_runtime_asserts:
@ -1563,7 +1560,7 @@ class OutputGraph:
if torch._dynamo.config.use_graph_deduplication:
return apply_graph_deduplication(self)
else:
return dict()
return {}
def install_subgraph(self, name, sub_gm):
next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)

View File

@ -583,9 +583,9 @@ class MetalKernel(SIMDKernel):
dtype=dtype,
)
if reduction_type == "welford_reduce":
assert not self.multistage_reduction, (
f"Multistage reduction not yet supported for {reduction_type}"
)
assert (
not self.multistage_reduction
), f"Multistage reduction not yet supported for {reduction_type}"
acc_buf = self._new_accvar(src_dtype, acc_buf_size)
self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};")
wf_res = self.cse.generate(