[MPS] cholesky implementation (#145701)

Requested in #77764

Closed #144193  due to a lot of conflicts when rebasing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145701
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-01-27 01:53:03 +00:00
committed by PyTorch MergeBot
parent c6ad08357b
commit b75afa2e2e
5 changed files with 401 additions and 2 deletions

View File

@ -1,4 +1,5 @@
#include <metal_array>
#include <metal_stdlib>
using namespace metal;
template <typename T>
@ -31,6 +32,271 @@ kernel void naive_matmul(
outputData[x * strides[2].x + y * strides[2].y] = rc;
}
inline float blockReduceSum(
threadgroup float* sharedScratch,
float val,
uint tid,
uint tpg) {
sharedScratch[tid] = val;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint offset = tpg >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
sharedScratch[tid] += sharedScratch[tid + offset];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
return sharedScratch[0];
}
kernel void factorDiagonalBlock(
device float* A [[buffer(0)]],
device int* success [[buffer(1)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint tid [[thread_position_in_threadgroup]],
uint bid [[threadgroup_position_in_grid]],
uint tpg [[threads_per_threadgroup]]) {
const uint actSize = min(N - k * NB, NB); // uint64 before NB
const uint batch_offset = bid * N * N;
const uint row0 = k * NB;
const uint col0 = k * NB;
threadgroup float tile[32][33];
threadgroup float reduceScratch[256];
const uint tileSize = actSize * actSize;
for (uint i = tid; i < tileSize; i += tpg) {
uint r = i / actSize;
uint c = i % actSize;
tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint kk = 0; kk < actSize; kk++) {
float diagElt = 0.0f;
if (kk > 0) {
float partialSum = 0.0f;
for (uint i = tid; i < kk; i += tpg) {
float val = tile[kk][i];
partialSum = fma(val, val, partialSum);
}
diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg);
}
if (tid == 0) {
float diagVal = tile[kk][kk] - diagElt;
// Check for positive definiteness
if (diagVal <= 0.0f) {
success[bid] = 0; // matrix is not positive definite
return;
}
tile[kk][kk] = sqrt(diagVal);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float pivot = tile[kk][kk];
for (uint j = kk + 1 + tid; j < actSize; j += tpg) {
float partialSum = 0.0f;
for (uint i = 0; i < kk; i++) {
partialSum = fma(tile[j][i], tile[kk][i], partialSum);
}
float val = tile[j][kk];
val -= partialSum;
val /= pivot;
tile[j][kk] = val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
for (uint i = tid; i < tileSize; i += tpg) {
uint r = i / actSize;
uint c = i % actSize;
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c];
}
}
kernel void applyTRSM(
device float* A [[buffer(0)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]]) {
uint b = tgid.x;
uint idxJ = tgid.y;
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
const uint batch_offset = b * N * N;
const uint j = (k + 1) + idxJ;
uint row0 = j * NB;
uint col0 = k * NB;
uint actSize_j = (uint)min((int)(N - row0), (int)NB);
if (actSize_k == 0 || actSize_j == 0) {
return;
}
if (j == k) {
return;
}
threadgroup float diag[32 * 32];
threadgroup float target[32 * 32];
for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)];
}
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint col = 0; col < actSize_k; col++) {
float diag_val = diag[col * actSize_k + col];
if (abs(diag_val) < 1e-6f) {
diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f;
}
for (uint row = tid.x; row < actSize_j; row += tpg.x) {
float sum = target[row * actSize_k + col];
// kahan sum
float c = 0.0f;
for (uint p = 0; p < col; p++) {
float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c;
float t = sum + y;
c = (t - sum) - y;
sum = t;
}
target[row * actSize_k + col] = sum / diag_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
uint r = i / actSize_k;
uint c = i % actSize_k;
A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i];
}
}
kernel void applySYRK(
device float* A [[buffer(0)]],
constant uint& N [[buffer(2)]],
constant uint& NB [[buffer(3)]],
constant uint& k [[buffer(4)]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 tgid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]]) {
uint b = tgid.x;
uint pairID = tgid.y;
uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2;
uint hRel = pairID - (jRel * (jRel + 1) >> 1);
const uint startJ = (k + 1);
uint j = startJ + jRel;
uint h = startJ + hRel;
uint row0 = j * NB;
uint col0 = h * NB;
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
const uint actSize_j = min((uint)(N - row0), NB);
const uint actSize_h = min((uint)(N - col0), NB);
const uint batch_offset = b * N * N;
if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0)
return;
threadgroup float left[32 * 33];
threadgroup float right_t[32 * 33];
threadgroup float tile[32 * 33];
const uint threads = min(tpg.x, actSize_j * actSize_k);
for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) {
uint r = i / actSize_k;
uint c = i % actSize_k;
left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)];
}
for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) {
uint r = i / actSize_k;
uint c = i % actSize_k;
right_t[c * actSize_h + r] =
A[batch_offset + (h * NB + r) * N + (k * NB + c)];
}
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
uint r = i / actSize_h;
uint c = i % actSize_h;
tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) {
uint r = idx / actSize_h;
uint c = idx % actSize_h;
if ((j == h) && (r < c))
continue;
uint tile_idx = r * actSize_h + c;
float sum = tile[tile_idx];
uint left_row = r * actSize_k;
uint right_col = c;
uint k = 0;
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
for (; k + 4 <= actSize_k; k += 4) {
float4 left4 = {
left[left_row + k],
left[left_row + k + 1],
left[left_row + k + 2],
left[left_row + k + 3]};
float4 right4 = {
right_t[(k + 0) * actSize_h + right_col],
right_t[(k + 1) * actSize_h + right_col],
right_t[(k + 2) * actSize_h + right_col],
right_t[(k + 3) * actSize_h + right_col]};
sum4 = fma(left4, right4, sum4);
}
sum -= dot(sum4, 1.0);
for (; k < actSize_k; k++) {
sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum);
}
tile[tile_idx] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
uint r = i / actSize_h;
uint c = i % actSize_h;
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c];
}
}
#define INSTANTIATE_NAIVE_MM(DTYPE) \
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
naive_matmul<DTYPE>( \

View File

@ -18,6 +18,8 @@
#include <ATen/ops/addr_native.h>
#include <ATen/ops/baddbmm_native.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/cholesky_native.h>
#include <ATen/ops/linalg_cholesky_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/mm_native.h>
@ -780,6 +782,83 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
return out;
}
static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
using namespace mps;
TORCH_CHECK(out.is_mps());
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32");
TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D");
TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square");
if (input.numel() == 0 || out.numel() == 0) {
out.zero_();
return out;
}
resize_output(out, input.sizes());
out.copy_(input);
int64_t ndim = out.dim();
int64_t N = out.size(-1);
int64_t B = 1;
for (int64_t i = 0; i < ndim - 2; i++) {
B *= out.size(i);
}
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock");
auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM");
auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK");
int64_t NB = std::min<int64_t>(32, N);
int64_t numBlocks = (N + NB - 1) / NB;
Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
id<MTLBuffer> successBuffer = getMTLBufferStorage(success);
MTLSize threadGroupSize = MTLSizeMake(256, 1, 1);
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
[computeEncoder setBuffer:outBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2];
[computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3];
@autoreleasepool {
dispatch_sync_with_rethrow(stream->queue(), ^() {
for (int64_t k = 0; k < numBlocks; k++) {
[computeEncoder setComputePipelineState:factorDiagonalPSO];
[computeEncoder setBuffer:successBuffer offset:0 atIndex:1];
[computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4];
MTLSize gridSize = MTLSizeMake(B, 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
// process all remaining blocks in this row/column in parallel
if (k < numBlocks - 1) {
int64_t startJ = k + 1;
int64_t nBlocksJ = (numBlocks - startJ);
if (nBlocksJ > 0) {
// TRSM for all blocks in parallel
MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1);
[computeEncoder setComputePipelineState:applyTRSMPSO];
[computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize];
// SYRK for all independent block pairs in parallel
uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2;
MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1);
[computeEncoder setComputePipelineState:applySYRKPSO];
[computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize];
}
}
}
});
}
TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite");
out.tril_(); //
return upper ? out.transpose_(ndim - 2, ndim - 1) : out;
}
} // namespace mps
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
@ -940,6 +1019,25 @@ Tensor& addbmm_out_mps(const Tensor& self,
return result;
}
Tensor cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
mps::linalg_cholesky_mps_impl(self, upper, out);
return out;
}
Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
}
Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) {
return mps::linalg_cholesky_mps_impl(self, upper, out);
}
Tensor linalg_cholesky_mps(const Tensor& self, bool upper) {
auto out = at::empty_like(self, MemoryFormat::Contiguous);
return mps::linalg_cholesky_mps_impl(self, upper, out);
}
Tensor addbmm_mps(const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,

View File

@ -9439,11 +9439,13 @@
- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: cholesky_out
MPS: cholesky_mps_out
- func: cholesky(Tensor self, bool upper=False) -> Tensor
variants: method, function
dispatch:
CPU, CUDA: cholesky
MPS: cholesky_mps
- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
@ -13900,9 +13902,15 @@
- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
python_module: linalg
dispatch:
CompositeImplicitAutograd: linalg_cholesky
MPS: linalg_cholesky_mps
- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CompositeImplicitAutograd: linalg_cholesky_out
MPS: linalg_cholesky_out_mps
- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
python_module: linalg

View File

@ -673,7 +673,6 @@ def mps_ops_modifier(ops):
'__rsub__': None,
'cauchy_': None,
'cauchy': None,
'cholesky': None,
'cholesky_inverse': None,
'cholesky_solve': None,
'cummax': None,
@ -693,7 +692,6 @@ def mps_ops_modifier(ops):
'index_reduceamin': None,
'kthvalue': None,
'lcm': None,
'linalg.cholesky': None,
'linalg.cholesky_ex': None,
'linalg.cond': None,
'linalg.det': None,
@ -6388,6 +6386,30 @@ class TestMPS(TestCaseMPS):
atol=0, rtol=0
)
def test_cholesky(self):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
def run_cholesky_test(size, *batch_dims, upper):
input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu")
input_mps = input_cpu.to('mps')
output_cpu = torch.linalg.cholesky(input_cpu, upper=upper)
output_mps = torch.linalg.cholesky(input_mps, upper=upper)
self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6)
# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154]
# even/odd batch sizes
batch_sizes = [1, 2, 4, 8, 16, 17]
for upper in [True, False]:
for size in matrix_sizes:
for batch_size in batch_sizes:
run_cholesky_test(size, batch_size, upper=upper)
# test >3D matrices
run_cholesky_test(128, 10, 10, upper=False)
run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True)
def test_upsample_nearest2d(self):
def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,

View File

@ -410,6 +410,11 @@
self: cholesky_backward(grad, upper, L)
L: cholesky_jvp(self_t, L, upper)
# temporarily here before linalg_cholesky dispatches to linalg_cholesky_ex on MPS device
- name: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
self: cholesky_backward(grad, upper, result)
result: cholesky_jvp(self_t, result, upper)
- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask)
result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)