mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[MPS] Do not crash if tensor dim > INT_MAX (#158824)
Looks like all MPS operations will crash if one of tensor dimentions are greater than `2**31-1` Change it into a structured exception, by checking tensor size before attempting to create MPS Tensor Add regression test for it. Before this change running following will abort with exception ``` % python3 -c "import torch; torch.randint(0, 10, (2**31,), dtype=torch.uint8, device='mps')" /AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' zsh: abort python3 -c· ``` Skip the test on MacOS-13, as it crashes somewhere deep in MPSGraph framework with ``` /AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158824 Approved by: https://github.com/dcci ghstack dependencies: #158690, #158823
This commit is contained in:
committed by
PyTorch MergeBot
parent
371ffaf415
commit
d0c00d9a69
@ -427,12 +427,22 @@ static MPSNDArray* permuteNDArray(MPSNDArray* inArray, const std::vector<int64_t
|
||||
return result;
|
||||
}
|
||||
|
||||
// Should be called before initWithBuffer to prevent hard crashes with
|
||||
// '[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
|
||||
static void check_mps_shape(MPSShape* shape) {
|
||||
for (NSNumber* elem in shape) {
|
||||
const auto val = [elem longValue];
|
||||
TORCH_CHECK(val <= std::numeric_limits<int32_t>::max(), "MPSGaph does not support tensor dims larger than INT_MAX");
|
||||
}
|
||||
}
|
||||
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) {
|
||||
id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
|
||||
|
||||
MPSDataType mpsDataType = getMPSDataType(t.scalar_type());
|
||||
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes];
|
||||
srcTensorDesc.preferPackedRows = YES;
|
||||
check_mps_shape(sizes);
|
||||
MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
|
||||
offset:t.storage_offset() * t.element_size()
|
||||
descriptor:srcTensorDesc] autorelease];
|
||||
@ -542,9 +552,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
// Tensor is contiguous and has no storage offset.
|
||||
// Wrap it directly inside MPSGraphTensorData
|
||||
if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) {
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
|
||||
shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor)
|
||||
dataType:dataType] autorelease];
|
||||
auto shape = mpsShape_ ? mpsShape_ : getMPSShape(_tensor);
|
||||
check_mps_shape(shape);
|
||||
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:shape dataType:dataType] autorelease];
|
||||
} else {
|
||||
IntArrayRef view_shape;
|
||||
if (mpsShape_) {
|
||||
@ -553,8 +563,11 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
|
||||
MPSShape* mpsShape = getMPSShape(_tensor);
|
||||
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
|
||||
check_mps_shape(mpsShape);
|
||||
|
||||
auto storage_numel = src.storage().nbytes() / src.element_size();
|
||||
TORCH_CHECK(storage_numel <= std::numeric_limits<int32_t>::max(),
|
||||
"MPSGaph does not support tensor dims larger than INT_MAX");
|
||||
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
|
||||
shape:@[ @(storage_numel) ]];
|
||||
srcTensorDesc.preferPackedRows = YES;
|
||||
|
@ -8013,6 +8013,18 @@ class TestLargeTensors(TestCaseMPS):
|
||||
gc.collect()
|
||||
torch.mps.empty_cache()
|
||||
|
||||
@serialTest()
|
||||
def test_rand_2b_raises(self):
|
||||
if MACOS_VERSION < 14.0:
|
||||
raise unittest.SkipTest("Crashes on MacOS-13")
|
||||
int32_max = torch.iinfo(torch.int32).max
|
||||
with self.assertRaises(RuntimeError):
|
||||
# This used to crash with NDArray dimension length > INT_MAX
|
||||
x = torch.randint(0, 10, (int32_max + 1,), dtype=torch.int8, device='mps')
|
||||
x = torch.randint(0, 10, (int32_max,), dtype=torch.int8, device='mps')
|
||||
self.assertEqual(x.numel(), int32_max)
|
||||
del x
|
||||
|
||||
|
||||
class TestLogical(TestCaseMPS):
|
||||
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
|
||||
|
Reference in New Issue
Block a user