mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] lu unpack (#146681)
Implements lu unpack function on MPS. Haven't added new tests because they are covered by removing the lu_unpack from UNIMPLEMENTED_XFAILLIST in test_mps with `test_output_match` function Pull Request resolved: https://github.com/pytorch/pytorch/pull/146681 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
803661526e
commit
0ab67299c3
@ -395,6 +395,58 @@ kernel void applySYRK(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void applyPivots(
|
||||
device float* P [[buffer(0)]],
|
||||
device const int* pivots [[buffer(1)]],
|
||||
constant uint& R [[buffer(2)]],
|
||||
constant uint& K [[buffer(3)]],
|
||||
uint3 tid [[thread_position_in_threadgroup]],
|
||||
uint3 bid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threads_per_threadgroup]]) {
|
||||
uint tx = tid.x;
|
||||
uint group_size = tpg.x * tpg.y;
|
||||
uint batch_idx = bid.x;
|
||||
|
||||
for (int i = static_cast<int>(K) - 1; i >= 0; i--) {
|
||||
int pivot = pivots[batch_idx * K + i];
|
||||
if (pivot == i) {
|
||||
// no swap needed
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint j = tx * 4; j < R; j += group_size * 4) {
|
||||
uint elementsRemaining = R - j;
|
||||
|
||||
// if we can use float4 or not
|
||||
if (elementsRemaining < 4) {
|
||||
for (uint e = 0; e < elementsRemaining; e++) {
|
||||
float row_i_value = P[batch_idx * R * R + i * R + (j + e)];
|
||||
float pivot_row_value = P[batch_idx * R * R + pivot * R + (j + e)];
|
||||
|
||||
P[batch_idx * R * R + i * R + (j + e)] = pivot_row_value;
|
||||
P[batch_idx * R * R + pivot * R + (j + e)] = row_i_value;
|
||||
}
|
||||
} else {
|
||||
// vectorized load/stores
|
||||
device float4* rowIPtr =
|
||||
reinterpret_cast<device float4*>(&P[batch_idx * R * R + i * R + j]);
|
||||
device float4* pivotPtr = reinterpret_cast<device float4*>(
|
||||
&P[batch_idx * R * R + pivot * R + j]);
|
||||
|
||||
float4 row_i_val = *rowIPtr;
|
||||
float4 pivot_val = *pivotPtr;
|
||||
|
||||
*rowIPtr = pivot_val;
|
||||
*pivotPtr = row_i_val;
|
||||
}
|
||||
}
|
||||
// barrier here so different threads do not rush after each other
|
||||
// swapping rows for the next iteration while
|
||||
// some threads are swapping the current one
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NAIVE_MM(DTYPE) \
|
||||
template [[host_name("naive_matmul_" #DTYPE)]] kernel void \
|
||||
naive_matmul<DTYPE>( \
|
||||
|
@ -24,7 +24,9 @@
|
||||
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
||||
#include <ATen/ops/linalg_lu_factor_native.h>
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
#endif
|
||||
@ -1001,6 +1003,54 @@ static Tensor& linalg_solve_triangular_mps_impl(const Tensor& A,
|
||||
return out;
|
||||
}
|
||||
|
||||
static void lu_unpack_mps_impl(const Tensor& LU_data,
|
||||
const Tensor& LU_pivots,
|
||||
bool unpack_data,
|
||||
bool unpack_pivots,
|
||||
const Tensor& P,
|
||||
const Tensor& L,
|
||||
const Tensor& U) {
|
||||
const auto ndim = LU_data.dim();
|
||||
TORCH_CHECK(ndim >= 2, "LU_data must have at least 2 dimensions");
|
||||
|
||||
const auto r = LU_data.size(-2);
|
||||
const auto c = LU_data.size(-1);
|
||||
const auto k = std::min<int64_t>(r, c);
|
||||
|
||||
const auto batchSize = c10::multiply_integers(LU_data.sizes().begin(), LU_data.sizes().end() - 2);
|
||||
|
||||
if (unpack_data) {
|
||||
Tensor L_part = r < c ? slice(LU_data, -1, 0, k) : LU_data;
|
||||
L.copy_(L_part.tril());
|
||||
(ndim == 2 ? L.diagonal() : L.diagonal(0, -2, -1)).fill_(1);
|
||||
|
||||
Tensor U_part = r < c ? LU_data : slice(LU_data, -2, 0, k);
|
||||
U.copy_(U_part.triu());
|
||||
}
|
||||
|
||||
if (unpack_pivots) {
|
||||
// P as an identity matrix for pivots
|
||||
P.fill_(0);
|
||||
LU_pivots.dim() == 1 ? P.diagonal().fill_(1) : P.diagonal(0, -2, -1).fill_(1);
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
auto applyPivotsPSO = lib.getPipelineStateForFunc("applyPivots");
|
||||
uint32_t maxThreadsPerGroup = [applyPivotsPSO maxTotalThreadsPerThreadgroup];
|
||||
|
||||
auto pivots = (LU_pivots.dim() == 1) ? LU_pivots.sub(1) : LU_pivots.view({batchSize, -1}).sub(1);
|
||||
|
||||
@autoreleasepool {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
auto computeEncoder = stream->commandEncoder();
|
||||
mtl_setArgs(computeEncoder, P, pivots, r, k);
|
||||
[computeEncoder setComputePipelineState:applyPivotsPSO];
|
||||
mtl_dispatch1DJob(computeEncoder, applyPivotsPSO, batchSize * maxThreadsPerGroup);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) {
|
||||
using namespace mps;
|
||||
|
||||
@ -1321,6 +1371,17 @@ std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
|
||||
return std::make_tuple(std::move(LU), std::move(pivots));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(lu_unpack_out_mps)
|
||||
(const Tensor& LU_data,
|
||||
const Tensor& LU_pivots,
|
||||
bool unpack_data,
|
||||
bool unpack_pivots,
|
||||
const Tensor& P,
|
||||
const Tensor& L,
|
||||
const Tensor& U) {
|
||||
mps::lu_unpack_mps_impl(LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps)
|
||||
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);
|
||||
|
@ -9520,6 +9520,7 @@
|
||||
structured: True
|
||||
dispatch:
|
||||
CPU, CUDA: lu_unpack_out
|
||||
MPS: lu_unpack_out_mps
|
||||
|
||||
# TODO: remove dispatch section when porting TH CUDA to ATen
|
||||
- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -88,9 +88,6 @@ def mps_ops_grad_modifier(ops):
|
||||
'cdist': [torch.float32],
|
||||
'masked.scatter': [torch.float16, torch.float32],
|
||||
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
|
||||
'lu': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||
'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||
'linalg.lu_factor_ex': [torch.float16, torch.float32], # missing `aten::lu_unpack`.
|
||||
'linalg.solve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||
'linalg.solve_ex': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||
'linalg.tensorsolve': [torch.float16, torch.float32], # missing `aten::lu_solve`.
|
||||
@ -724,7 +721,6 @@ def mps_ops_modifier(ops):
|
||||
'logcumsumexp': None,
|
||||
'logdet': None,
|
||||
'lu_solve': None,
|
||||
'lu_unpack': None,
|
||||
'masked.median': None,
|
||||
'matrix_exp': None,
|
||||
'mode': None,
|
||||
|
Reference in New Issue
Block a user