[MPS] Do not crash if tensor dim > INT_MAX (#158824)

Looks like all MPS operations will crash if one of tensor dimentions are
greater than `2**31-1`

Change it into a structured exception, by checking tensor size before
attempting to create MPS Tensor

Add regression test for it. Before this change running following will abort with exception
```
% python3 -c "import torch; torch.randint(0, 10, (2**31,), dtype=torch.uint8, device='mps')"
/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
zsh: abort      python3 -c·
```

Skip the test on MacOS-13, as it crashes somewhere deep in MPSGraph framework with
```
/AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158824
Approved by: https://github.com/dcci
ghstack dependencies: #158690, #158823
This commit is contained in:
Nikita Shulga
2025-07-22 07:36:01 -07:00
committed by PyTorch MergeBot
parent 371ffaf415
commit d0c00d9a69
2 changed files with 28 additions and 3 deletions

View File

@ -427,12 +427,22 @@ static MPSNDArray* permuteNDArray(MPSNDArray* inArray, const std::vector<int64_t
return result;
}
// Should be called before initWithBuffer to prevent hard crashes with
// '[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
static void check_mps_shape(MPSShape* shape) {
for (NSNumber* elem in shape) {
const auto val = [elem longValue];
TORCH_CHECK(val <= std::numeric_limits<int32_t>::max(), "MPSGaph does not support tensor dims larger than INT_MAX");
}
}
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
MPSDataType mpsDataType = getMPSDataType(t.scalar_type());
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes];
srcTensorDesc.preferPackedRows = YES;
check_mps_shape(sizes);
MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
offset:t.storage_offset() * t.element_size()
descriptor:srcTensorDesc] autorelease];
@ -542,9 +552,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
// Tensor is contiguous and has no storage offset.
// Wrap it directly inside MPSGraphTensorData
if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) {
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor)
dataType:dataType] autorelease];
auto shape = mpsShape_ ? mpsShape_ : getMPSShape(_tensor);
check_mps_shape(shape);
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:shape dataType:dataType] autorelease];
} else {
IntArrayRef view_shape;
if (mpsShape_) {
@ -553,8 +563,11 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
MPSShape* mpsShape = getMPSShape(_tensor);
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
check_mps_shape(mpsShape);
auto storage_numel = src.storage().nbytes() / src.element_size();
TORCH_CHECK(storage_numel <= std::numeric_limits<int32_t>::max(),
"MPSGaph does not support tensor dims larger than INT_MAX");
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
shape:@[ @(storage_numel) ]];
srcTensorDesc.preferPackedRows = YES;

View File

@ -8013,6 +8013,18 @@ class TestLargeTensors(TestCaseMPS):
gc.collect()
torch.mps.empty_cache()
@serialTest()
def test_rand_2b_raises(self):
if MACOS_VERSION < 14.0:
raise unittest.SkipTest("Crashes on MacOS-13")
int32_max = torch.iinfo(torch.int32).max
with self.assertRaises(RuntimeError):
# This used to crash with NDArray dimension length > INT_MAX
x = torch.randint(0, 10, (int32_max + 1,), dtype=torch.int8, device='mps')
x = torch.randint(0, 10, (int32_max,), dtype=torch.int8, device='mps')
self.assertEqual(x.numel(), int32_max)
del x
class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):