Fix get_free_symbol_uses for several nodes (#160314)

get_free_symbol_uses is used to know what unbacked symbols are used by a given node.
not having correct get_free_symbol_uses defined properly leads to :

- eliminating of some nodes due to not detection of any users. (See the added unit test)
- Incorrect topological sort.

Fix get_free_symbol_uses , NopKernel , ConcarKernel, InputsKerenl, external kernel.
for ComputedBuffer with NonOwningLayout its interesting case.
when layout is NonOwningLayout we need to access the actual view op base layout and use
detect symbols in it. Because when we codegen the ComputedBuffer we uses those symbols.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160314
Approved by: https://github.com/eellison
This commit is contained in:
Laith Sakka
2025-08-13 00:19:28 -07:00
committed by PyTorch MergeBot
parent ecde76c764
commit 96bd33b2de
4 changed files with 110 additions and 38 deletions

View File

@ -74,15 +74,15 @@ aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
mm_loop_inductor_gpu,compile_time_instruction_count,4495000000,0.1
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8462000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9199000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.1
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

View File

@ -3616,6 +3616,17 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
def test_unbacked_select_index_cpp_wrapper(self):
self.test_unbacked_select_index()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select2(self):
def f(idx, x):
x = x.select(0, idx.item())
return x @ x
x = torch.randn(3, 3, 3)
idx = torch.tensor(1, dtype=torch.int64)
out = torch.compile(f)(idx, x)
self.assertEqual(out, f(idx, x))
instantiate_parametrized_tests(TestUnbacked)

View File

@ -3576,6 +3576,11 @@ class OutputSpec:
def storage_size(self) -> int:
raise NotImplementedError(type(self).__name__)
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
raise NotImplementedError(type(self).__name__)
@ir_dataclass
class Layout(OutputSpec):
@ -3807,6 +3812,15 @@ class Layout(OutputSpec):
def storage_size(self) -> Expr:
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return (
get_free_symbols(self.size, unbacked_only)
| get_free_symbols(self.stride, unbacked_only)
| get_free_symbols(self.offset, unbacked_only)
)
class FixedLayout(Layout):
"""A Tensor layout we cannot change"""
@ -3999,6 +4013,16 @@ class NonOwningLayout(Layout):
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
assert isinstance(self.view, ReinterpretView)
box = self.view.data
assert isinstance(box, StorageBox), type(box)
input_buffer = box.data
assert isinstance(input_buffer, Buffer), type(box)
return input_buffer.layout.get_free_symbol_uses(unbacked_only)
class CommBufferType(Enum):
SYMM_MEM = "symm_mem"
@ -4382,6 +4406,10 @@ class ShapeAsConstantBuffer(IRNode):
@ir_dataclass(frozen=False)
class ComputedBuffer(OperationBuffer):
"""
Represents a buffer that is computed during kernel execution rather than being an input.
"""
data: Loops
def get_computed_buffer_name(self) -> Optional[str]:
@ -4437,21 +4465,20 @@ class ComputedBuffer(OperationBuffer):
# those symbols that establishes a dependency). However, we haven't
# started codegen yet so we can't directly reuse that logic.
#
# For now, I'm just yoloing with the size of the buffer. Not sure if
# it is enough.
#
# One thing you might wonder is if this is enough for a ComputedBuffer
# denoting a reduction over i0. Empirically, it is enough, but for an
# unusual reason: we only need accurate dependencies for item() call,
# but it's impossible to end up with a reduction over i0 from an
# item() call without a regular non-reduction buffer first.
return (
get_free_symbols(self.get_size(), unbacked_only)
| get_free_symbols(self.get_stride(), unbacked_only)
| get_free_symbols(self.get_offset(), unbacked_only)
| self.data.get_free_symbol_uses(unbacked_only)
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
)
result = self.layout.get_free_symbol_uses(
unbacked_only
) | self.data.get_free_symbol_uses(unbacked_only)
if self.has_store_function() and isinstance(
self.get_store_function(), LoopBody
):
result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
return result
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
if (
@ -4463,6 +4490,9 @@ class ComputedBuffer(OperationBuffer):
return self.data.make_loader()
return super().make_loader()
def has_store_function(self) -> bool:
return isinstance(self.data, (Reduction, Scan, Sort, Pointwise))
def get_store_function(self) -> Callable[..., None]:
indexer = self.get_layout().as_fixed().make_indexer()
if isinstance(self.data, (Reduction, Scan, Sort)):
@ -5170,6 +5200,18 @@ class InputsKernel(OperationBuffer):
def num_reads(self) -> int:
return 1
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
r = OrderedSet[sympy.Symbol]()
for inp in self.inputs:
if isinstance(inp, IRNode):
r |= inp.get_free_symbol_uses(unbacked_only)
else:
for inner_inp in inp:
r |= inner_inp.get_free_symbol_uses(unbacked_only)
return r
class NopKernel(InputsKernel):
def is_no_op(self) -> bool:
@ -5332,6 +5374,11 @@ class ConcatKernel(NopKernel):
and not isinstance(src.data, ExternKernelAlloc)
)
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return NopKernel.get_free_symbol_uses(self, unbacked_only)
@classmethod
def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
# Attempt to turn this into a ReinterpretView rather than assert.
@ -6232,7 +6279,7 @@ class ExternKernel(InputsKernel):
maybe_get_symbols = (
maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
)
r = OrderedSet[sympy.Symbol]()
r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
for arg in self.constant_args:
r |= maybe_get_symbols(arg)
for arg in self.kwargs.values():
@ -8690,7 +8737,10 @@ class EffectfulKernel(FallbackKernel):
class NonTensorObj(IRNode):
pass
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
@ir_dataclass

View File

@ -2396,11 +2396,11 @@ class Scheduler:
for fs in s.free_symbols:
unbacked_symbol_to_origin_node[fs] = None
has_non_input_unbacked_defs = False
for node in self.nodes:
log.debug("scheduling %s", node.node)
assert node.node is not None
# unbacked symbols don't follow ordinary buffer dependencies, so
# we track their def/uses separately
assert node.node is not None
unbacked_symbol_defs = sorted(
node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
)
@ -2409,20 +2409,28 @@ class Scheduler:
# Pick the first definer as canonical. There may be multiple
# because if a MultiOutputLayout buffer propagates an unbacked
# symint to multiple outputs, they will all claim to def it.
has_non_input_unbacked_defs = True
if s not in unbacked_symbol_to_origin_node:
unbacked_symbol_to_origin_node[s] = node.get_name()
unbacked_symbol_uses = sorted(
node.node.get_free_symbol_uses(unbacked_only=True), key=lambda x: x.name
)
# if a kernel takes unbacked symints, register dependencies
for s in unbacked_symbol_uses:
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node}"
for node in self.nodes:
log.debug("scheduling %s", node.node)
if has_non_input_unbacked_defs:
assert node.node is not None
unbacked_symbol_uses = sorted(
node.node.get_free_symbol_uses(unbacked_only=True),
key=lambda x: x.name,
)
if (r := unbacked_symbol_to_origin_node[s]) is not None:
for buf in self.name_to_node[r].get_outputs():
node.add_fake_dep(StarDep(buf.get_name()))
# if a kernel takes unbacked symints, register dependencies
for s in unbacked_symbol_uses:
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node}"
)
if (r := unbacked_symbol_to_origin_node[s]) is not None:
for buf in self.name_to_node[r].get_outputs():
node.add_fake_dep(StarDep(buf.get_name()))
if (
len(node.read_writes.writes) == 1
@ -2477,17 +2485,20 @@ class Scheduler:
add_user(buf_name, OutputNode(StarDep(buf_name)))
# make sure unbacked symints aren't dead-code-eliminated
for out in V.graph.graph_outputs:
for s in out.get_free_symbol_uses(unbacked_only=True):
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
)
if r := unbacked_symbol_to_origin_node[s]:
for buf_name in self.name_to_node[r].get_buffer_names():
log.debug(
"scheduling output %s for unbacked symint %s", buf_name, s
)
add_user(buf_name, OutputNode(StarDep(buf_name)))
if has_non_input_unbacked_defs:
for out in V.graph.graph_outputs:
for s in out.get_free_symbol_uses(unbacked_only=True):
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
)
if r := unbacked_symbol_to_origin_node[s]:
for buf_name in self.name_to_node[r].get_buffer_names():
log.debug(
"scheduling output %s for unbacked symint %s",
buf_name,
s,
)
add_user(buf_name, OutputNode(StarDep(buf_name)))
# make sure input mutation isn't dead-code-eliminated
for name in self.mutation_renames: