mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDNN][SDPA] Fix unsupported trivial stride-1 transpose case (#134031)
Fixes #134001 Incorrect assumption that two same-shape tensors being contiguous meant that they would have the same stride Pull Request resolved: https://github.com/pytorch/pytorch/pull/134031 Approved by: https://github.com/drisspg, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
@ -662,7 +662,10 @@ void run_cudnn_SDP_bprop(
|
||||
" Materializing a contiguous tensor which will increase memory usage...");
|
||||
dO_ = dO.contiguous();
|
||||
}
|
||||
if (!std::equal(
|
||||
if ( // handle trivial transposed case with a transposed dim of size 1
|
||||
// see also: https://github.com/pytorch/pytorch/issues/134001
|
||||
!(dO_.is_contiguous() && o.is_contiguous()) &&
|
||||
!std::equal(
|
||||
o.strides().begin(), o.strides().end(), dO.strides().begin())) {
|
||||
TORCH_WARN(
|
||||
"cuDNN SDPA backward got grad_output.strides() != output.strides(), "
|
||||
@ -674,8 +677,9 @@ void run_cudnn_SDP_bprop(
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
std::equal(
|
||||
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
|
||||
(dO_.is_contiguous() && o.is_contiguous()) ||
|
||||
std::equal(
|
||||
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
|
||||
"cuDNN SDPA expected grad_output.strides() == output.strides(), "
|
||||
"the previous step probably failed to materialize a grad_output "
|
||||
"with matching strides...");
|
||||
|
@ -2398,6 +2398,22 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
||||
o = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_trivial_output_transpose(self, device):
|
||||
# see also: https://github.com/pytorch/pytorch/issues/134001
|
||||
x = torch.randn(2, 4, 1, 64, device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
x2 = x.transpose(1, 2)
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
||||
o = torch.nn.functional.scaled_dot_product_attention(x2, x2, x2).transpose(1, 2).reshape(2, 64, 4)
|
||||
o.backward(o)
|
||||
x_cpu = x.clone().cpu().detach()
|
||||
x_cpu.requires_grad = True
|
||||
x2_cpu = x_cpu.transpose(1, 2)
|
||||
o = torch.nn.functional.scaled_dot_product_attention(x2_cpu, x2_cpu, x2_cpu).transpose(1, 2).reshape(2, 64, 4)
|
||||
o.backward(o)
|
||||
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
|
||||
|
Reference in New Issue
Block a user