[MPS] lu unpack (#146681)

Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-02-08 00:16:17 +00:00
committed by PyTorch MergeBot
parent 803661526e
commit 0ab67299c3
4 changed files with 114 additions and 4 deletions

View File

@ -395,6 +395,58 @@ kernel void applySYRK(
}
}
kernel void applyPivots(
device float* P [[buffer(0)]],
device const int* pivots [[buffer(1)]],
constant uint& R [[buffer(2)]],
constant uint& K [[buffer(3)]],
uint3 tid [[thread_position_in_threadgroup]],
uint3 bid [[threadgroup_position_in_grid]],
uint3 tpg [[threads_per_threadgroup]]) {
uint tx = tid.x;
uint group_size = tpg.x * tpg.y;
uint batch_idx = bid.x;
for (int i = static_cast<int>(K) - 1; i >= 0; i--) {
int pivot = pivots[batch_idx * K + i];
if (pivot == i) {
// no swap needed
continue;
}
for (uint j = tx * 4; j < R; j += group_size * 4) {
uint elementsRemaining = R - j;
// if we can use float4 or not
if (elementsRemaining < 4) {
for (uint e = 0; e < elementsRemaining; e++) {
float row_i_value = P[batch_idx * R * R + i * R + (j + e)];
float pivot_row_value = P[batch_idx * R * R + pivot * R + (j + e)];
P[batch_idx * R * R + i * R + (j + e)] = pivot_row_value;
P[batch_idx * R * R + pivot * R + (j + e)] = row_i_value;
}
} else {
// vectorized load/stores
device float4* rowIPtr =
reinterpret_cast<device float4*>(&P[batch_idx * R * R + i * R + j]);
device float4* pivotPtr = reinterpret_cast<device float4*>(
&P[batch_idx * R * R + pivot * R + j]);
float4 row_i_val = *rowIPtr;
float4 pivot_val = *pivotPtr;
*rowIPtr = pivot_val;
*pivotPtr = row_i_val;
}
}
// barrier here so different threads do not rush after each other
// swapping rows for the next iteration while
// some threads are swapping the current one
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
#define INSTANTIATE_NAIVE_MM(DTYPE) \
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
naive_matmul<DTYPE>( \

View File

@ -24,7 +24,9 @@
#include <ATen/ops/linalg_lu_factor_ex_native.h>
#include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h>
#include <ATen/ops/lu_unpack_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/stack.h>
#include <ATen/ops/triangular_solve_native.h>
#endif
@ -1001,6 +1003,54 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
return out;
}
static void lu_unpack_mps_impl(const Tensor& LU_data,
const Tensor& LU_pivots,
bool unpack_data,
bool unpack_pivots,
const Tensor& P,
const Tensor& L,
const Tensor& U) {
const auto ndim = LU_data.dim();
TORCH_CHECK(ndim >= 2, "LU_data must have at least 2 dimensions");
const auto r = LU_data.size(-2);
const auto c = LU_data.size(-1);
const auto k = std::min<int64_t>(r, c);
const auto batchSize = c10::multiply_integers(LU_data.sizes().begin(), LU_data.sizes().end() - 2);
if (unpack_data) {
Tensor L_part = r < c ? slice(LU_data, -1, 0, k) : LU_data;
L.copy_(L_part.tril());
(ndim == 2 ? L.diagonal() : L.diagonal(0, -2, -1)).fill_(1);
Tensor U_part = r < c ? LU_data : slice(LU_data, -2, 0, k);
U.copy_(U_part.triu());
}
if (unpack_pivots) {
// P as an identity matrix for pivots
P.fill_(0);
LU_pivots.dim() == 1 ? P.diagonal().fill_(1) : P.diagonal(0, -2, -1).fill_(1);
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto applyPivotsPSO = lib.getPipelineStateForFunc("applyPivots");
uint32_t maxThreadsPerGroup = [applyPivotsPSO maxTotalThreadsPerThreadgroup];
auto pivots = (LU_pivots.dim() == 1) ? LU_pivots.sub(1) : LU_pivots.view({batchSize, -1}).sub(1);
@autoreleasepool {
dispatch_sync_with_rethrow(stream->queue(), ^() {
auto computeEncoder = stream->commandEncoder();
mtl_setArgs(computeEncoder, P, pivots, r, k);
[computeEncoder setComputePipelineState:applyPivotsPSO];
mtl_dispatch1DJob(computeEncoder, applyPivotsPSO, batchSize * maxThreadsPerGroup);
});
}
}
}
static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
using namespace mps;
@ -1321,6 +1371,17 @@ std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
return std::make_tuple(std::move(LU), std::move(pivots));
}
TORCH_IMPL_FUNC(lu_unpack_out_mps)
(const Tensor& LU_data,
const Tensor& LU_pivots,
bool unpack_data,
bool unpack_pivots,
const Tensor& P,
const Tensor& L,
const Tensor& U) {
mps::lu_unpack_mps_impl(LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U);
}
TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps)
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);

View File

@ -9520,6 +9520,7 @@
structured: True
dispatch:
CPU, CUDA: lu_unpack_out
MPS: lu_unpack_out_mps
# TODO: remove dispatch section when porting TH CUDA to ATen
- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)

View File

@ -88,9 +88,6 @@ def mps_ops_grad_modifier(ops):
'cdist': [torch.float32],
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`.
'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
@ -724,7 +721,6 @@ def mps_ops_modifier(ops):
'logcumsumexp': None,
'logdet': None,
'lu_solve': None,
'lu_unpack': None,
'masked.median': None,
'matrix_exp': None,
'mode': None,