DDE-Free select with unbacked index. (#157605)

When select has data dependent input, we cant tell if the actual index shall be index+size or index.
to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the
output view and we compute its value dynamically at runtime when inductor is lowered.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Laith Sakka
2025-07-10 20:28:30 -07:00
committed by PyTorch MergeBot
parent 9faef3d17c
commit 0b2ef76e85
15 changed files with 349 additions and 53 deletions

View File

@ -69,13 +69,20 @@ OPTIMUS_EXCLUDE_POST_GRAD = [
"inductor_autotune_lookup_table",
]
from torch.fx.experimental.symbolic_shapes import (
free_symbols,
free_unbacked_symbols,
IterateExprs,
ShapeEnv,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence, ValuesView
from torch import SymBool, SymFloat, SymInt
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.node import Node
from .codegen.common import WorkspaceArg
@ -3355,3 +3362,10 @@ def aoti_model_name_from_config() -> str:
model_name = config.aot_inductor.model_name_for_generated_files
model_name = "aoti_model" if model_name is None else model_name
return model_name
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
if unbacked_only:
return free_unbacked_symbols(x)
else:
return free_symbols(x)