mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "[Dynamo] Optimize dedupe region ancestor tracking (#152589)"
This reverts commit b5f1345f72ec6d1b004b05284e9553e65ee03abc. Reverted https://github.com/pytorch/pytorch/pull/152589 on behalf of https://github.com/jeanschmidt due to Breaking internal signal citadel-fbcode-test-mode-opt-for-pt2_stack_for_internal-linux-0 please see diff [D74531503](https://www.internalfb.com/diff/D74531503) for more details ([comment](https://github.com/pytorch/pytorch/pull/152410#issuecomment-2871168679))
This commit is contained in:
@ -1,20 +1,20 @@
|
||||
add_loop_eager,compile_time_instruction_count,3040000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,2960000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,6037000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5806000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,29440000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,29160000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44130000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42960000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25840000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25630000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -30,11 +30,11 @@ basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instructio
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000000,0.2
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,9714000000,0.2
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1731000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1677500000,0.02
|
||||
|
||||
|
||||
|
||||
@ -42,19 +42,19 @@ float_args,compile_time_instruction_count,439200000,0.015
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,1005000000,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,998400000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3234000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3227000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum_loop,compile_time_instruction_count,4239000000,0.015
|
||||
symint_sum_loop,compile_time_instruction_count,4224000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2080000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2075364055,0.015
|
||||
|
||||
|
||||
|
||||
@ -62,15 +62,15 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8586000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1895000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1884000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -1025,12 +1025,12 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
@ -1063,11 +1063,11 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
@ -1093,12 +1093,12 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
|
@ -61,12 +61,12 @@ class GraphRegionTrackerTests(TestCase):
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
_o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(y)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
o4 = o3 * o3
|
||||
return o2 * o4 + o0
|
||||
return o2 * o4
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
|
@ -151,9 +151,6 @@ class BackwardBfsArgIter:
|
||||
def create(origin: Node) -> "BackwardBfsArgIter":
|
||||
it = BackwardBfsArgIter(origin)
|
||||
it.add_children(origin)
|
||||
# pop the origin node, since it is the origin of
|
||||
# the region and does not need to be considered for addition
|
||||
assert it.next()
|
||||
return it
|
||||
|
||||
def next(self) -> Optional[Node]:
|
||||
@ -168,11 +165,17 @@ class BackwardBfsArgIter:
|
||||
return self._cur
|
||||
|
||||
def add_children(self, node: Node) -> None:
|
||||
flat_args = _get_flat_args_unique(node, {})
|
||||
arg: Any
|
||||
flat_args, _ = tree_flatten(node.args)
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, Node):
|
||||
self._append(arg)
|
||||
|
||||
flat_kwargs, _ = tree_flatten(node.kwargs)
|
||||
for kwarg in flat_kwargs:
|
||||
if isinstance(kwarg, Node):
|
||||
self._append(kwarg)
|
||||
|
||||
def _append(self, arg: Node) -> None:
|
||||
if self._cur is None:
|
||||
self._cur = arg
|
||||
@ -325,38 +328,6 @@ class GraphRegionTracker:
|
||||
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
|
||||
|
||||
|
||||
class RegionWrapper:
|
||||
"""Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
|
||||
|
||||
def __init__(
|
||||
self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]]
|
||||
) -> None:
|
||||
assert len(region) == 1, "all regions should start with one node"
|
||||
node = region[0]
|
||||
self.node_to_recursive_ancestors = node_to_recursive_ancestors
|
||||
self.iter = BackwardBfsArgIter.create(node)
|
||||
self.nodes_unique = OrderedSet([node])
|
||||
self.ancestors = set(node_to_recursive_ancestors[node])
|
||||
self.region = region
|
||||
|
||||
def next_candidate(self) -> Optional[Node]:
|
||||
return self.iter.next()
|
||||
|
||||
def will_inclusion_create_cycle(self, node: Node) -> bool:
|
||||
external_users = [user for user in node.users if user not in self.nodes_unique]
|
||||
for user in external_users:
|
||||
if user in self.ancestors:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def add(self, node: Node) -> None:
|
||||
self.nodes_unique.add(node)
|
||||
self.region.append(node)
|
||||
self.iter.add_children(node)
|
||||
self.ancestors.update(self.node_to_recursive_ancestors[node])
|
||||
|
||||
|
||||
def fully_expand_region_group(
|
||||
regions: list[Region],
|
||||
seen_nodes: set[Node],
|
||||
@ -368,12 +339,20 @@ def fully_expand_region_group(
|
||||
|
||||
# All regions should start with 1 node
|
||||
assert all(len(region) == 1 for region in regions)
|
||||
region_wrappers = [
|
||||
RegionWrapper(region, node_to_recursive_ancestors) for region in regions
|
||||
]
|
||||
region_iters = []
|
||||
for region in regions:
|
||||
(origin,) = region # Only works for 1 element sets
|
||||
region_iters.append(BackwardBfsArgIter.create(origin))
|
||||
|
||||
nodes_to_add = OrderedSet[Node]()
|
||||
current_node = region_wrappers[0].next_candidate()
|
||||
nodes_to_add: list[Node] = []
|
||||
|
||||
# we already have the origin node in each region
|
||||
for region_it in region_iters:
|
||||
node = region_it.next()
|
||||
assert node
|
||||
region_it.add_children(node)
|
||||
|
||||
current_node = region_iters[0].next()
|
||||
|
||||
# No children
|
||||
if current_node is None:
|
||||
@ -383,51 +362,46 @@ def fully_expand_region_group(
|
||||
# regions are only expanded if the node to add is valid
|
||||
# for ALL regions
|
||||
while current_node:
|
||||
add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle(
|
||||
current_node
|
||||
add_node = not _will_create_cycle(
|
||||
current_node, regions[0], node_to_recursive_ancestors
|
||||
)
|
||||
nodes_to_add.clear()
|
||||
nodes_to_add.add(current_node)
|
||||
for region_wrapper in region_wrappers[1:]:
|
||||
candidate = region_wrapper.next_candidate()
|
||||
nodes_to_add.append(current_node)
|
||||
nodes_to_add_set = set(nodes_to_add)
|
||||
for ind, region_it in enumerate(region_iters[1:]):
|
||||
ind += 1 # compensate for the 0th region
|
||||
node = region_it.next()
|
||||
|
||||
debug_log("--------------------")
|
||||
debug_log(
|
||||
"considering candidate: %s, cur_node: %s", candidate, current_node
|
||||
)
|
||||
debug_log("considering adding: %s, cur_node: %s", node, current_node)
|
||||
debug_log("previously claimed nodes: %s", node in seen_nodes)
|
||||
if node:
|
||||
debug_log("is_identical: %s", is_identical_fn(node, current_node))
|
||||
add_node &= (
|
||||
node not in seen_nodes
|
||||
and node not in nodes_to_add_set
|
||||
and node.op != "placeholder"
|
||||
and is_identical_fn(node, current_node)
|
||||
and not _will_create_cycle(
|
||||
node, regions[ind], node_to_recursive_ancestors
|
||||
)
|
||||
)
|
||||
nodes_to_add.append(node)
|
||||
nodes_to_add_set.add(node)
|
||||
else:
|
||||
add_node = False
|
||||
|
||||
if not candidate or not add_to_all_regions:
|
||||
add_to_all_regions = False
|
||||
continue
|
||||
|
||||
debug_log(
|
||||
"candidate in previously claimed nodes?: %s", candidate in seen_nodes
|
||||
)
|
||||
debug_log("is_identical: %s", is_identical_fn(candidate, current_node))
|
||||
|
||||
add_to_all_regions &= (
|
||||
candidate not in seen_nodes
|
||||
and candidate not in nodes_to_add
|
||||
and candidate.op != "placeholder"
|
||||
and is_identical_fn(candidate, current_node)
|
||||
and not region_wrapper.will_inclusion_create_cycle(candidate)
|
||||
)
|
||||
nodes_to_add.add(candidate)
|
||||
|
||||
debug_log(f"add_to_all_regions: {add_to_all_regions}")
|
||||
debug_log("--------------------")
|
||||
|
||||
if add_to_all_regions:
|
||||
assert len(region_wrappers) == len(nodes_to_add), (
|
||||
"Numer of nodes to add must equal the number of regions"
|
||||
)
|
||||
for region_wrapper, node in zip(region_wrappers, nodes_to_add):
|
||||
region_wrapper.add(node)
|
||||
if add_node:
|
||||
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
|
||||
region.append(node)
|
||||
debug_log("adding %s's children", node)
|
||||
debug_log("%s %s", node.args, list(node.kwargs.items()))
|
||||
region_it.add_children(node)
|
||||
seen_nodes.add(node)
|
||||
|
||||
current_node = region_wrappers[0].next_candidate()
|
||||
current_node = region_iters[0].next()
|
||||
|
||||
# Ensure regions are sorted in topological order
|
||||
for region in regions:
|
||||
@ -450,3 +424,20 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
|
||||
)
|
||||
node_to_recursive_ancestors[node].add(arg)
|
||||
return node_to_recursive_ancestors
|
||||
|
||||
|
||||
def _will_create_cycle(
|
||||
node_to_add: Node,
|
||||
region: Region,
|
||||
node_to_recursive_ancestors: dict[Node, set[Node]],
|
||||
) -> bool:
|
||||
region_set: set[Node] = set(region)
|
||||
region_ancestors: set[Node] = set(
|
||||
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
|
||||
)
|
||||
external_users = [user for user in node_to_add.users if user not in region_set]
|
||||
for user in external_users:
|
||||
if user in region_ancestors:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -3152,20 +3152,12 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
args, kwargs = get_fake_values_from_nodes(
|
||||
tx, (node.args, node.kwargs), allow_non_graph_fake
|
||||
)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
):
|
||||
flat_args_kwargs = get_fake_values_from_nodes(
|
||||
tx, _get_flat_args(node, {}), allow_non_graph_fake
|
||||
)
|
||||
id_to_initial_version = {
|
||||
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
|
||||
}
|
||||
else:
|
||||
flat_args_kwargs = []
|
||||
id_to_initial_version = {}
|
||||
flat_args_kwargs = get_fake_values_from_nodes(
|
||||
tx, _get_flat_args(node, {}), allow_non_graph_fake
|
||||
)
|
||||
id_to_initial_version = {
|
||||
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
|
||||
}
|
||||
|
||||
nnmodule = None
|
||||
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
|
||||
|
Reference in New Issue
Block a user