mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
SDP Backend function fix (#161169)
The issue cannot be reproduced using the original repro code provided in the issue description. However, the underlying issue mentioned by the maintainer (missing functions in `builder.py` and `trace_rules.py`) was never addressed and can still be reproduced with this test case: ```python import torch from torch.nn.attention import _cur_sdpa_kernel_backends @torch.compile(fullgraph=True) def test_function_that_triggers_error(): return _cur_sdpa_kernel_backends() print("Calling torch.compile function...") try: result = test_function_that_triggers_error() print(f"Success: {result}") except Exception as e: print(f"ERROR: {e}") print(f"Error type: {type(e)}") ``` The original repro likely no longer triggers the issue due to code path changes in the SDPA implementation, while the direct call to `_cur_sdpa_kernel_backends()` exposes the underlying problem where certain torch._C functions returning non-Tensor values aren't properly handled by dynamo tracing. I have implemented the changes by adding the missing functions to both `builder.py` and `trace_rules.py` to properly handle these cases during compilation. @guilhermeleobas Pull Request resolved: https://github.com/pytorch/pytorch/pull/161169 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi
This commit is contained in:
committed by
PyTorch MergeBot
parent
7130b174e0
commit
ba3c2c80ab
@ -5,6 +5,7 @@ import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
from torch.backends.cuda import SDPAParams
|
||||
from torch.nn.attention import _cur_sdpa_kernel_backends, sdpa_kernel, SDPBackend
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -99,6 +100,43 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
self.assert_ref_equals_params(o, expected)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_sdpa_c_functions_no_graph_break(self):
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch.compile(fullgraph=True, backend=counter)
|
||||
def test_cur_sdpa_kernel_backends():
|
||||
return _cur_sdpa_kernel_backends()
|
||||
|
||||
result = test_cur_sdpa_kernel_backends()
|
||||
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_sdpa_kernel_decorator_with_compile(self):
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.MATH,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
]
|
||||
|
||||
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, *args, **kwargs
|
||||
)
|
||||
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch.compile(fullgraph=True, backend=counter)
|
||||
def f(x):
|
||||
return scaled_dot_product_attention(x, x, x)
|
||||
|
||||
x = torch.rand(128, 64, 64, 256, dtype=torch.float16)
|
||||
result = f(x)
|
||||
|
||||
self.assertEqual(result.shape, x.shape)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -684,6 +684,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._get_mem_efficient_sdp_enabled",
|
||||
"torch._C._get_mkldnn_enabled",
|
||||
"torch._C._get_cudnn_sdp_enabled",
|
||||
"torch._C._get_overrideable_sdp_enabled",
|
||||
"torch._C._set_sdp_use_cudnn",
|
||||
"torch._C._get_mobile_model_contained_types_from_buffer",
|
||||
"torch._C._get_mobile_model_contained_types",
|
||||
@ -1220,6 +1221,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._set_sdp_use_math",
|
||||
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
|
||||
"torch._C._set_sdp_use_mem_efficient",
|
||||
"torch._C._set_sdp_use_overrideable",
|
||||
"torch._C._set_should_use_format_with_string_table",
|
||||
"torch._C._set_sm_carveout_experimental",
|
||||
"torch._C._set_storage_access_error_msg",
|
||||
|
@ -3033,6 +3033,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
torch.backends.cuda.is_flash_attention_available,
|
||||
torch.backends.cuda.can_use_flash_attention,
|
||||
torch.backends.cuda.can_use_efficient_attention,
|
||||
torch._C._get_cudnn_sdp_enabled,
|
||||
torch._C._get_flash_sdp_enabled,
|
||||
torch._C._get_mem_efficient_sdp_enabled,
|
||||
torch._C._get_math_sdp_enabled,
|
||||
torch._C._get_overrideable_sdp_enabled,
|
||||
"is_integer",
|
||||
]
|
||||
+ list(supported_const_comparison_op_values.keys())
|
||||
|
Reference in New Issue
Block a user