[GraphPartition] cache get_free_symbol_uses (#166338)

Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
ee7434be82/torch/_inductor/ir.py (L4541-L4543)

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-10-31 02:50:10 +00:00
committed by PyTorch MergeBot
parent 12577064dd
commit a6b1ef1717
3 changed files with 177 additions and 4 deletions

View File

@ -11005,6 +11005,29 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
p = torch.tensor(0.50, device=self.device)
get_mask(x, p)
def test_flexible_layout_immutable_free_symbols(self):
import sympy
x = sympy.Symbol("x")
y = sympy.Symbol("y")
z = sympy.Symbol("z")
layout = torch._inductor.ir.FlexibleLayout(
self.device, torch.float32, size=(x, y)
)
# pad_strides works since it does not add new symints
layout.pad_strides()
# same symints and different order should work
layout.size = (y, x)
# adding new symints should fail
with self.assertRaisesRegex(
AssertionError, "Expected free symbols unchanged, but got"
):
layout.size = (z,)
def test_sqrt_dynamic_shapes(self):
# TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877.
# TODO: support cuda path.

View File

@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
free_symbols,
free_unbacked_symbols,
IterateExprs,
rebind_unbacked,
resolve_unbacked_bindings,
ShapeEnv,
@ -97,6 +98,7 @@ from .utils import (
argsort,
argsort_sym,
cache_on_self,
cache_on_self_and_args,
ceildiv,
convert_shape_to_inductor,
convert_shape_to_symint,
@ -933,6 +935,7 @@ class Loops(IRNode):
inner_fn: Callable[..., Any]
ranges: Sequence[_IntLike]
@cache_on_self_and_args("Loops")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -1228,6 +1231,7 @@ class Reduction(Loops):
__repr__ = __str__
@cache_on_self_and_args("Reduction")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
@ -2327,6 +2331,7 @@ class Scan(Loops):
# HACK we mimic reduction
@cache_on_self_and_args("Scan")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
# need to explicitly represent the closure so we can pull out unbacked
@ -2537,6 +2542,7 @@ class Sort(Loops):
# HACK we mimic reduction
@cache_on_self_and_args("Sort")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return (
super().get_free_symbol_uses(unbacked_only)
@ -2785,6 +2791,7 @@ def is_unaligned(node: IRNode) -> bool:
class BaseView(IRNode):
data: IRNode
@cache_on_self_and_args("BaseView")
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
return self.data.get_free_symbol_uses(unbacked_only)
@ -3359,6 +3366,7 @@ class ReinterpretView(BaseView):
def freeze_layout(self) -> None:
pass
@cache_on_self_and_args("ReinterpretView")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -3643,13 +3651,37 @@ class Layout(OutputSpec):
self.dtype = dtype
assert len(size) == len(stride), f"size={size}, stride={stride}"
assert all(isinstance(s, (Expr, int)) for s in size)
self.size = size
self.stride = stride
self.offset = offset
self._size = size
self._stride = stride
self._offset = offset
self.is_pinned = is_pinned
# is_pinned implies cpu
assert (not self.is_pinned) or (self.device.type == "cpu")
@property
def size(self) -> Sequence[Expr]:
return self._size
@size.setter
def size(self, value: Sequence[Expr]) -> None:
self._size = value
@property
def stride(self) -> Sequence[Expr]:
return self._stride
@stride.setter
def stride(self, value: Sequence[Expr]) -> None:
self._stride = value
@property
def offset(self) -> Expr:
return self._offset
@offset.setter
def offset(self, value: Expr) -> None:
self._offset = value
def __str__(self) -> str:
offset = ""
if self.offset != 0:
@ -3869,6 +3901,7 @@ class Layout(OutputSpec):
def storage_size(self) -> Expr:
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
@cache_on_self_and_args("Layout")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -3888,7 +3921,11 @@ class FixedLayout(Layout):
class FlexibleLayout(Layout):
"""A Tensor layout that we are allowed to change"""
"""
A Tensor layout that we are allowed to change
Assumption: layout change should NOT add or remove free symbols
"""
allow_indexing = False
@ -3973,6 +4010,33 @@ class FlexibleLayout(Layout):
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
return FlexibleLayout.fill_ordered(sizes, fill_order)
@property
def size(self) -> Sequence[Expr]:
return self._size
@size.setter
def size(self, value: Sequence[Expr]) -> None:
self.assert_free_symbol_uses_unchanged("size", value)
self._size = value
@property
def stride(self) -> Sequence[Expr]:
return self._stride
@stride.setter
def stride(self, value: Sequence[Expr]) -> None:
self.assert_free_symbol_uses_unchanged("stride", value)
self._stride = value
@property
def offset(self) -> Expr:
return self._offset
@offset.setter
def offset(self, value: Expr) -> None:
self.assert_free_symbol_uses_unchanged("offset", value)
self._offset = value
def as_stride_order(
self, order: Sequence[int], allow_padding: bool = False
) -> FixedLayout:
@ -4031,6 +4095,25 @@ class FlexibleLayout(Layout):
self.is_pinned,
)
def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
initial_free_symbols = {}
for name in ["size", "stride", "offset"]:
for unbacked_only in [True, False]:
key = (name, unbacked_only)
initial_free_symbols[key] = OrderedSet(
get_free_symbols(getattr(self, name), unbacked_only)
)
return initial_free_symbols
def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
for unbacked_only in [True, False]:
old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
assert new_free_symbols == old_free_symbols, (
f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
)
def __init__(
self,
device: torch.device,
@ -4045,6 +4128,10 @@ class FlexibleLayout(Layout):
strides = FlexibleLayout.contiguous_strides(size)
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
# record the initial free symbols to check that we do not add new free symbols
# later when modifying sizes, strides, and offsets.
self.initial_free_symbols = self.get_initial_free_symbol_uses()
class NonOwningLayout(Layout):
"""Is a view into the storage of another tensor"""
@ -4070,6 +4157,7 @@ class NonOwningLayout(Layout):
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
@cache_on_self_and_args("NonOwningLayout")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4358,6 +4446,7 @@ class Buffer(IRNode, CodegenSymbol):
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet([self.get_name()])
@cache_on_self_and_args("Buffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4430,6 +4519,7 @@ class NoneAsConstantBuffer(IRNode):
def get_reads(self) -> OrderedSet[Dep]:
return OrderedSet()
@cache_on_self_and_args("NoneAsConstantBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4449,6 +4539,7 @@ class NoneAsConstantBuffer(IRNode):
class ShapeAsConstantBuffer(IRNode):
expr: Expr
@cache_on_self_and_args("ShapeAsConstantBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4521,6 +4612,7 @@ class ComputedBuffer(OperationBuffer):
self.data.get_size(),
)
@cache_on_self_and_args("ComputedBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -4974,6 +5066,7 @@ class TritonTemplateBuffer(TemplateBuffer):
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
@cache_on_self_and_args("TritonTemplateBuffer")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -5340,6 +5433,7 @@ class InputsKernel(OperationBuffer):
def num_reads(self) -> int:
return 1
@cache_on_self_and_args("InputsKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -5514,6 +5608,7 @@ class ConcatKernel(NopKernel):
and not isinstance(src.data, ExternKernelAlloc)
)
@cache_on_self_and_args("ConcatKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -6430,6 +6525,7 @@ class ExternKernel(InputsKernel):
index = sympy_subs(sympy.expand(index), replacement)
return index, tuple(new_sizes)
@cache_on_self_and_args("ExternKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -6889,6 +6985,7 @@ class UserDefinedTritonKernel(ExternKernel):
original_fxnode_name=self.fx_node.name,
)
@cache_on_self_and_args("UserDefinedTritonKernel")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7327,6 +7424,7 @@ class DynamicSelectStorageOffset(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_offset_symbol])
@cache_on_self_and_args("DynamicSelectStorageOffset")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7377,6 +7475,7 @@ class DynamicSliceSize(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_size_symbol])
@cache_on_self_and_args("DynamicSliceSize")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -7441,6 +7540,7 @@ class AssertScalar(ExternKernel):
def has_side_effects(self) -> bool:
return True
@cache_on_self_and_args("AssertScalar")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -8115,6 +8215,7 @@ class MultiOutput(ExternKernel):
self.indices = indices
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
@cache_on_self_and_args("MultiOutput")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -8237,6 +8338,7 @@ class MutableBox(IRNode):
def realize(self) -> Optional[str]:
return self.data.realize()
@cache_on_self_and_args("MutableBox")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
@ -9073,6 +9175,7 @@ class EffectfulKernel(FallbackKernel):
class NonTensorObj(IRNode):
@cache_on_self_and_args("NonTensorObj")
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:

View File

@ -662,6 +662,7 @@ def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
P = ParamSpec("P")
RV = TypeVar("RV", covariant=True)
FN_TYPE = Callable[Concatenate[Any, P], RV]
class CachedMethod(Protocol, Generic[P, RV]):
@ -709,6 +710,52 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
return cache_on_self(fn)
def cache_on_self_and_args(
class_name: str,
) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
# include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
def wrapper(
fn: FN_TYPE[P, RV],
) -> FN_TYPE[P, RV]:
key = f"__{class_name}_{fn.__name__}_cache"
# wrapper is likely on the hot path, compile a specialized version of it
ctx = {"fn": fn}
exec(
f"""\
def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
args_kwargs = (args, tuple(sorted(kwargs.items())))
if not hasattr(self, "{key}"):
object.__setattr__(self, "{key}", {{}})
cache = self.{key}
try:
return cache[args_kwargs]
except KeyError:
pass
rv = fn(self, *args, **kwargs)
cache[args_kwargs] = rv
return rv
""".lstrip(),
ctx,
)
inner = functools.wraps(fn)(ctx["inner"])
def clear_cache(self: Any) -> None:
if hasattr(self, key):
delattr(self, key)
inner.clear_cache = clear_cache # type: ignore[attr-defined]
return inner
return wrapper
def aggregate_origins(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
) -> OrderedSet[Node]: