mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
DDE-Free select with unbacked index. (#157605)
When select has data dependent input, we cant tell if the actual index shall be index+size or index. to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the output view and we compute its value dynamically at runtime when inductor is lowered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605 Approved by: https://github.com/ColinPeppler
This commit is contained in:
committed by
PyTorch MergeBot
parent
9faef3d17c
commit
0b2ef76e85
@ -69,13 +69,20 @@ OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"inductor_autotune_lookup_table",
|
||||
]
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
free_symbols,
|
||||
free_unbacked_symbols,
|
||||
IterateExprs,
|
||||
ShapeEnv,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence, ValuesView
|
||||
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .codegen.common import WorkspaceArg
|
||||
@ -3355,3 +3362,10 @@ def aoti_model_name_from_config() -> str:
|
||||
model_name = config.aot_inductor.model_name_for_generated_files
|
||||
model_name = "aoti_model" if model_name is None else model_name
|
||||
return model_name
|
||||
|
||||
|
||||
def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
|
||||
if unbacked_only:
|
||||
return free_unbacked_symbols(x)
|
||||
else:
|
||||
return free_symbols(x)
|
||||
|
Reference in New Issue
Block a user