mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] cholesky implementation (#145701)
Requested in #77764 Closed #144193 due to a lot of conflicts when rebasing Pull Request resolved: https://github.com/pytorch/pytorch/pull/145701 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6ad08357b
commit
b75afa2e2e
@ -1,4 +1,5 @@
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
template <typename T>
|
||||
@ -31,6 +32,271 @@ kernel void naive_matmul(
|
||||
outputData[x * strides[2].x + y * strides[2].y] = rc;
|
||||
}
|
||||
|
||||
inline float blockReduceSum(
|
||||
threadgroup float* sharedScratch,
|
||||
float val,
|
||||
uint tid,
|
||||
uint tpg) {
|
||||
sharedScratch[tid] = val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint offset = tpg >> 1; offset > 0; offset >>= 1) {
|
||||
if (tid < offset) {
|
||||
sharedScratch[tid] += sharedScratch[tid + offset];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
return sharedScratch[0];
|
||||
}
|
||||
|
||||
kernel void factorDiagonalBlock(
|
||||
device float* A [[buffer(0)]],
|
||||
device int* success [[buffer(1)]],
|
||||
constant uint& N [[buffer(2)]],
|
||||
constant uint& NB [[buffer(3)]],
|
||||
constant uint& k [[buffer(4)]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint bid [[threadgroup_position_in_grid]],
|
||||
uint tpg [[threads_per_threadgroup]]) {
|
||||
const uint actSize = min(N - k * NB, NB); // uint64 before NB
|
||||
const uint batch_offset = bid * N * N;
|
||||
|
||||
const uint row0 = k * NB;
|
||||
const uint col0 = k * NB;
|
||||
|
||||
threadgroup float tile[32][33];
|
||||
threadgroup float reduceScratch[256];
|
||||
const uint tileSize = actSize * actSize;
|
||||
|
||||
for (uint i = tid; i < tileSize; i += tpg) {
|
||||
uint r = i / actSize;
|
||||
uint c = i % actSize;
|
||||
tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint kk = 0; kk < actSize; kk++) {
|
||||
float diagElt = 0.0f;
|
||||
if (kk > 0) {
|
||||
float partialSum = 0.0f;
|
||||
for (uint i = tid; i < kk; i += tpg) {
|
||||
float val = tile[kk][i];
|
||||
partialSum = fma(val, val, partialSum);
|
||||
}
|
||||
diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
float diagVal = tile[kk][kk] - diagElt;
|
||||
// Check for positive definiteness
|
||||
if (diagVal <= 0.0f) {
|
||||
success[bid] = 0; // matrix is not positive definite
|
||||
return;
|
||||
}
|
||||
tile[kk][kk] = sqrt(diagVal);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float pivot = tile[kk][kk];
|
||||
|
||||
for (uint j = kk + 1 + tid; j < actSize; j += tpg) {
|
||||
float partialSum = 0.0f;
|
||||
for (uint i = 0; i < kk; i++) {
|
||||
partialSum = fma(tile[j][i], tile[kk][i], partialSum);
|
||||
}
|
||||
|
||||
float val = tile[j][kk];
|
||||
val -= partialSum;
|
||||
val /= pivot;
|
||||
tile[j][kk] = val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
for (uint i = tid; i < tileSize; i += tpg) {
|
||||
uint r = i / actSize;
|
||||
uint c = i % actSize;
|
||||
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void applyTRSM(
|
||||
device float* A [[buffer(0)]],
|
||||
constant uint& N [[buffer(2)]],
|
||||
constant uint& NB [[buffer(3)]],
|
||||
constant uint& k [[buffer(4)]],
|
||||
uint3 tid [[thread_position_in_threadgroup]],
|
||||
uint3 tgid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threads_per_threadgroup]]) {
|
||||
uint b = tgid.x;
|
||||
uint idxJ = tgid.y;
|
||||
|
||||
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
|
||||
const uint batch_offset = b * N * N;
|
||||
const uint j = (k + 1) + idxJ;
|
||||
|
||||
uint row0 = j * NB;
|
||||
uint col0 = k * NB;
|
||||
|
||||
uint actSize_j = (uint)min((int)(N - row0), (int)NB);
|
||||
if (actSize_k == 0 || actSize_j == 0) {
|
||||
return;
|
||||
}
|
||||
if (j == k) {
|
||||
return;
|
||||
}
|
||||
|
||||
threadgroup float diag[32 * 32];
|
||||
threadgroup float target[32 * 32];
|
||||
|
||||
for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) {
|
||||
uint r = i / actSize_k;
|
||||
uint c = i % actSize_k;
|
||||
diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)];
|
||||
}
|
||||
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
|
||||
uint r = i / actSize_k;
|
||||
uint c = i % actSize_k;
|
||||
target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint col = 0; col < actSize_k; col++) {
|
||||
float diag_val = diag[col * actSize_k + col];
|
||||
if (abs(diag_val) < 1e-6f) {
|
||||
diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f;
|
||||
}
|
||||
|
||||
for (uint row = tid.x; row < actSize_j; row += tpg.x) {
|
||||
float sum = target[row * actSize_k + col];
|
||||
|
||||
// kahan sum
|
||||
float c = 0.0f;
|
||||
for (uint p = 0; p < col; p++) {
|
||||
float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c;
|
||||
float t = sum + y;
|
||||
c = (t - sum) - y;
|
||||
sum = t;
|
||||
}
|
||||
|
||||
target[row * actSize_k + col] = sum / diag_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) {
|
||||
uint r = i / actSize_k;
|
||||
uint c = i % actSize_k;
|
||||
A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void applySYRK(
|
||||
device float* A [[buffer(0)]],
|
||||
constant uint& N [[buffer(2)]],
|
||||
constant uint& NB [[buffer(3)]],
|
||||
constant uint& k [[buffer(4)]],
|
||||
uint3 tid [[thread_position_in_threadgroup]],
|
||||
uint3 tgid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threads_per_threadgroup]]) {
|
||||
uint b = tgid.x;
|
||||
uint pairID = tgid.y;
|
||||
|
||||
uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2;
|
||||
uint hRel = pairID - (jRel * (jRel + 1) >> 1);
|
||||
|
||||
const uint startJ = (k + 1);
|
||||
uint j = startJ + jRel;
|
||||
uint h = startJ + hRel;
|
||||
uint row0 = j * NB;
|
||||
uint col0 = h * NB;
|
||||
|
||||
const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB)));
|
||||
const uint actSize_j = min((uint)(N - row0), NB);
|
||||
const uint actSize_h = min((uint)(N - col0), NB);
|
||||
const uint batch_offset = b * N * N;
|
||||
|
||||
if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0)
|
||||
return;
|
||||
|
||||
threadgroup float left[32 * 33];
|
||||
threadgroup float right_t[32 * 33];
|
||||
threadgroup float tile[32 * 33];
|
||||
|
||||
const uint threads = min(tpg.x, actSize_j * actSize_k);
|
||||
|
||||
for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) {
|
||||
uint r = i / actSize_k;
|
||||
uint c = i % actSize_k;
|
||||
left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)];
|
||||
}
|
||||
|
||||
for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) {
|
||||
uint r = i / actSize_k;
|
||||
uint c = i % actSize_k;
|
||||
right_t[c * actSize_h + r] =
|
||||
A[batch_offset + (h * NB + r) * N + (k * NB + c)];
|
||||
}
|
||||
|
||||
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
|
||||
uint r = i / actSize_h;
|
||||
uint c = i % actSize_h;
|
||||
tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) {
|
||||
uint r = idx / actSize_h;
|
||||
uint c = idx % actSize_h;
|
||||
|
||||
if ((j == h) && (r < c))
|
||||
continue;
|
||||
|
||||
uint tile_idx = r * actSize_h + c;
|
||||
float sum = tile[tile_idx];
|
||||
|
||||
uint left_row = r * actSize_k;
|
||||
uint right_col = c;
|
||||
|
||||
uint k = 0;
|
||||
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
|
||||
for (; k + 4 <= actSize_k; k += 4) {
|
||||
float4 left4 = {
|
||||
left[left_row + k],
|
||||
left[left_row + k + 1],
|
||||
left[left_row + k + 2],
|
||||
left[left_row + k + 3]};
|
||||
|
||||
float4 right4 = {
|
||||
right_t[(k + 0) * actSize_h + right_col],
|
||||
right_t[(k + 1) * actSize_h + right_col],
|
||||
right_t[(k + 2) * actSize_h + right_col],
|
||||
right_t[(k + 3) * actSize_h + right_col]};
|
||||
|
||||
sum4 = fma(left4, right4, sum4);
|
||||
}
|
||||
|
||||
sum -= dot(sum4, 1.0);
|
||||
|
||||
for (; k < actSize_k; k++) {
|
||||
sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum);
|
||||
}
|
||||
|
||||
tile[tile_idx] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) {
|
||||
uint r = i / actSize_h;
|
||||
uint c = i % actSize_h;
|
||||
A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c];
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NAIVE_MM(DTYPE) \
|
||||
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
|
||||
naive_matmul<DTYPE>( \
|
||||
|
@ -18,6 +18,8 @@
|
||||
#include <ATen/ops/addr_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/cholesky_native.h>
|
||||
#include <ATen/ops/linalg_cholesky_native.h>
|
||||
#include <ATen/ops/linalg_lu_factor_native.h>
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
@ -780,6 +782,83 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
|
||||
return out;
|
||||
}
|
||||
|
||||
static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(out.is_mps());
|
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32");
|
||||
TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D");
|
||||
TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square");
|
||||
|
||||
if (input.numel() == 0 || out.numel() == 0) {
|
||||
out.zero_();
|
||||
return out;
|
||||
}
|
||||
resize_output(out, input.sizes());
|
||||
out.copy_(input);
|
||||
|
||||
int64_t ndim = out.dim();
|
||||
int64_t N = out.size(-1);
|
||||
int64_t B = 1;
|
||||
for (int64_t i = 0; i < ndim - 2; i++) {
|
||||
B *= out.size(i);
|
||||
}
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
|
||||
auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock");
|
||||
auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM");
|
||||
auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK");
|
||||
|
||||
int64_t NB = std::min<int64_t>(32, N);
|
||||
int64_t numBlocks = (N + NB - 1) / NB;
|
||||
|
||||
Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
|
||||
id<MTLBuffer> successBuffer = getMTLBufferStorage(success);
|
||||
|
||||
MTLSize threadGroupSize = MTLSizeMake(256, 1, 1);
|
||||
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
[computeEncoder setBuffer:outBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2];
|
||||
[computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3];
|
||||
|
||||
@autoreleasepool {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
for (int64_t k = 0; k < numBlocks; k++) {
|
||||
[computeEncoder setComputePipelineState:factorDiagonalPSO];
|
||||
[computeEncoder setBuffer:successBuffer offset:0 atIndex:1];
|
||||
[computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4];
|
||||
MTLSize gridSize = MTLSizeMake(B, 1, 1);
|
||||
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
// process all remaining blocks in this row/column in parallel
|
||||
if (k < numBlocks - 1) {
|
||||
int64_t startJ = k + 1;
|
||||
int64_t nBlocksJ = (numBlocks - startJ);
|
||||
|
||||
if (nBlocksJ > 0) {
|
||||
// TRSM for all blocks in parallel
|
||||
MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1);
|
||||
[computeEncoder setComputePipelineState:applyTRSMPSO];
|
||||
[computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
// SYRK for all independent block pairs in parallel
|
||||
uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2;
|
||||
MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1);
|
||||
[computeEncoder setComputePipelineState:applySYRKPSO];
|
||||
[computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_CHECK(success.all().item<bool>(), "linalg.cholesky: Input matrix is not positive definite");
|
||||
out.tril_(); //
|
||||
return upper ? out.transpose_(ndim - 2, ndim - 1) : out;
|
||||
}
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
@ -940,6 +1019,25 @@ Tensor& addbmm_out_mps(const Tensor& self,
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor cholesky_mps(const Tensor& self, bool upper) {
|
||||
auto out = at::empty_like(self, MemoryFormat::Contiguous);
|
||||
mps::linalg_cholesky_mps_impl(self, upper, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) {
|
||||
return mps::linalg_cholesky_mps_impl(self, upper, out);
|
||||
}
|
||||
|
||||
Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) {
|
||||
return mps::linalg_cholesky_mps_impl(self, upper, out);
|
||||
}
|
||||
|
||||
Tensor linalg_cholesky_mps(const Tensor& self, bool upper) {
|
||||
auto out = at::empty_like(self, MemoryFormat::Contiguous);
|
||||
return mps::linalg_cholesky_mps_impl(self, upper, out);
|
||||
}
|
||||
|
||||
Tensor addbmm_mps(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
|
@ -9439,11 +9439,13 @@
|
||||
- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU, CUDA: cholesky_out
|
||||
MPS: cholesky_mps_out
|
||||
|
||||
- func: cholesky(Tensor self, bool upper=False) -> Tensor
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CPU, CUDA: cholesky
|
||||
MPS: cholesky_mps
|
||||
|
||||
- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -13900,9 +13902,15 @@
|
||||
|
||||
- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_cholesky
|
||||
MPS: linalg_cholesky_mps
|
||||
|
||||
- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_cholesky_out
|
||||
MPS: linalg_cholesky_out_mps
|
||||
|
||||
- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
|
||||
python_module: linalg
|
||||
|
@ -673,7 +673,6 @@ def mps_ops_modifier(ops):
|
||||
'__rsub__': None,
|
||||
'cauchy_': None,
|
||||
'cauchy': None,
|
||||
'cholesky': None,
|
||||
'cholesky_inverse': None,
|
||||
'cholesky_solve': None,
|
||||
'cummax': None,
|
||||
@ -693,7 +692,6 @@ def mps_ops_modifier(ops):
|
||||
'index_reduceamin': None,
|
||||
'kthvalue': None,
|
||||
'lcm': None,
|
||||
'linalg.cholesky': None,
|
||||
'linalg.cholesky_ex': None,
|
||||
'linalg.cond': None,
|
||||
'linalg.det': None,
|
||||
@ -6388,6 +6386,30 @@ class TestMPS(TestCaseMPS):
|
||||
atol=0, rtol=0
|
||||
)
|
||||
|
||||
def test_cholesky(self):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
def run_cholesky_test(size, *batch_dims, upper):
|
||||
input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu")
|
||||
input_mps = input_cpu.to('mps')
|
||||
output_cpu = torch.linalg.cholesky(input_cpu, upper=upper)
|
||||
output_mps = torch.linalg.cholesky(input_mps, upper=upper)
|
||||
self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6)
|
||||
|
||||
# test with different even/odd matrix sizes
|
||||
matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154]
|
||||
# even/odd batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16, 17]
|
||||
|
||||
for upper in [True, False]:
|
||||
for size in matrix_sizes:
|
||||
for batch_size in batch_sizes:
|
||||
run_cholesky_test(size, batch_size, upper=upper)
|
||||
|
||||
# test >3D matrices
|
||||
run_cholesky_test(128, 10, 10, upper=False)
|
||||
run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True)
|
||||
|
||||
def test_upsample_nearest2d(self):
|
||||
def helper(N, C, H, W, memory_format):
|
||||
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
|
||||
|
@ -410,6 +410,11 @@
|
||||
self: cholesky_backward(grad, upper, L)
|
||||
L: cholesky_jvp(self_t, L, upper)
|
||||
|
||||
# temporarily here before linalg_cholesky dispatches to linalg_cholesky_ex on MPS device
|
||||
- name: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
|
||||
self: cholesky_backward(grad, upper, result)
|
||||
result: cholesky_jvp(self_t, result, upper)
|
||||
|
||||
- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
|
||||
self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask)
|
||||
result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)
|
||||
|
Reference in New Issue
Block a user