set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)

Two fixes:
1. in rein_place pass, set unbacked bindings for newly created nodes.
2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #164341
This commit is contained in:
Laith Sakka
2025-10-17 11:01:15 -07:00
committed by PyTorch MergeBot
parent c6a8db0b9a
commit 017d2985f3
3 changed files with 42 additions and 4 deletions

View File

@ -24,7 +24,10 @@ from torch._inductor.lowering import (
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
)
from torch._inductor.virtualized import V
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
GuardOnDataDependentSymNode,
)
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.passes.reinplace import _is_view_op
from torch.utils import _pytree as pytree
@ -60,7 +63,9 @@ def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
fake_result = fn(*fake_args, **fake_kwargs)
node = graph.call_function(fn, args, kwargs)
node.meta["val"] = fake_result
return node
@ -171,6 +176,13 @@ def _decompose_scatter_mutating(
tmp = inp
for view in view_ops: # type: ignore[union-attr]
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
# we need to set unbacked bindings that could have been created in the view ops.
if (V.fake_mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(
V.fake_mode.shape_env, tmp.meta["val"]
)
):
tmp.meta["unbacked_bindings"] = symbol_to_path
graph_call_function(graph, aten.copy_.default, tmp, src)
return inp # type: ignore[return-value]

View File

@ -4542,9 +4542,7 @@ class ComputedBuffer(OperationBuffer):
unbacked_only
) | self.data.get_free_symbol_uses(unbacked_only)
if self.has_store_function() and isinstance(
self.get_store_function(), LoopBody
):
if self.has_store_function():
result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
return result