From 0ab67299c35481b62ac2c1f2ada80f12eb4093f5 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 8 Feb 2025 00:16:17 +0000 Subject: [PATCH] [MPS] lu unpack (#146681) Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../native/mps/kernels/LinearAlgebra.metal | 52 ++++++++++++++++ .../native/mps/operations/LinearAlgebra.mm | 61 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 1 + test/test_mps.py | 4 -- 4 files changed, 114 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 89c5ab6cdb42..ffa739d3b455 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -395,6 +395,58 @@ kernel void applySYRK( } } +kernel void applyPivots( + device float* P [[buffer(0)]], + device const int* pivots [[buffer(1)]], + constant uint& R [[buffer(2)]], + constant uint& K [[buffer(3)]], + uint3 tid [[thread_position_in_threadgroup]], + uint3 bid [[threadgroup_position_in_grid]], + uint3 tpg [[threads_per_threadgroup]]) { + uint tx = tid.x; + uint group_size = tpg.x * tpg.y; + uint batch_idx = bid.x; + + for (int i = static_cast(K) - 1; i >= 0; i--) { + int pivot = pivots[batch_idx * K + i]; + if (pivot == i) { + // no swap needed + continue; + } + + for (uint j = tx * 4; j < R; j += group_size * 4) { + uint elementsRemaining = R - j; + + // if we can use float4 or not + if (elementsRemaining < 4) { + for (uint e = 0; e < elementsRemaining; e++) { + float row_i_value = P[batch_idx * R * R + i * R + (j + e)]; + float pivot_row_value = P[batch_idx * R * R + pivot * R + (j + e)]; + + P[batch_idx * R * R + i * R + (j + e)] = pivot_row_value; + P[batch_idx * R * R + pivot * R + (j + e)] = row_i_value; + } + } else { + // vectorized load/stores + device float4* rowIPtr = + reinterpret_cast(&P[batch_idx * R * R + i * R + j]); + device float4* pivotPtr = reinterpret_cast( + &P[batch_idx * R * R + pivot * R + j]); + + float4 row_i_val = *rowIPtr; + float4 pivot_val = *pivotPtr; + + *rowIPtr = pivot_val; + *pivotPtr = row_i_val; + } + } + // barrier here so different threads do not rush after each other + // swapping rows for the next iteration while + // some threads are swapping the current one + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + #define INSTANTIATE_NAIVE_MM(DTYPE) \ template [[host_name("naive_matmul_" #DTYPE)]] kernel void \ naive_matmul( \ diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 02ae024049be..82d3ca25671d 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -24,7 +24,9 @@ #include #include #include +#include #include +#include #include #include #endif @@ -1001,6 +1003,54 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A, return out; } +static void lu_unpack_mps_impl(const Tensor& LU_data, + const Tensor& LU_pivots, + bool unpack_data, + bool unpack_pivots, + const Tensor& P, + const Tensor& L, + const Tensor& U) { + const auto ndim = LU_data.dim(); + TORCH_CHECK(ndim >= 2, "LU_data must have at least 2 dimensions"); + + const auto r = LU_data.size(-2); + const auto c = LU_data.size(-1); + const auto k = std::min(r, c); + + const auto batchSize = c10::multiply_integers(LU_data.sizes().begin(), LU_data.sizes().end() - 2); + + if (unpack_data) { + Tensor L_part = r < c ? slice(LU_data, -1, 0, k) : LU_data; + L.copy_(L_part.tril()); + (ndim == 2 ? L.diagonal() : L.diagonal(0, -2, -1)).fill_(1); + + Tensor U_part = r < c ? LU_data : slice(LU_data, -2, 0, k); + U.copy_(U_part.triu()); + } + + if (unpack_pivots) { + // P as an identity matrix for pivots + P.fill_(0); + LU_pivots.dim() == 1 ? P.diagonal().fill_(1) : P.diagonal(0, -2, -1).fill_(1); + + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto applyPivotsPSO = lib.getPipelineStateForFunc("applyPivots"); + uint32_t maxThreadsPerGroup = [applyPivotsPSO maxTotalThreadsPerThreadgroup]; + + auto pivots = (LU_pivots.dim() == 1) ? LU_pivots.sub(1) : LU_pivots.view({batchSize, -1}).sub(1); + + @autoreleasepool { + dispatch_sync_with_rethrow(stream->queue(), ^() { + auto computeEncoder = stream->commandEncoder(); + mtl_setArgs(computeEncoder, P, pivots, r, k); + [computeEncoder setComputePipelineState:applyPivotsPSO]; + mtl_dispatch1DJob(computeEncoder, applyPivotsPSO, batchSize * maxThreadsPerGroup); + }); + } + } +} + static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) { using namespace mps; @@ -1321,6 +1371,17 @@ std::tuple linalg_lu_factor_mps(const Tensor& A, bool pivot) { return std::make_tuple(std::move(LU), std::move(pivots)); } +TORCH_IMPL_FUNC(lu_unpack_out_mps) +(const Tensor& LU_data, + const Tensor& LU_pivots, + bool unpack_data, + bool unpack_pivots, + const Tensor& P, + const Tensor& L, + const Tensor& U) { + mps::lu_unpack_mps_impl(LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); +} + TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps) (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4e2b97c066e6..9fe24ec7e9a6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9520,6 +9520,7 @@ structured: True dispatch: CPU, CUDA: lu_unpack_out + MPS: lu_unpack_out_mps # TODO: remove dispatch section when porting TH CUDA to ATen - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) diff --git a/test/test_mps.py b/test/test_mps.py index 3d739741b103..ffe209e49254 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -88,9 +88,6 @@ def mps_ops_grad_modifier(ops): 'cdist': [torch.float32], 'masked.scatter': [torch.float16, torch.float32], 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. - 'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`. - 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`. - 'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`. 'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`. 'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`. @@ -724,7 +721,6 @@ def mps_ops_modifier(ops): 'logcumsumexp': None, 'logdet': None, 'lu_solve': None, - 'lu_unpack': None, 'masked.median': None, 'matrix_exp': None, 'mode': None,