[MPS] fix attention for >4d tensors (#147545)

Fixes #147443

and adds tests for >4d tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147545
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-02-25 13:55:28 +00:00
committed by PyTorch MergeBot
parent 0b9da1ae0a
commit a695aae89b
2 changed files with 74 additions and 6 deletions

View File

@ -23,6 +23,9 @@ namespace native {
static inline std::tuple<Tensor, bool> ensure_4d(const Tensor& x) {
if (x.dim() == 3) {
return {x.unsqueeze(0), true};
} else if (x.dim() > 4) {
auto batchSize = c10::multiply_integers(x.sizes().begin(), x.sizes().end() - 3);
return {x.view({batchSize, x.size(-3), x.size(-2), x.size(-1)}), true};
} else {
return {x, false};
}
@ -52,6 +55,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
auto [q_, sq] = ensure_4d(query);
auto [k_, sk] = ensure_4d(key);
auto [v_, sv] = ensure_4d(value);
std::optional<Tensor> mask_;
using namespace mps;
struct CachedGraph : public MPSCachedGraph {
@ -113,7 +117,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
auto maskExpandedDims = query.sizes().vec();
maskExpandedDims[maskExpandedDims.size() - 1] = maxSeqLength;
mask_ = attn_mask->expand(maskExpandedDims);
std::tie(*mask_, std::ignore) = ensure_4d(*mask_);
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *mask_);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
}
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
@ -130,19 +138,23 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
NSDictionary* feeds = nil;
if (!attn_mask) {
if (!mask_) {
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
} else {
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *mask_);
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
}
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
}
// Squeeze back to 3D
auto final_out = (sq ? out.squeeze(0) : out);
auto final_attn = (sq ? attn.squeeze(0) : attn);
// reshape back to original dimension
auto final_out = sq ? out.view_as(query) : out;
auto final_attn = sq ? (query.dim() == 3 ? attn.squeeze(0) : [&]{
std::vector<int64_t> shape(query.sizes().begin(), query.sizes().end() - 3);
shape.insert(shape.end(), {attn.size(1), attn.size(2), attn.size(3)});
return attn.view(shape);
}()) : attn;
return {std::move(final_out), std::move(final_attn)};
}

View File

@ -9956,6 +9956,62 @@ class TestSDPA(TestCaseMPS):
def test_sdpa_3d_input_fp16(self):
self._test_sdpa_3d_input(torch.float16)
def _test_sdpa_no_mask_5d(
self,
dtype: torch.dtype,
B: int = 2,
extra: int = 3,
NH: int = 4,
L: int = 10,
HS: int = 16,
requires_grad: bool = False
):
torch.manual_seed(1729)
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps", requires_grad=requires_grad)
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
if requires_grad and torch.is_grad_enabled():
y.sum().backward()
y_ref.sum().backward()
self._compare_tensors(q.grad.cpu(), q.cpu().grad)
def test_sdpa_no_mask_5d_fp32(self):
self._test_sdpa_no_mask_5d(torch.float32)
def test_sdpa_no_mask_5d_fp16(self):
self._test_sdpa_no_mask_5d(torch.float16)
def _test_sdpa_mask_5d(
self,
dtype: torch.dtype,
B: int = 2,
extra: int = 3,
NH: int = 4,
L: int = 10,
HS: int = 16
):
torch.manual_seed(1729)
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device="mps")).unsqueeze(0).unsqueeze(0)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
def test_sdpa_mask_5d_fp32(self):
self._test_sdpa_mask_5d(torch.float32)
def test_sdpa_mask_5d_fp16(self):
self._test_sdpa_mask_5d(torch.float16)
class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):