mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
The issue cannot be reproduced using the original repro code provided in the issue description. However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case: ```python import torch from torch.nn.attention import _cur_sdpa_kernel_backends @torch.compile(fullgraph=True) def test_function_that_triggers_error(): return _cur_sdpa_kernel_backends() print("Calling torch.compile function...") try: result = test_function_that_triggers_error() print(f"Success: {result}") except Exception as e: print(f"ERROR: {e}") print(f"Error type: {type(e)}") ``` The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing. I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation. @guilhermeleobas Pull Request resolved: https://github.com/pytorch/pytorch/pull/161169 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import CompileCounter
|
|
from torch.backends.cuda import SDPAParams
|
|
from torch.nn.attention import _cur_sdpa_kernel_backends, sdpa_kernel, SDPBackend
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def allow_in_graph_sdpa_params():
|
|
global SDPAParams
|
|
try:
|
|
old = SDPAParams
|
|
SDPAParams = torch._dynamo.allow_in_graph(SDPAParams)
|
|
yield
|
|
finally:
|
|
SDPAParams = old
|
|
|
|
|
|
class TestSDPA(torch._dynamo.test_case.TestCase):
|
|
def assert_ref_equals_params(self, actual, expected):
|
|
self.assertIs(actual.query, expected.query)
|
|
self.assertIs(actual.key, expected.key)
|
|
self.assertIs(actual.value, expected.value)
|
|
self.assertIs(actual.attn_mask, expected.attn_mask)
|
|
|
|
def test_returns_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def fn(q, k, v, m):
|
|
return SDPAParams(q, k, v, m, 0.1, True, False)
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
o = fn(q, k, v, m)
|
|
self.assertTrue(isinstance(o, SDPAParams))
|
|
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_graph_break_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def fn(q, k, v, m):
|
|
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
torch._dynamo.graph_break()
|
|
return z, q + 1
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
o, _ = fn(q, k, v, m)
|
|
self.assertTrue(isinstance(o, SDPAParams))
|
|
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_input_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def fn(sdpap, q):
|
|
torch._dynamo.graph_break()
|
|
return sdpap, sdpap.query + q
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
s = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
o, _ = fn(s, q)
|
|
self.assertIs(o, s)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_intermediate_attr_access_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def fn(q, k, v, m):
|
|
q += 1
|
|
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
a = z.query
|
|
return a + 1, z, q
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
_, o, _ = fn(q, k, v, m)
|
|
expected = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
self.assert_ref_equals_params(o, expected)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_sdpa_c_functions_no_graph_break(self):
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def test_cur_sdpa_kernel_backends():
|
|
return _cur_sdpa_kernel_backends()
|
|
|
|
result = test_cur_sdpa_kernel_backends()
|
|
|
|
self.assertIsInstance(result, list)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_sdpa_kernel_decorator_with_compile(self):
|
|
SDPA_BACKEND_PRIORITY = [
|
|
SDPBackend.MATH,
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
SDPBackend.FLASH_ATTENTION,
|
|
]
|
|
|
|
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
return torch.nn.functional.scaled_dot_product_attention(
|
|
q, k, v, *args, **kwargs
|
|
)
|
|
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def f(x):
|
|
return scaled_dot_product_attention(x, x, x)
|
|
|
|
x = torch.rand(128, 64, 64, 256, dtype=torch.float16)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result.shape, x.shape)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|