mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA] Allow cuDNN or flash attn in test_activation_checkpointing
pattern match check (#153272)
Seems more robust than maintaining a mirror of dispatch condition based on compute capability etc Pull Request resolved: https://github.com/pytorch/pytorch/pull/153272 Approved by: https://github.com/soulitzer
This commit is contained in:
@ -17,10 +17,6 @@ from functorch.compile import min_cut_rematerialization_partition
|
|||||||
from torch._dynamo.backends.common import aot_autograd
|
from torch._dynamo.backends.common import aot_autograd
|
||||||
from torch._dynamo.testing import CompileCounterWithBackend
|
from torch._dynamo.testing import CompileCounterWithBackend
|
||||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||||
from torch.testing._internal.common_cuda import (
|
|
||||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
|
||||||
SM90OrLater,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu, skipIfRocm
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||||
@ -1368,21 +1364,24 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
|||||||
|
|
||||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||||
opt_fn(*args1).sum().backward()
|
opt_fn(*args1).sum().backward()
|
||||||
if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
|
|
||||||
op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
|
||||||
else:
|
|
||||||
op = torch.ops.aten._scaled_dot_product_flash_attention.default
|
|
||||||
|
|
||||||
fwd_graph = aot_graphs[0]
|
fwd_graph = aot_graphs[0]
|
||||||
|
op1 = torch.ops.aten._scaled_dot_product_flash_attention.default
|
||||||
|
op2 = torch.ops.aten._scaled_dot_product_cudnn_attention.default
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
count_ops(
|
count_ops(
|
||||||
fwd_graph,
|
fwd_graph,
|
||||||
[],
|
[],
|
||||||
freq=1,
|
freq=1,
|
||||||
op=op,
|
op=op1,
|
||||||
|
)
|
||||||
|
or count_ops(
|
||||||
|
fwd_graph,
|
||||||
|
[],
|
||||||
|
freq=1,
|
||||||
|
op=op2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
bwd_graph = aot_graphs[1]
|
bwd_graph = aot_graphs[1]
|
||||||
# Check that sin is not recomputed in the backward graph - checks percolate tags
|
# Check that sin is not recomputed in the backward graph - checks percolate tags
|
||||||
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
|
self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
|
||||||
@ -1392,7 +1391,13 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
|||||||
bwd_graph,
|
bwd_graph,
|
||||||
[],
|
[],
|
||||||
freq=1,
|
freq=1,
|
||||||
op=op,
|
op=op1,
|
||||||
|
)
|
||||||
|
or count_ops(
|
||||||
|
bwd_graph,
|
||||||
|
[],
|
||||||
|
freq=1,
|
||||||
|
op=op2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user