Raise BufferError for DLPack buffer-related errors. (#150691)

This PR addresses the Array API documentation for [`__dlpack__`][1] and
[`from_dlpack`][2] by making some buffer-related errors `BufferError`
instead of `RuntimeError`, e.g. incompatible dtype, strides, or device.

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
[2]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html#from-dlpack
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150691
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #150216, #150217, #150218
This commit is contained in:
Yukio Siraichi
2025-07-19 16:36:08 -03:00
committed by PyTorch MergeBot
parent a10f15718d
commit b4abf41425
5 changed files with 60 additions and 25 deletions

View File

@ -69,29 +69,29 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::Float8_e4m3fn:
case ScalarType::Float8_e4m3fnuz:
case ScalarType::Float8_e8m0fnu:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
break;
case ScalarType::Float4_e2m1fn_x2:
TORCH_CHECK(false, "float4 types are not supported by dlpack");
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType");
}
return dtype;
}
@ -133,7 +133,7 @@ DLDevice torchDeviceToDLDevice(at::Device device) {
ctx.device_type = DLDeviceType::kDLExtDev;
break;
default:
TORCH_CHECK(false, "Cannot pack tensors on " + device.str());
TORCH_CHECK_BUFFER(false, "Cannot pack tensors on " + device.str());
}
return ctx;
@ -165,14 +165,14 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
case DLDeviceType::kDLExtDev:
return at::Device(DeviceType::PrivateUse1, index);
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported device_type: ", std::to_string(type));
}
}
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype = ScalarType::Undefined;
TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
@ -189,7 +189,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::UInt64;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
}
break;
@ -208,7 +208,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::Long;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kInt bits ", std::to_string(dtype.bits));
}
break;
@ -224,7 +224,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::Double;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
}
break;
@ -234,7 +234,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::BFloat16;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
}
break;
@ -250,7 +250,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::ComplexDouble;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
}
break;
@ -260,12 +260,12 @@ ScalarType toScalarType(const DLDataType& dtype) {
stype = ScalarType::Bool;
break;
default:
TORCH_CHECK(
TORCH_CHECK_BUFFER(
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
}
return stype;
}

View File

@ -267,6 +267,13 @@ class C10_API NotImplementedError : public Error {
using Error::Error;
};
// Used in ATen for buffer-related errors, e.g. trying to create a DLPack of
// an unsupported device. These turn into BufferError when they cross to
// Python.
class C10_API BufferError : public Error {
using Error::Error;
};
// Used in ATen for non finite indices. These turn into
// ExitException when they cross to Python.
class C10_API EnforceFiniteError : public Error {
@ -635,6 +642,10 @@ namespace c10::detail {
#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__)
// Like TORCH_CHECK, but raises BufferError instead of Errors.
#define TORCH_CHECK_BUFFER(cond, ...) \
TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__)
#define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \
TORCH_CHECK_WITH_MSG( \
ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__)

View File

@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TestCase,
)
from torch.utils.dlpack import from_dlpack, to_dlpack
from torch.utils.dlpack import DLDeviceType, from_dlpack, to_dlpack
# Wraps a tensor, exposing only DLPack methods:
@ -304,21 +304,21 @@ class TestTorchDlPack(TestCase):
@skipMeta
def test_dlpack_export_requires_grad(self):
x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, r"require gradient"):
with self.assertRaisesRegex(BufferError, r"require gradient"):
x.__dlpack__()
@skipMeta
def test_dlpack_export_is_conj(self):
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
y = torch.conj(x)
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
with self.assertRaisesRegex(BufferError, r"conjugate bit"):
y.__dlpack__()
@skipMeta
def test_dlpack_export_non_strided(self):
x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
y = torch.conj(x)
with self.assertRaisesRegex(RuntimeError, r"strided"):
with self.assertRaisesRegex(BufferError, r"strided"):
y.__dlpack__()
@skipMeta
@ -459,6 +459,29 @@ class TestTorchDlPack(TestCase):
with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"):
self._test_from_dlpack(device, out_device="cpu", copy=False)
@skipMeta
@onlyNativeDeviceTypes
def test_unsupported_device_error(self, device):
inp = make_tensor((5,), dtype=torch.float32, device=device)
dl_device_type = DLDeviceType.kDLHexagon
with self.assertRaisesRegex(
BufferError, f"Unsupported device_type: {int(dl_device_type)}"
):
inp.__dlpack__(max_version=(1, 0), dl_device=(dl_device_type, 0))
@skipMeta
@onlyCPU
def test_dlpack_unsupported_dtype_error(self, device):
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
torch.float8_e4m3fn
)
with self.assertRaisesRegex(
BufferError, ".* types are not supported by dlpack"
):
from_dlpack(inp)
instantiate_device_type_tests(TestTorchDlPack, globals())

View File

@ -1709,13 +1709,13 @@ class Tensor(torch._C.TensorBase):
# so we prohibit exporting tensors that would lose their properties like
# requires_grad and having the conjugate bit set.
if self.requires_grad:
raise RuntimeError(
raise BufferError(
"Can't export tensors that require gradient, use tensor.detach()"
)
if self.is_conj():
raise RuntimeError("Can't export tensors with the conjugate bit set")
raise BufferError("Can't export tensors with the conjugate bit set")
if self.layout != torch.strided:
raise RuntimeError(
raise BufferError(
"Can't export tensors with layout other than torch.strided"
)
@ -1724,8 +1724,8 @@ class Tensor(torch._C.TensorBase):
and self.device.index != torch.cuda.current_device()
):
raise BufferError(
"Can't export tensors on a different CUDA device. "
f"Expected: {self.device}. "
"Can't export tensors on a different CUDA device index. "
f"Expected: {self.device.index}. "
f"Current device: {torch.cuda.current_device()}."
)

View File

@ -74,6 +74,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
_CATCH_GENERIC_ERROR( \
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
_CATCH_GENERIC_ERROR(BufferError, PyExc_BufferError, retstmnt) \
_CATCH_GENERIC_ERROR(SyntaxError, PyExc_SyntaxError, retstmnt) \
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
_CATCH_GENERIC_ERROR( \