mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] fix attention for >4d tensors (#147545)
Fixes #147443 and adds tests for >4d tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/147545 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b9da1ae0a
commit
a695aae89b
@ -23,6 +23,9 @@ namespace native {
|
||||
static inline std::tuple<Tensor, bool> ensure_4d(const Tensor& x) {
|
||||
if (x.dim() == 3) {
|
||||
return {x.unsqueeze(0), true};
|
||||
} else if (x.dim() > 4) {
|
||||
auto batchSize = c10::multiply_integers(x.sizes().begin(), x.sizes().end() - 3);
|
||||
return {x.view({batchSize, x.size(-3), x.size(-2), x.size(-1)}), true};
|
||||
} else {
|
||||
return {x, false};
|
||||
}
|
||||
@ -52,6 +55,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
|
||||
auto [q_, sq] = ensure_4d(query);
|
||||
auto [k_, sk] = ensure_4d(key);
|
||||
auto [v_, sv] = ensure_4d(value);
|
||||
std::optional<Tensor> mask_;
|
||||
|
||||
using namespace mps;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
@ -113,7 +117,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
|
||||
falsePredicateTensor:minusInf
|
||||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
auto maskExpandedDims = query.sizes().vec();
|
||||
maskExpandedDims[maskExpandedDims.size() - 1] = maxSeqLength;
|
||||
mask_ = attn_mask->expand(maskExpandedDims);
|
||||
std::tie(*mask_, std::ignore) = ensure_4d(*mask_);
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *mask_);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
|
||||
}
|
||||
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
|
||||
@ -130,19 +138,23 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
|
||||
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, out);
|
||||
auto attnPlaceholder = Placeholder(cachedGraph->attnTensor, attn);
|
||||
NSDictionary* feeds = nil;
|
||||
if (!attn_mask) {
|
||||
if (!mask_) {
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder);
|
||||
} else {
|
||||
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *attn_mask);
|
||||
auto mPlaceholder = Placeholder(cachedGraph->maskTensor, *mask_);
|
||||
feeds = dictionaryFromPlaceholders(qPlaceholder, kPlaceholder, vPlaceholder, mPlaceholder);
|
||||
}
|
||||
NSDictionary* outs = dictionaryFromPlaceholders(outputPlaceholder, attnPlaceholder);
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outs);
|
||||
}
|
||||
|
||||
// Squeeze back to 3D
|
||||
auto final_out = (sq ? out.squeeze(0) : out);
|
||||
auto final_attn = (sq ? attn.squeeze(0) : attn);
|
||||
// reshape back to original dimension
|
||||
auto final_out = sq ? out.view_as(query) : out;
|
||||
auto final_attn = sq ? (query.dim() == 3 ? attn.squeeze(0) : [&]{
|
||||
std::vector<int64_t> shape(query.sizes().begin(), query.sizes().end() - 3);
|
||||
shape.insert(shape.end(), {attn.size(1), attn.size(2), attn.size(3)});
|
||||
return attn.view(shape);
|
||||
}()) : attn;
|
||||
|
||||
return {std::move(final_out), std::move(final_attn)};
|
||||
}
|
||||
|
@ -9956,6 +9956,62 @@ class TestSDPA(TestCaseMPS):
|
||||
def test_sdpa_3d_input_fp16(self):
|
||||
self._test_sdpa_3d_input(torch.float16)
|
||||
|
||||
def _test_sdpa_no_mask_5d(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
B: int = 2,
|
||||
extra: int = 3,
|
||||
NH: int = 4,
|
||||
L: int = 10,
|
||||
HS: int = 16,
|
||||
requires_grad: bool = False
|
||||
):
|
||||
torch.manual_seed(1729)
|
||||
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps", requires_grad=requires_grad)
|
||||
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
|
||||
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
|
||||
|
||||
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
|
||||
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
||||
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=False)
|
||||
self._compare_tensors(y.cpu(), y_ref)
|
||||
|
||||
if requires_grad and torch.is_grad_enabled():
|
||||
y.sum().backward()
|
||||
y_ref.sum().backward()
|
||||
self._compare_tensors(q.grad.cpu(), q.cpu().grad)
|
||||
|
||||
def test_sdpa_no_mask_5d_fp32(self):
|
||||
self._test_sdpa_no_mask_5d(torch.float32)
|
||||
|
||||
def test_sdpa_no_mask_5d_fp16(self):
|
||||
self._test_sdpa_no_mask_5d(torch.float16)
|
||||
|
||||
def _test_sdpa_mask_5d(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
B: int = 2,
|
||||
extra: int = 3,
|
||||
NH: int = 4,
|
||||
L: int = 10,
|
||||
HS: int = 16
|
||||
):
|
||||
torch.manual_seed(1729)
|
||||
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
|
||||
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
|
||||
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
|
||||
mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device="mps")).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
|
||||
self._compare_tensors(y.cpu(), y_ref)
|
||||
|
||||
def test_sdpa_mask_5d_fp32(self):
|
||||
self._test_sdpa_mask_5d(torch.float32)
|
||||
|
||||
def test_sdpa_mask_5d_fp16(self):
|
||||
self._test_sdpa_mask_5d(torch.float16)
|
||||
|
||||
class TestGatherScatter(TestCaseMPS):
|
||||
def test_slicing_with_step(self):
|
||||
|
Reference in New Issue
Block a user