mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI] Pass in shape_env for get_stride_order (#163925)
Summary: As titled. Without the diff, we got P1963055009 With the diff passing in the enviroment, we can do correct sym_int deduction: https://fburl.com/mlhub/p5zy7o28 Test Plan: ``` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:unbacked_symints -- test_sdfpa_unbacked_strides --print-passing-details --env TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --env TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(u0, 0)" ``` Without the fix: P1964887260 With the fix: P1964888579 Differential Revision: D83211018 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163925 Approved by: https://github.com/ColinPeppler
This commit is contained in:
committed by
PyTorch MergeBot
parent
a60c6ed99f
commit
28c7d11428
@ -531,6 +531,35 @@ class TestUnbackedSymints(InductorTestCase):
|
||||
x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device)
|
||||
torch.compile(fn, fullgraph=True)(x)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@skipIfXpu(msg="scaled_dot_product_attention is not supported on XPU yet")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
def test_sdfpa_unbacked_strides(self, device):
|
||||
if device == "cpu":
|
||||
raise unittest.SkipTest("scaled_dot_product_attention has no CPU backend")
|
||||
|
||||
def fn(x, y):
|
||||
B, H, d_h = 2, 4, 16
|
||||
nz = torch.nonzero(x)
|
||||
seq_len = nz.size(0)
|
||||
y = torch.nonzero(y).size(0)
|
||||
strides = (H * seq_len * d_h, seq_len * d_h, d_h, y)
|
||||
|
||||
q = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
|
||||
k = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
|
||||
v = torch.randn(B, H, seq_len, d_h, device=device, dtype=torch.float16)
|
||||
q = torch.as_strided(q, size=(B, H, seq_len, d_h), stride=strides)
|
||||
k = torch.as_strided(k, size=(B, H, seq_len, d_h), stride=strides)
|
||||
v = torch.as_strided(v, size=(B, H, seq_len, d_h), stride=strides)
|
||||
result = torch.ops.aten._scaled_dot_product_flash_attention.default(
|
||||
q, k, v, dropout_p=0.0, is_causal=False, scale=None
|
||||
)
|
||||
return result
|
||||
|
||||
x = torch.tensor([1.0, 0.0] * 8, device=device)
|
||||
y = torch.tensor([1.0, 0.0], device=device)
|
||||
torch.compile(fn, fullgraph=True)(x, y)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
def test_unbacked_linear_layer_norm_input(self, device):
|
||||
|
@ -2614,8 +2614,8 @@ def sdpa_constraint(fx_node, *args, **kwargs):
|
||||
meta_stride_expr = [
|
||||
s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
|
||||
]
|
||||
|
||||
stride_order = ir.get_stride_order(meta_val.stride())
|
||||
shape_env = V.graph.sizevars.shape_env
|
||||
stride_order = ir.get_stride_order(meta_val.stride(), shape_env)
|
||||
|
||||
if stride_order and stride_order[-1] != 0:
|
||||
# contiguous stride order
|
||||
|
Reference in New Issue
Block a user