Revert "[Dynamo] Optimize dedupe region ancestor tracking (#152589)"

This reverts commit b5f1345f72ec6d1b004b05284e9553e65ee03abc.

Reverted https://github.com/pytorch/pytorch/pull/152589 on behalf of https://github.com/jeanschmidt due to Breaking internal signal citadel-fbcode-test-mode-opt-for-pt2_stack_for_internal-linux-0 please see diff [D74531503](https://www.internalfb.com/diff/D74531503) for more details ([comment](https://github.com/pytorch/pytorch/pull/152410#issuecomment-2871168679))
This commit is contained in:
PyTorch MergeBot
2025-05-12 07:15:09 +00:00
parent 7243c69421
commit aa7fe6af41
5 changed files with 94 additions and 111 deletions

View File

@ -1,20 +1,20 @@
add_loop_eager,compile_time_instruction_count,3040000000,0.015
add_loop_eager,compile_time_instruction_count,2960000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,6037000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5806000000,0.025
add_loop_inductor,compile_time_instruction_count,29440000000,0.015
add_loop_inductor,compile_time_instruction_count,29160000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44130000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42960000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,25840000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,25630000000,0.015
@ -30,11 +30,11 @@ basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instructio
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000000,0.2
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9714000000,0.2
update_hint_regression,compile_time_instruction_count,1731000000,0.02
update_hint_regression,compile_time_instruction_count,1677500000,0.02
@ -42,19 +42,19 @@ float_args,compile_time_instruction_count,439200000,0.015
sum_floordiv_regression,compile_time_instruction_count,1005000000,0.015
sum_floordiv_regression,compile_time_instruction_count,998400000,0.015
symint_sum,compile_time_instruction_count,3234000000,0.015
symint_sum,compile_time_instruction_count,3227000000,0.015
symint_sum_loop,compile_time_instruction_count,4239000000,0.015
symint_sum_loop,compile_time_instruction_count,4224000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2080000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2075364055,0.015
@ -62,15 +62,15 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8586000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1895000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1884000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015

1 add_loop_eager compile_time_instruction_count 3040000000 2960000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 6037000000 5806000000 0.025
3 add_loop_inductor compile_time_instruction_count 29440000000 29160000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44130000000 42960000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25840000000 25630000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1011000000 1011000000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18150000000 18150000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16340000000 16340000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10370000000 9714000000 0.2
10 update_hint_regression compile_time_instruction_count 1731000000 1677500000 0.02
11 float_args compile_time_instruction_count 439200000 439200000 0.015
12 sum_floordiv_regression compile_time_instruction_count 1005000000 998400000 0.015
13 symint_sum compile_time_instruction_count 3234000000 3227000000 0.015
14 symint_sum_loop compile_time_instruction_count 4239000000 4224000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2080000000 2075364055 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5944000000 5944000000 0.015
17 aotdispatcher_partitioner_cpu compile_time_instruction_count 8630000000 8586000000 0.015
18 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1895000000 1884000000 0.015
19 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3818000000 3795000000 0.015
20 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10280000000 10280000000 0.015
30
31
32
33
34
35
36
37
38
39
40
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

View File

@ -1025,12 +1025,12 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
o1 = x0_1.view((10, 10)); x0_1 = None
add_ = l_x_.add_(l_x_); add_ = None
add_2 = o0 + o1; o0 = o1 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
@ -1063,11 +1063,11 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
mul_ = l_y_.mul_(l_y_); mul_ = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
@ -1093,12 +1093,12 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
o1 = x0_1.view((10, 10)); x0_1 = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None

View File

@ -61,12 +61,12 @@ class GraphRegionTrackerTests(TestCase):
return z
def fn(x, y):
o0 = inner_fn(x, y)
_o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4 + o0
return o2 * o4
self.assertExpectedInline(
self.get_result(

View File

@ -151,9 +151,6 @@ class BackwardBfsArgIter:
def create(origin: Node) -> "BackwardBfsArgIter":
it = BackwardBfsArgIter(origin)
it.add_children(origin)
# pop the origin node, since it is the origin of
# the region and does not need to be considered for addition
assert it.next()
return it
def next(self) -> Optional[Node]:
@ -168,11 +165,17 @@ class BackwardBfsArgIter:
return self._cur
def add_children(self, node: Node) -> None:
flat_args = _get_flat_args_unique(node, {})
arg: Any
flat_args, _ = tree_flatten(node.args)
for arg in flat_args:
if isinstance(arg, Node):
self._append(arg)
flat_kwargs, _ = tree_flatten(node.kwargs)
for kwarg in flat_kwargs:
if isinstance(kwarg, Node):
self._append(kwarg)
def _append(self, arg: Node) -> None:
if self._cur is None:
self._cur = arg
@ -325,38 +328,6 @@ class GraphRegionTracker:
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
class RegionWrapper:
"""Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
def __init__(
self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]]
) -> None:
assert len(region) == 1, "all regions should start with one node"
node = region[0]
self.node_to_recursive_ancestors = node_to_recursive_ancestors
self.iter = BackwardBfsArgIter.create(node)
self.nodes_unique = OrderedSet([node])
self.ancestors = set(node_to_recursive_ancestors[node])
self.region = region
def next_candidate(self) -> Optional[Node]:
return self.iter.next()
def will_inclusion_create_cycle(self, node: Node) -> bool:
external_users = [user for user in node.users if user not in self.nodes_unique]
for user in external_users:
if user in self.ancestors:
return True
return False
def add(self, node: Node) -> None:
self.nodes_unique.add(node)
self.region.append(node)
self.iter.add_children(node)
self.ancestors.update(self.node_to_recursive_ancestors[node])
def fully_expand_region_group(
regions: list[Region],
seen_nodes: set[Node],
@ -368,12 +339,20 @@ def fully_expand_region_group(
# All regions should start with 1 node
assert all(len(region) == 1 for region in regions)
region_wrappers = [
RegionWrapper(region, node_to_recursive_ancestors) for region in regions
]
region_iters = []
for region in regions:
(origin,) = region # Only works for 1 element sets
region_iters.append(BackwardBfsArgIter.create(origin))
nodes_to_add = OrderedSet[Node]()
current_node = region_wrappers[0].next_candidate()
nodes_to_add: list[Node] = []
# we already have the origin node in each region
for region_it in region_iters:
node = region_it.next()
assert node
region_it.add_children(node)
current_node = region_iters[0].next()
# No children
if current_node is None:
@ -383,51 +362,46 @@ def fully_expand_region_group(
# regions are only expanded if the node to add is valid
# for ALL regions
while current_node:
add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle(
current_node
add_node = not _will_create_cycle(
current_node, regions[0], node_to_recursive_ancestors
)
nodes_to_add.clear()
nodes_to_add.add(current_node)
for region_wrapper in region_wrappers[1:]:
candidate = region_wrapper.next_candidate()
nodes_to_add.append(current_node)
nodes_to_add_set = set(nodes_to_add)
for ind, region_it in enumerate(region_iters[1:]):
ind += 1 # compensate for the 0th region
node = region_it.next()
debug_log("--------------------")
debug_log(
"considering candidate: %s, cur_node: %s", candidate, current_node
)
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
node not in seen_nodes
and node not in nodes_to_add_set
and node.op != "placeholder"
and is_identical_fn(node, current_node)
and not _will_create_cycle(
node, regions[ind], node_to_recursive_ancestors
)
)
nodes_to_add.append(node)
nodes_to_add_set.add(node)
else:
add_node = False
if not candidate or not add_to_all_regions:
add_to_all_regions = False
continue
debug_log(
"candidate in previously claimed nodes?: %s", candidate in seen_nodes
)
debug_log("is_identical: %s", is_identical_fn(candidate, current_node))
add_to_all_regions &= (
candidate not in seen_nodes
and candidate not in nodes_to_add
and candidate.op != "placeholder"
and is_identical_fn(candidate, current_node)
and not region_wrapper.will_inclusion_create_cycle(candidate)
)
nodes_to_add.add(candidate)
debug_log(f"add_to_all_regions: {add_to_all_regions}")
debug_log("--------------------")
if add_to_all_regions:
assert len(region_wrappers) == len(nodes_to_add), (
"Numer of nodes to add must equal the number of regions"
)
for region_wrapper, node in zip(region_wrappers, nodes_to_add):
region_wrapper.add(node)
if add_node:
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
region.append(node)
debug_log("adding %s's children", node)
debug_log("%s %s", node.args, list(node.kwargs.items()))
region_it.add_children(node)
seen_nodes.add(node)
current_node = region_wrappers[0].next_candidate()
current_node = region_iters[0].next()
# Ensure regions are sorted in topological order
for region in regions:
@ -450,3 +424,20 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
)
node_to_recursive_ancestors[node].add(arg)
return node_to_recursive_ancestors
def _will_create_cycle(
node_to_add: Node,
region: Region,
node_to_recursive_ancestors: dict[Node, set[Node]],
) -> bool:
region_set: set[Node] = set(region)
region_ancestors: set[Node] = set(
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
)
external_users = [user for user in node_to_add.users if user not in region_set]
for user in external_users:
if user in region_ancestors:
return True
return False

View File

@ -3152,20 +3152,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
args, kwargs = get_fake_values_from_nodes(
tx, (node.args, node.kwargs), allow_non_graph_fake
)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
):
flat_args_kwargs = get_fake_values_from_nodes(
tx, _get_flat_args(node, {}), allow_non_graph_fake
)
id_to_initial_version = {
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
}
else:
flat_args_kwargs = []
id_to_initial_version = {}
flat_args_kwargs = get_fake_values_from_nodes(
tx, _get_flat_args(node, {}), allow_non_graph_fake
)
id_to_initial_version = {
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
}
nnmodule = None
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):