mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix torch.mm
correctness for large matrices (#117549)
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows: ```objc NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new]; for (int64_t i = 0; i < M; i += tile_size) { const auto i_end = std::min(i + tile_size, M); NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new]; for (int64_t j = 0; j < K; j += tile_size) { const auto j_end = std::min(j + tile_size, K); MPSGraphTensor* tile = nil; for (int64_t k = 0; k < N; k += tile_size) { const auto k_end = std::min(k + tile_size, N); auto selfChunk = [graph sliceTensor:selfTensor starts:@[ @(i), @(k) ] ends:@[ @(i_end), @(k_end) ] strides:@[ @(1), @(1) ] name:nil]; auto otherChunk = [graph sliceTensor:otherTensor starts:@[ @(k), @(j) ] ends:@[ @(k_end), @(j_end) ] strides:@[ @(1), @(1) ] name:nil]; auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil]; tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM; } [row_chunks addObject:tile]; } auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject; [rows addObject:row]; } return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject; ``` One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable Fixes https://github.com/pytorch/pytorch/issues/116769 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549 Approved by: https://github.com/kulinseth
This commit is contained in:
committed by
PyTorch MergeBot
parent
f518cf811d
commit
1872834247
@ -9,6 +9,7 @@ namespace at::native {
|
||||
namespace {
|
||||
|
||||
static const char* METAL_CROSS = R"CROSS_METAL(
|
||||
#include <metal_array>
|
||||
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
@ -22,12 +22,119 @@
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
namespace {
|
||||
static const char* METAL_LINALG = R"MATMUL_METAL(
|
||||
#include <metal_array>
|
||||
|
||||
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
|
||||
using namespace metal;
|
||||
template<typename T>
|
||||
T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) {
|
||||
T rc = 0.0;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
rc += v1[i * strides.x] * v2[i * strides.y];
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
|
||||
const Tensor& self,
|
||||
const Tensor& other) {
|
||||
template<typename T>
|
||||
kernel void naive_matmul(
|
||||
constant T * mat1Data [[buffer(0)]],
|
||||
constant T * mat2Data [[buffer(1)]],
|
||||
device T * outputData [[buffer(2)]],
|
||||
constant array<ulong2, 3> & strides [[buffer(3)]],
|
||||
constant uint3 & sizes [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
uint y = thread_index / sizes.x;
|
||||
uint x = thread_index % sizes.x;
|
||||
if (x >= sizes.x || y >= sizes.z) {
|
||||
return;
|
||||
}
|
||||
auto rc = dot_product(mat1Data + x * strides[0].x,
|
||||
mat2Data + y * strides[1].y,
|
||||
ulong2(strides[0].y, strides[1].x),
|
||||
sizes.y);
|
||||
outputData[x * strides[2].x + y * strides[2].y] = rc;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NAIVE_MM(DTYPE) \
|
||||
template \
|
||||
[[host_name("naive_matmul_" #DTYPE)]] \
|
||||
kernel void naive_matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
constant DTYPE * mat2Data [[buffer(1)]], \
|
||||
device DTYPE * outputData [[buffer(2)]], \
|
||||
constant array<ulong2, 3> & strides [[buffer(3)]], \
|
||||
constant uint3 & sizes [[buffer(4)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
INSTANTIATE_NAIVE_MM(float);
|
||||
INSTANTIATE_NAIVE_MM(half);
|
||||
)MATMUL_METAL";
|
||||
|
||||
id<MTLLibrary> compileLinalgOpLibrary(id<MTLDevice> device) {
|
||||
static id<MTLLibrary> linalgLibrary = nil;
|
||||
if (linalgLibrary) {
|
||||
return linalgLibrary;
|
||||
}
|
||||
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:MTLLanguageVersion2_3];
|
||||
linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding]
|
||||
options:options
|
||||
error:&error];
|
||||
TORCH_CHECK(linalgLibrary, "Failed to create metal linalg library, error: ", [[error description] UTF8String]);
|
||||
return linalgLibrary;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> matmulPipelineState(id<MTLDevice> device, ScalarType scalar_type) {
|
||||
std::string kernel = "naive_matmul_" + mps::scalarToMetalTypeString(scalar_type);
|
||||
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
|
||||
id<MTLComputePipelineState> pso = psoCache[kernel];
|
||||
if (pso) {
|
||||
return pso;
|
||||
}
|
||||
|
||||
NSError* error = nil;
|
||||
id<MTLLibrary> linalgLib = compileLinalgOpLibrary(device);
|
||||
id<MTLFunction> matmulFunc = [linalgLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
|
||||
TORCH_CHECK(matmulFunc, "Failed to create function state object for: ", kernel);
|
||||
pso = [device newComputePipelineStateWithFunction:matmulFunc error:&error];
|
||||
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
|
||||
psoCache[kernel] = pso;
|
||||
return pso;
|
||||
}
|
||||
|
||||
Tensor& do_metal_mm(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
auto matmulPSO = matmulPipelineState(device, output.scalar_type());
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
getMPSProfiler().beginProfileKernel(matmulPSO, "naive_matmul", {self, other});
|
||||
auto computeEncoder = stream->commandEncoder();
|
||||
[computeEncoder setComputePipelineState:matmulPSO];
|
||||
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
|
||||
static_cast<uint32_t>(self.size(1)),
|
||||
static_cast<uint32_t>(output.size(1))};
|
||||
std::array<int64_t, 6> strides = {
|
||||
self.stride(0), self.stride(1), other.stride(0), other.stride(1), output.stride(0), output.stride(1)};
|
||||
mtl_setBuffer(computeEncoder, self, 0);
|
||||
mtl_setBuffer(computeEncoder, other, 1);
|
||||
mtl_setBuffer(computeEncoder, output, 2);
|
||||
[computeEncoder setBytes:strides.data() length:sizeof(uint64_t) * strides.size() atIndex:3];
|
||||
[computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
|
||||
mtl_dispatch1DJob(computeEncoder, matmulPSO, output.numel());
|
||||
getMPSProfiler().endProfileKernel(matmulPSO);
|
||||
}
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
|
||||
const Tensor& self,
|
||||
const Tensor& other) {
|
||||
if (self.numel() == 0 || other.numel() == 0) {
|
||||
auto output = [graph constantWithScalar:0.0
|
||||
shape:getMPSShape({self.size(0), other.size(1)})
|
||||
@ -40,6 +147,15 @@ static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGr
|
||||
return {selfTensor, otherTensor, output};
|
||||
}
|
||||
|
||||
bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr;
|
||||
constexpr auto max_stride_size = 32768;
|
||||
return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size ||
|
||||
other.stride(0) > max_stride_size || other.stride(1) > max_stride_size;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
@ -58,6 +174,14 @@ static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor&
|
||||
return output;
|
||||
}
|
||||
|
||||
// MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15
|
||||
// And crashes if its a view of matrix with dimentions larger than 2**15
|
||||
// See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
|
||||
// In such cases, fallback to navie but accurate metal shader
|
||||
if (use_metal_mm(self, other, output)) {
|
||||
return do_metal_mm(self, other, output);
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
|
||||
|
||||
@ -85,6 +209,8 @@ static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor&
|
||||
return output;
|
||||
}
|
||||
|
||||
enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
|
||||
|
||||
static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
|
@ -6869,6 +6869,22 @@ class TestMPS(TestCaseMPS):
|
||||
gc.collect()
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def test_mm_large(self):
|
||||
""" Test that MM works for matrices with index larger than 32K """
|
||||
x = torch.rand(10, 1, device="mps")
|
||||
y = torch.rand(1, 32769, device="mps")
|
||||
# This used to crash with:
|
||||
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
|
||||
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
|
||||
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
|
||||
# And below used to produce incorrect results
|
||||
m, n, k = 1024, 1, 32769
|
||||
x = torch.rand(m, n, device="mps")
|
||||
y = torch.rand(n, k, device="mps")
|
||||
z = torch.mm(x, y).to("cpu")
|
||||
z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
|
||||
self.assertEqual(z, z_cpu)
|
||||
|
||||
# Test flip
|
||||
def test_flip(self):
|
||||
def helper(shape, dims):
|
||||
|
Reference in New Issue
Block a user