Compare commits

...

3 Commits

Author SHA1 Message Date
332491ccfd Merge mess 2024-07-10 20:29:57 -07:00
5f622ca2aa MPS kernel working for the base case 2024-07-10 20:23:11 -07:00
5b7e452f08 Boilerplate for the fused sdpa mps kernel call 2024-07-10 20:23:02 -07:00
6 changed files with 152 additions and 2 deletions

View File

@ -156,6 +156,7 @@ file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
file(GLOB native_transformers_mps_mm "native/transformers/mps/*.mm")
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
@ -550,7 +551,7 @@ if(USE_CUDA)
endif()
if(USE_MPS)
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h})
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_transformers_mps_mm})
endif()
if(USE_ROCM)

View File

@ -14720,6 +14720,7 @@
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
MPS: _fused_sdp_choice_mps
tags: nondeterministic_seeded
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
@ -14783,6 +14784,11 @@
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_attention_mps(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
dispatch:
MPS: _fused_scaled_dot_product_attention_mps
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:

View File

@ -39,6 +39,7 @@
#include <ATen/ops/_scaled_dot_product_flash_attention_backward_native.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_native.h>
#include <ATen/ops/_scaled_dot_product_cudnn_attention.h>
#include <ATen/ops/_scaled_dot_product_attention_mps.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
@ -70,6 +71,7 @@
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <iostream>
namespace at {
namespace native {
@ -663,6 +665,11 @@ Tensor scaled_dot_product_attention(
query_, key, value, attn_mask_, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::mps_attention: {
auto out_softmax = at::_scaled_dot_product_attention_mps(
query_, key, value, dropout_p, is_causal, scale);
return out_softmax;
}
case sdp::SDPBackend::flash_attention: {
if(query_.device().type() == DeviceType::CUDA){
c10::SymInt og_size = query_.sym_size(-1);

View File

@ -0,0 +1,115 @@
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_sdp_choice_native.h>
#include <ATen/ops/_scaled_dot_product_attention_mps.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/ones_like_native.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#include <ATen/mps/MPSProfiler.h>
#endif
#include<iostream>
namespace at {
namespace native {
struct CachedGraph : public mps::MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};
//Need a right way to do this definition, now its just here to prevent compilation error.
Tensor _fused_scaled_dot_product_attention_mps(Tensor const& query, Tensor const& key, Tensor const& value, double dropout_p, bool is_causal, c10::optional<double> scale);
//Do I need to do this in MPS namespace? Directly it causes an issue so it needs to be reflected elsewhere.
Tensor _fused_scaled_dot_product_attention_mps(Tensor const& query, Tensor const& key, Tensor const& value, double dropout_p, bool is_causal, c10::optional<double> scale) {
using namespace mps;
if (query.numel() == 0 || key.numel() == 0 || value.numel() == 0) {
//TODO: Check if zeros is the expectation in this case. Or just empty.
return at::zeros_like(query);
}
double scale_;
if(scale) {
scale_ = scale.value();
} else {
scale_ = 1.0 / sqrt(query.size(-1));
}
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t max_seqlen_batch_q = query.size(2);
const int64_t head_dim = query.size(3);
const int64_t max_seqlen_batch_k = key.size(2);
const int64_t max_seqlen_batch_v = value.size(2);
Tensor out = at::zeros_like(query, query.options());
const auto L = query.size(-2), S = key.size(-2);
auto mask = at::zeros({L, S}, query.options());
if (is_causal) {
auto temp = at::ones({L, S}, query.options().dtype(at::kBool)).tril();
mask.masked_fill_(temp.logical_not(), -std::numeric_limits<double>::infinity());
}
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string cacheKey = "fused_sdpa_" + getTensorsStringKey({query}); // + std::to_string(scale);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(cacheKey, [&](auto mpsGraph, auto newCachedGraph) {
auto mpsDtype = getMPSDataType(query);
MPSGraphTensor* queryTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, mpsDtype);
MPSGraphTensor* keyTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, mpsDtype);
MPSGraphTensor* valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, mpsDtype);
MPSGraphTensor* maskTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, mpsDtype);
newCachedGraph->inputTensors_ = {queryTensor, keyTensor, valueTensor, maskTensor};
MPSGraphTensor *sdpa = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryTensor
keyTensor:keyTensor
valueTensor:valueTensor
maskTensor:maskTensor
scale:scale_
name:nil];
newCachedGraph->outputTensor_ = sdpa;
return newCachedGraph;
});
Placeholder queryPlaceholder = Placeholder(cachedGraph->inputTensors_[0], query, getMPSShape(query));
Placeholder keyPlaceholder = Placeholder(cachedGraph->inputTensors_[1], key, getMPSShape(key));
Placeholder valuePlaceholder = Placeholder(cachedGraph->inputTensors_[2], value, getMPSShape(value));
Placeholder maskPlaceholder = Placeholder(cachedGraph->inputTensors_[3], mask, getMPSShape(mask));
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
auto feeds = dictionaryFromPlaceholders(queryPlaceholder, keyPlaceholder, valuePlaceholder, maskPlaceholder);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
}
return out;
}
int64_t _fused_sdp_choice_mps(const Tensor& query_, const Tensor& key, const Tensor& value, const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
auto backend = sdp::SDPBackend::mps_attention; //select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found. ",
"This is likely due to turning off both the math kernel and the fused kernels.");
}
return static_cast<int64_t>(backend);
}
REGISTER_MPS_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_mps);
}} // namespace at::native

View File

@ -29,7 +29,8 @@ enum class SDPBackend {
flash_attention = 1,
efficient_attention = 2,
cudnn_attention = 3,
overrideable = 4
overrideable = 4,
mps_attention = 5,
};
// Note that if this changed make sure to update

View File

@ -1844,6 +1844,26 @@ class TestMPS(TestCaseMPS):
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
def test_sdpa(self):
for is_causal in [True, False]:
b = 4 #batch
h = 2 #heads
L = 3 #prompt size
E = 2 #features
S = 2 #context
k = torch.randn(b, h, L, E, device='cpu') # b, h, L, E
q = torch.randn(b, h, S, E, device='cpu') # b, h, S, E
v = torch.randn(b, h, S, E, device='cpu') # b, h, S, E
cpu_ref = F.scaled_dot_product_attention(k, q, v, is_causal=is_causal)
device = 'mps'
k_mps = k.detach().clone().to(device)
q_mps = q.detach().clone().to(device)
v_mps = v.detach().clone().to(device)
mps = F.scaled_dot_product_attention(k_mps, q_mps, v_mps, is_causal=is_causal)
torch.testing.assert_close(mps.to('cpu'), cpu_ref)
def test_mm(self):
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")