mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA] Allow cuDNN or flash attn in test_activation_checkpointing
pattern match check (#153272)
Seems more robust than maintaining a mirror of dispatch condition based on compute capability etc Pull Request resolved: https://github.com/pytorch/pytorch/pull/153272 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -17,10 +17,6 @@ from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
SM90OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
@ -1368,21 +1364,24 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
opt_fn(*args1).sum().backward()
|
||||
if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
|
||||
op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
||||
else:
|
||||
op = torch.ops.aten._scaled_dot_product_flash_attention.default
|
||||
|
||||
fwd_graph = aot_graphs[0]
|
||||
op1 = torch.ops.aten._scaled_dot_product_flash_attention.default
|
||||
op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
||||
self.assertTrue(
|
||||
count_ops(
|
||||
fwd_graph,
|
||||
[],
|
||||
freq=1,
|
||||
op=op,
|
||||
op=op1,
|
||||
)
|
||||
or count_ops(
|
||||
fwd_graph,
|
||||
[],
|
||||
freq=1,
|
||||
op=op2,
|
||||
)
|
||||
)
|
||||
|
||||
bwd_graph = aot_graphs[1]
|
||||
# Check that sin is not recomputed in the backward graph - checks percolate tags
|
||||
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
|
||||
@ -1392,7 +1391,13 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bwd_graph,
|
||||
[],
|
||||
freq=1,
|
||||
op=op,
|
||||
op=op1,
|
||||
)
|
||||
or count_ops(
|
||||
bwd_graph,
|
||||
[],
|
||||
freq=1,
|
||||
op=op2,
|
||||
)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user