mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 22:14:53 +08:00
[GraphPartition] cache get_free_symbol_uses (#166338)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.ee7434be82/torch/_inductor/ir.py (L4541-L4543)This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
12577064dd
commit
a6b1ef1717
@ -11005,6 +11005,29 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
p = torch.tensor(0.50, device=self.device)
|
||||
get_mask(x, p)
|
||||
|
||||
def test_flexible_layout_immutable_free_symbols(self):
|
||||
import sympy
|
||||
|
||||
x = sympy.Symbol("x")
|
||||
y = sympy.Symbol("y")
|
||||
z = sympy.Symbol("z")
|
||||
|
||||
layout = torch._inductor.ir.FlexibleLayout(
|
||||
self.device, torch.float32, size=(x, y)
|
||||
)
|
||||
|
||||
# pad_strides works since it does not add new symints
|
||||
layout.pad_strides()
|
||||
|
||||
# same symints and different order should work
|
||||
layout.size = (y, x)
|
||||
|
||||
# adding new symints should fail
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "Expected free symbols unchanged, but got"
|
||||
):
|
||||
layout.size = (z,)
|
||||
|
||||
def test_sqrt_dynamic_shapes(self):
|
||||
# TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877.
|
||||
# TODO: support cuda path.
|
||||
|
||||
@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
compute_unbacked_bindings,
|
||||
free_symbols,
|
||||
free_unbacked_symbols,
|
||||
IterateExprs,
|
||||
rebind_unbacked,
|
||||
resolve_unbacked_bindings,
|
||||
ShapeEnv,
|
||||
@ -97,6 +98,7 @@ from .utils import (
|
||||
argsort,
|
||||
argsort_sym,
|
||||
cache_on_self,
|
||||
cache_on_self_and_args,
|
||||
ceildiv,
|
||||
convert_shape_to_inductor,
|
||||
convert_shape_to_symint,
|
||||
@ -933,6 +935,7 @@ class Loops(IRNode):
|
||||
inner_fn: Callable[..., Any]
|
||||
ranges: Sequence[_IntLike]
|
||||
|
||||
@cache_on_self_and_args("Loops")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -1228,6 +1231,7 @@ class Reduction(Loops):
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
@cache_on_self_and_args("Reduction")
|
||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
|
||||
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
|
||||
@ -2327,6 +2331,7 @@ class Scan(Loops):
|
||||
|
||||
# HACK we mimic reduction
|
||||
|
||||
@cache_on_self_and_args("Scan")
|
||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
|
||||
# need to explicitly represent the closure so we can pull out unbacked
|
||||
@ -2537,6 +2542,7 @@ class Sort(Loops):
|
||||
|
||||
# HACK we mimic reduction
|
||||
|
||||
@cache_on_self_and_args("Sort")
|
||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||
return (
|
||||
super().get_free_symbol_uses(unbacked_only)
|
||||
@ -2785,6 +2791,7 @@ def is_unaligned(node: IRNode) -> bool:
|
||||
class BaseView(IRNode):
|
||||
data: IRNode
|
||||
|
||||
@cache_on_self_and_args("BaseView")
|
||||
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
|
||||
return self.data.get_free_symbol_uses(unbacked_only)
|
||||
|
||||
@ -3359,6 +3366,7 @@ class ReinterpretView(BaseView):
|
||||
def freeze_layout(self) -> None:
|
||||
pass
|
||||
|
||||
@cache_on_self_and_args("ReinterpretView")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -3643,13 +3651,37 @@ class Layout(OutputSpec):
|
||||
self.dtype = dtype
|
||||
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
||||
assert all(isinstance(s, (Expr, int)) for s in size)
|
||||
self.size = size
|
||||
self.stride = stride
|
||||
self.offset = offset
|
||||
self._size = size
|
||||
self._stride = stride
|
||||
self._offset = offset
|
||||
self.is_pinned = is_pinned
|
||||
# is_pinned implies cpu
|
||||
assert (not self.is_pinned) or (self.device.type == "cpu")
|
||||
|
||||
@property
|
||||
def size(self) -> Sequence[Expr]:
|
||||
return self._size
|
||||
|
||||
@size.setter
|
||||
def size(self, value: Sequence[Expr]) -> None:
|
||||
self._size = value
|
||||
|
||||
@property
|
||||
def stride(self) -> Sequence[Expr]:
|
||||
return self._stride
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value: Sequence[Expr]) -> None:
|
||||
self._stride = value
|
||||
|
||||
@property
|
||||
def offset(self) -> Expr:
|
||||
return self._offset
|
||||
|
||||
@offset.setter
|
||||
def offset(self, value: Expr) -> None:
|
||||
self._offset = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
offset = ""
|
||||
if self.offset != 0:
|
||||
@ -3869,6 +3901,7 @@ class Layout(OutputSpec):
|
||||
def storage_size(self) -> Expr:
|
||||
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
|
||||
|
||||
@cache_on_self_and_args("Layout")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -3888,7 +3921,11 @@ class FixedLayout(Layout):
|
||||
|
||||
|
||||
class FlexibleLayout(Layout):
|
||||
"""A Tensor layout that we are allowed to change"""
|
||||
"""
|
||||
A Tensor layout that we are allowed to change
|
||||
|
||||
Assumption: layout change should NOT add or remove free symbols
|
||||
"""
|
||||
|
||||
allow_indexing = False
|
||||
|
||||
@ -3973,6 +4010,33 @@ class FlexibleLayout(Layout):
|
||||
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
|
||||
return FlexibleLayout.fill_ordered(sizes, fill_order)
|
||||
|
||||
@property
|
||||
def size(self) -> Sequence[Expr]:
|
||||
return self._size
|
||||
|
||||
@size.setter
|
||||
def size(self, value: Sequence[Expr]) -> None:
|
||||
self.assert_free_symbol_uses_unchanged("size", value)
|
||||
self._size = value
|
||||
|
||||
@property
|
||||
def stride(self) -> Sequence[Expr]:
|
||||
return self._stride
|
||||
|
||||
@stride.setter
|
||||
def stride(self, value: Sequence[Expr]) -> None:
|
||||
self.assert_free_symbol_uses_unchanged("stride", value)
|
||||
self._stride = value
|
||||
|
||||
@property
|
||||
def offset(self) -> Expr:
|
||||
return self._offset
|
||||
|
||||
@offset.setter
|
||||
def offset(self, value: Expr) -> None:
|
||||
self.assert_free_symbol_uses_unchanged("offset", value)
|
||||
self._offset = value
|
||||
|
||||
def as_stride_order(
|
||||
self, order: Sequence[int], allow_padding: bool = False
|
||||
) -> FixedLayout:
|
||||
@ -4031,6 +4095,25 @@ class FlexibleLayout(Layout):
|
||||
self.is_pinned,
|
||||
)
|
||||
|
||||
def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
|
||||
initial_free_symbols = {}
|
||||
for name in ["size", "stride", "offset"]:
|
||||
for unbacked_only in [True, False]:
|
||||
key = (name, unbacked_only)
|
||||
initial_free_symbols[key] = OrderedSet(
|
||||
get_free_symbols(getattr(self, name), unbacked_only)
|
||||
)
|
||||
|
||||
return initial_free_symbols
|
||||
|
||||
def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
|
||||
for unbacked_only in [True, False]:
|
||||
old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
|
||||
new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
|
||||
assert new_free_symbols == old_free_symbols, (
|
||||
f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
@ -4045,6 +4128,10 @@ class FlexibleLayout(Layout):
|
||||
strides = FlexibleLayout.contiguous_strides(size)
|
||||
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
|
||||
|
||||
# record the initial free symbols to check that we do not add new free symbols
|
||||
# later when modifying sizes, strides, and offsets.
|
||||
self.initial_free_symbols = self.get_initial_free_symbol_uses()
|
||||
|
||||
|
||||
class NonOwningLayout(Layout):
|
||||
"""Is a view into the storage of another tensor"""
|
||||
@ -4070,6 +4157,7 @@ class NonOwningLayout(Layout):
|
||||
|
||||
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
|
||||
|
||||
@cache_on_self_and_args("NonOwningLayout")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4358,6 +4446,7 @@ class Buffer(IRNode, CodegenSymbol):
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet([self.get_name()])
|
||||
|
||||
@cache_on_self_and_args("Buffer")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4430,6 +4519,7 @@ class NoneAsConstantBuffer(IRNode):
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return OrderedSet()
|
||||
|
||||
@cache_on_self_and_args("NoneAsConstantBuffer")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4449,6 +4539,7 @@ class NoneAsConstantBuffer(IRNode):
|
||||
class ShapeAsConstantBuffer(IRNode):
|
||||
expr: Expr
|
||||
|
||||
@cache_on_self_and_args("ShapeAsConstantBuffer")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4521,6 +4612,7 @@ class ComputedBuffer(OperationBuffer):
|
||||
self.data.get_size(),
|
||||
)
|
||||
|
||||
@cache_on_self_and_args("ComputedBuffer")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4974,6 +5066,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
|
||||
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
|
||||
|
||||
@cache_on_self_and_args("TritonTemplateBuffer")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -5340,6 +5433,7 @@ class InputsKernel(OperationBuffer):
|
||||
def num_reads(self) -> int:
|
||||
return 1
|
||||
|
||||
@cache_on_self_and_args("InputsKernel")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -5514,6 +5608,7 @@ class ConcatKernel(NopKernel):
|
||||
and not isinstance(src.data, ExternKernelAlloc)
|
||||
)
|
||||
|
||||
@cache_on_self_and_args("ConcatKernel")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -6430,6 +6525,7 @@ class ExternKernel(InputsKernel):
|
||||
index = sympy_subs(sympy.expand(index), replacement)
|
||||
return index, tuple(new_sizes)
|
||||
|
||||
@cache_on_self_and_args("ExternKernel")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -6889,6 +6985,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
original_fxnode_name=self.fx_node.name,
|
||||
)
|
||||
|
||||
@cache_on_self_and_args("UserDefinedTritonKernel")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -7327,6 +7424,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.unbacked_offset_symbol])
|
||||
|
||||
@cache_on_self_and_args("DynamicSelectStorageOffset")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -7377,6 +7475,7 @@ class DynamicSliceSize(ExternKernel):
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.unbacked_size_symbol])
|
||||
|
||||
@cache_on_self_and_args("DynamicSliceSize")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -7441,6 +7540,7 @@ class AssertScalar(ExternKernel):
|
||||
def has_side_effects(self) -> bool:
|
||||
return True
|
||||
|
||||
@cache_on_self_and_args("AssertScalar")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -8115,6 +8215,7 @@ class MultiOutput(ExternKernel):
|
||||
self.indices = indices
|
||||
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
|
||||
|
||||
@cache_on_self_and_args("MultiOutput")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -8237,6 +8338,7 @@ class MutableBox(IRNode):
|
||||
def realize(self) -> Optional[str]:
|
||||
return self.data.realize()
|
||||
|
||||
@cache_on_self_and_args("MutableBox")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
@ -9073,6 +9175,7 @@ class EffectfulKernel(FallbackKernel):
|
||||
|
||||
|
||||
class NonTensorObj(IRNode):
|
||||
@cache_on_self_and_args("NonTensorObj")
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
|
||||
@ -662,6 +662,7 @@ def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
|
||||
|
||||
P = ParamSpec("P")
|
||||
RV = TypeVar("RV", covariant=True)
|
||||
FN_TYPE = Callable[Concatenate[Any, P], RV]
|
||||
|
||||
|
||||
class CachedMethod(Protocol, Generic[P, RV]):
|
||||
@ -709,6 +710,52 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
|
||||
return cache_on_self(fn)
|
||||
|
||||
|
||||
def cache_on_self_and_args(
|
||||
class_name: str,
|
||||
) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
|
||||
# include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
|
||||
|
||||
def wrapper(
|
||||
fn: FN_TYPE[P, RV],
|
||||
) -> FN_TYPE[P, RV]:
|
||||
key = f"__{class_name}_{fn.__name__}_cache"
|
||||
|
||||
# wrapper is likely on the hot path, compile a specialized version of it
|
||||
ctx = {"fn": fn}
|
||||
exec(
|
||||
f"""\
|
||||
def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
|
||||
args_kwargs = (args, tuple(sorted(kwargs.items())))
|
||||
|
||||
if not hasattr(self, "{key}"):
|
||||
object.__setattr__(self, "{key}", {{}})
|
||||
|
||||
cache = self.{key}
|
||||
|
||||
try:
|
||||
return cache[args_kwargs]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
rv = fn(self, *args, **kwargs)
|
||||
|
||||
cache[args_kwargs] = rv
|
||||
return rv
|
||||
""".lstrip(),
|
||||
ctx,
|
||||
)
|
||||
inner = functools.wraps(fn)(ctx["inner"])
|
||||
|
||||
def clear_cache(self: Any) -> None:
|
||||
if hasattr(self, key):
|
||||
delattr(self, key)
|
||||
|
||||
inner.clear_cache = clear_cache # type: ignore[attr-defined]
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def aggregate_origins(
|
||||
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
|
||||
) -> OrderedSet[Node]:
|
||||
|
||||
Reference in New Issue
Block a user