[MPS] Support large tensors in torch.cat (#164416)

Fixes #164415
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164416
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-10-09 11:32:13 -05:00
committed by PyTorch MergeBot
parent 684df93975
commit 83cbba8759
6 changed files with 271 additions and 3 deletions

View File

@ -99,6 +99,9 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape);
MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
// Determines whether a tensor is too large to use MPSGraph
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI = true);
static inline id<MTLBuffer> getMTLBufferStorage(const TensorBase& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

View File

@ -439,6 +439,22 @@ static void check_mps_shape(MPSShape* shape) {
}
}
bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if ((!tensor.is_contiguous() || tensor.storage_offset()) && useMPSStridedAPI && is_macOS_15_0_or_newer) {
auto storage_numel = tensor.storage().nbytes() / tensor.element_size() - tensor.storage_offset();
if (storage_numel > std::numeric_limits<int32_t>::max()) {
return true;
}
}
for (auto size : tensor.sizes()) {
if (size > std::numeric_limits<int32_t>::max()) {
return true;
}
}
return false;
}
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);

View File

@ -0,0 +1,18 @@
#pragma once
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeSharedParams {
int32_t ndim;
int32_t cat_dim;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
};
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeInputParams {
idx_type_t cat_dim_offset;
idx_type_t input_element_offset;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> input_sizes;
};

View File

@ -0,0 +1,82 @@
#include <ATen/native/mps/kernels/Shape.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
template <typename T_in, typename T_out>
kernel void cat_large(
constant T_in* input [[buffer(0)]],
device T_out* output [[buffer(1)]],
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
constant CatLargeInputParams<>& input_params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto ndim = shared_params.ndim;
auto cat_dim = shared_params.cat_dim;
constant auto& output_strides = shared_params.output_strides;
constant auto& output_sizes = shared_params.output_sizes;
auto cat_dim_offset = input_params.cat_dim_offset;
auto input_element_offset = input_params.input_element_offset;
constant auto& input_strides = input_params.input_strides;
constant auto& input_sizes = input_params.input_sizes;
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
int64_t input_offset = 0;
int64_t output_offset = 0;
for (auto dim = ndim - 1; dim >= 0; dim--) {
auto dim_size = input_sizes[dim];
auto input_dim_idx = input_element_idx % dim_size;
auto output_dim_idx =
input_dim_idx + ((dim == cat_dim) ? cat_dim_offset : 0);
input_offset += input_strides[dim] * input_dim_idx;
output_offset += output_strides[dim] * output_dim_idx;
input_element_idx = input_element_idx / dim_size;
}
output[output_offset] = static_cast<T_out>(input[input_offset]);
}
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
kernel void cat_large<T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
REGISTER_CAT_LARGE_OP(float, T_out); \
REGISTER_CAT_LARGE_OP(half, T_out); \
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
REGISTER_CAT_LARGE_OP(int, T_out); \
REGISTER_CAT_LARGE_OP(uint, T_out); \
REGISTER_CAT_LARGE_OP(long, T_out); \
REGISTER_CAT_LARGE_OP(ulong, T_out); \
REGISTER_CAT_LARGE_OP(short, T_out); \
REGISTER_CAT_LARGE_OP(ushort, T_out); \
REGISTER_CAT_LARGE_OP(char, T_out); \
REGISTER_CAT_LARGE_OP(uchar, T_out); \
REGISTER_CAT_LARGE_OP(bool, T_out);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
REGISTER_CAT_LARGE_OP(float2, float2);
REGISTER_CAT_LARGE_OP(half2, half2);

View File

@ -2,9 +2,13 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/kernels/Shape.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -16,6 +20,13 @@
#endif
namespace at::native {
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Shape_metallib.h>
#endif
namespace mps {
// Produces a shape with the `dim` dimension set to 0.
@ -57,6 +68,70 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
")");
}
}
// This implementation of cat is used only if one of the inputs or the output is
// too large to use MPSGraph.
// NOTE: `output` is expected to already have the correct size.
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatLargeSharedParams shared_params;
shared_params.ndim = output.dim();
shared_params.cat_dim = dimension;
for (const auto dim : c10::irange(output.dim())) {
shared_params.output_strides[dim] = output.stride(dim);
shared_params.output_sizes[dim] = output.size(dim);
}
int64_t cat_dim_offset = 0;
size_t input_idx = 0;
MPSStream* stream = getCurrentMPSStream();
// Launch a separate kernels for each input. This will produce some overhead,
// but that should be relatively minimal since at least one of the inputs is
// very large. In order to launch only one kernel to process all inputs, we
// would have to copy all the input tensor data into a packed buffer, which
// would not be ideal.
for (const Tensor& input : inputs) {
if (input.numel() == 0) {
continue;
}
// Metal can only launch up to MAX_INT threads at one time. If the input has
// more than that number of elements, launch multiple kernels with different
// offsets into the data.
const int64_t max_num_threads = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
auto num_threads = std::min(max_num_threads, numel_remaining);
CatLargeInputParams input_params;
input_params.cat_dim_offset = cat_dim_offset;
input_params.input_element_offset = input.numel() - numel_remaining;
for (const auto dim : c10::irange(input.dim())) {
input_params.input_strides[dim] = input.stride(dim);
input_params.input_sizes[dim] = input.size(dim);
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
[computeEncoder setComputePipelineState:pipeline_state];
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
}
cat_dim_offset += input.size(dimension);
input_idx++;
}
}
} // namespace mps
// topk
@ -231,7 +306,11 @@ TORCH_IMPL_FUNC(cat_out_mps)
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
idx = 0;
bool has_large_tensor = false;
for (const Tensor& tensor : materialized_inputs) {
if (isTooLargeForMPSGraph(tensor)) {
has_large_tensor |= true;
}
if (!should_skip(tensor)) {
// TODO: Factor out `check_shape_except_dim`
check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
@ -249,6 +328,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
return;
}
has_large_tensor |= isTooLargeForMPSGraph(out);
if (has_large_tensor) {
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;

View File

@ -80,6 +80,9 @@ if not torch.backends.mps.is_available():
total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble]
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
@ -3637,6 +3640,70 @@ class TestMPS(TestCaseMPS):
# TODO: enable memory format test
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
# Skip if a test needs more memory than the system has.
def _skip_if_exceeds_total_memory(self, required_memory):
if total_memory < required_memory:
self.skipTest(
f"Needs {required_memory / (1024**3):0.01f} GiB RAM, "
f"but only {total_memory / (1024**3):0.01f} GiB is available.")
@parametrize("dtype", MPS_DTYPES)
def test_cat_large_tensor(self, dtype):
a_shape = (1, 11 + (1 << 31), 1)
b_shape = (1, 100, 1)
# Assume up to 1% extra overhead memory might be required.
required_memory = 1.01 * (math.prod(a_shape) + math.prod(a_shape)) * dtype.itemsize
self._skip_if_exceeds_total_memory(required_memory)
a_cpu = make_tensor((1,), dtype=dtype, device='cpu').expand(a_shape)
b_cpu = make_tensor(b_shape, dtype=dtype, device='cpu')
r_cpu = torch.cat([a_cpu, b_cpu], dim=1)
# Pick a subset of output elements to compare, because comparing all of
# them takes too long.
rand_indices = torch.randint(0, a_cpu.shape[1] + b_cpu.shape[1], (10_000,))
r_cpu_part0 = r_cpu[:, rand_indices, :].clone()
r_cpu_part1 = r_cpu[:, -200:, :].clone()
r_cpu_part2 = r_cpu[:, :200, :].clone()
# Delete the CPU result to free up memory for the MPS run.
del r_cpu
a_mps = (
torch.empty(0, dtype=dtype, device='mps')
.set_(a_cpu.untyped_storage().mps())
.as_strided(size=a_cpu.size(), stride=a_cpu.stride())
)
b_mps = b_cpu.to('mps')
try:
r_mps = torch.cat([a_mps, b_mps], dim=1)
except RuntimeError as e:
if "Invalid buffer size" in str(e):
self.skipTest(f"Exceeds max buffer size for MPS: {str(e)}.")
raise e
self.assertEqual(r_mps[:, rand_indices, :], r_cpu_part0)
self.assertEqual(r_mps[:, -200:, :], r_cpu_part1)
self.assertEqual(r_mps[:, :200, :], r_cpu_part2)
def test_large_tensor_to_string(self):
shape = (2, 1 << 31)
# Assume up to 1% extra overhead memory might be required.
required_memory = 1.01 * 2 * math.prod(shape)
self._skip_if_exceeds_total_memory(required_memory)
self.assertEqual(
str(torch.ones(shape, dtype=torch.int8, device='mps')),
(
"tensor([[1, 1, 1, ..., 1, 1, 1],\n"
" [1, 1, 1, ..., 1, 1, 1]], device='mps:0', dtype=torch.int8)"
),
)
# See https://github.com/pytorch/pytorch/issues/152701
def test_jacfwd_cat(self):
def fn(x, y):
@ -12167,9 +12234,6 @@ class TestNoRegression(TestCase):
self.assertEqual(x2.device.type, "mps")
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble]
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
MPS_GRAD_DTYPES = [torch.float32, torch.float16]