[AOTI] Pass in shape_env for get_stride_order (#163925)

Summary:
As titled.
Without the diff, we got P1963055009

With the diff passing in the enviroment, we can do correct sym_int deduction:
https://fburl.com/mlhub/p5zy7o28

Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:unbacked_symints -- test_sdfpa_unbacked_strides --print-passing-details --env TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --env TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(u0, 0)"
```
Without the fix: P1964887260
With the fix: P1964888579

Differential Revision: D83211018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163925
Approved by: https://github.com/ColinPeppler
This commit is contained in:
Zhijing Li
2025-09-26 21:10:03 +00:00
committed by PyTorch MergeBot
parent a60c6ed99f
commit 28c7d11428
2 changed files with 31 additions and 2 deletions

View File

@ -531,6 +531,35 @@ class TestUnbackedSymints(InductorTestCase):
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
torch.compile(fn, fullgraph=True)(x)
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@skipIfXpu(msg="scaled_dot_product_attention is not supported on XPU yet")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_sdfpa_unbacked_strides(self, device):
if device == "cpu":
raise unittest.SkipTest("scaled_dot_product_attention has no CPU backend")
def fn(x, y):
B, H, d_h = 2, 4, 16
nz = torch.nonzero(x)
seq_len = nz.size(0)
y = torch.nonzero(y).size(0)
strides = (H * seq_len * d_h, seq_len * d_h, d_h, y)
q = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
k = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
v = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
q = torch.as_strided(q, size=(B, H, seq_len, d_h), stride=strides)
k = torch.as_strided(k, size=(B, H, seq_len, d_h), stride=strides)
v = torch.as_strided(v, size=(B, H, seq_len, d_h), stride=strides)
result = torch.ops.aten._scaled_dot_product_flash_attention.default(
q, k, v, dropout_p=0.0, is_causal=False, scale=None
)
return result
x = torch.tensor([1.0, 0.0] * 8, device=device)
y = torch.tensor([1.0, 0.0], device=device)
torch.compile(fn, fullgraph=True)(x, y)
@skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton")
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
def test_unbacked_linear_layer_norm_input(self, device):

View File

@ -2614,8 +2614,8 @@ def sdpa_constraint(fx_node, *args, **kwargs):
meta_stride_expr = [
s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
]
stride_order = ir.get_stride_order(meta_val.stride())
shape_env = V.graph.sizevars.shape_env
stride_order = ir.get_stride_order(meta_val.stride(), shape_env)
if stride_order and stride_order[-1] != 0:
# contiguous stride order