[CUDA] Allow cuDNN or flash attn in test_activation_checkpointing pattern match check (#153272)

Seems more robust than maintaining a mirror of dispatch condition based on compute capability etc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153272
Approved by: https://github.com/soulitzer
This commit is contained in:
eqy
2025-07-11 20:58:08 +00:00
committed by PyTorch MergeBot
parent 702a304b07
commit 00ae620b9f

View File

@ -17,10 +17,6 @@ from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import CompileCounterWithBackend from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
SM90OrLater,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.inductor_utils import HAS_CUDA
@ -1368,21 +1364,24 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
opt_fn = torch.compile(fn, backend=backend, fullgraph=True) opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(*args1).sum().backward() opt_fn(*args1).sum().backward()
if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
else:
op = torch.ops.aten._scaled_dot_product_flash_attention.default
fwd_graph = aot_graphs[0] fwd_graph = aot_graphs[0]
op1 = torch.ops.aten._scaled_dot_product_flash_attention.default
op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default
self.assertTrue( self.assertTrue(
count_ops( count_ops(
fwd_graph, fwd_graph,
[], [],
freq=1, freq=1,
op=op, op=op1,
)
or count_ops(
fwd_graph,
[],
freq=1,
op=op2,
) )
) )
bwd_graph = aot_graphs[1] bwd_graph = aot_graphs[1]
# Check that sin is not recomputed in the backward graph - checks percolate tags # Check that sin is not recomputed in the backward graph - checks percolate tags
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default)) self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
@ -1392,7 +1391,13 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bwd_graph, bwd_graph,
[], [],
freq=1, freq=1,
op=op, op=op1,
)
or count_ops(
bwd_graph,
[],
freq=1,
op=op2,
) )
) )