Fix: ShapeEnv not propagated properly to inductor SizeVars (#162927)

Summary:
I am really skeptical about inductor sizevars creating an empty shape env when not provided with one
i think we should fail there if the graph has dynamic shapes and no shape env is provided.

however i wonder if there are actually use cases that depends on the shape env not being there?
Reasoning APIs depends on facts in the shape env. and assumes some stuff exists for specific symbols.

Test Plan:
Fix the bug reported in creating simple e2e unit test is not trivial
https://www.internalfb.com/diff/D82337184

Rollback Plan:

Differential Revision: D82412384

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162927
Approved by: https://github.com/ezyang, https://github.com/eellison, https://github.com/jansel
This commit is contained in:
Laith Sakka
2025-09-18 00:56:18 +00:00
committed by PyTorch MergeBot
parent 57a54a04b6
commit 04ddea44fd
5 changed files with 17 additions and 1 deletions

View File

@ -1227,7 +1227,9 @@ class _InProcessFxCompile(FxCompile):
# structured logs...
# trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False))
shape_env = shape_env_from_inputs(example_inputs)
shape_env = gm.shape_env
if shape_env is None:
shape_env = shape_env_from_inputs(example_inputs)
# Convert view to reshape in the graph. This is necessary primarily for
# layout optimization. Do it unconditionally for uniformity.

View File

@ -70,6 +70,8 @@ class SizeVarAllocator:
def __init__(self, shape_env=None) -> None:
super().__init__()
# Note: this can lead to bugs. Reasoning APIs depends on existing information in
# in the shape_env. For example! var_to_ranges can't be empty!
if shape_env is None:
shape_env = ShapeEnv()
self.shape_env = shape_env

View File

@ -2968,6 +2968,15 @@ def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
if isinstance(input, torch.SymInt):
return input.node.shape_env
# Check tensor sizes and strides for SymInt values
if isinstance(input, torch.Tensor):
for size in input.size():
if isinstance(size, torch.SymInt):
return size.node.shape_env
for stride in input.stride():
if isinstance(stride, torch.SymInt):
return stride.node.shape_env
# TODO(voz): Should we always have one anyway?
return None

View File

@ -1756,12 +1756,14 @@ def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]:
def eval_guards(
gm: torch.fx.GraphModule, *args: Tensor, ignore_static: bool = True
) -> bool:
assert gm.shape_env is not None
return gm.shape_env.evaluate_guards_for_args( # type: ignore[operator, union-attr]
fx_placeholder_vals(gm), args, ignore_static=ignore_static
)
def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]:
assert gm.shape_env is not None
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr]

View File

@ -546,6 +546,7 @@ class GraphModule(torch.nn.Module):
self._erase_node_hooks: list[Callable] = []
# Used to remove hooks from deepcopied graph modules within a context manager.
self._deepcopy_hooks: list[Callable] = []
self.shape_env = None # optional not always set even when dynamic shapes exist.
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842