mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Raise BufferError
for DLPack buffer-related errors. (#150691)
This PR addresses the Array API documentation for [`__dlpack__`][1] and [`from_dlpack`][2] by making some buffer-related errors `BufferError` instead of `RuntimeError`, e.g. incompatible dtype, strides, or device. [1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html [2]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html#from-dlpack Pull Request resolved: https://github.com/pytorch/pytorch/pull/150691 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #150216, #150217, #150218
This commit is contained in:
committed by
PyTorch MergeBot
parent
a10f15718d
commit
b4abf41425
@ -69,29 +69,29 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
case ScalarType::Float8_e4m3fnuz:
|
||||
case ScalarType::Float8_e8m0fnu:
|
||||
TORCH_CHECK(false, "float8 types are not supported by dlpack");
|
||||
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Float4_e2m1fn_x2:
|
||||
TORCH_CHECK(false, "float4 types are not supported by dlpack");
|
||||
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QInt32:
|
||||
case ScalarType::QUInt4x2:
|
||||
case ScalarType::QUInt2x4:
|
||||
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
|
||||
TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Bits1x8:
|
||||
case ScalarType::Bits2x4:
|
||||
case ScalarType::Bits4x2:
|
||||
case ScalarType::Bits8:
|
||||
case ScalarType::Bits16:
|
||||
TORCH_CHECK(false, "Bit types are not supported by dlpack");
|
||||
TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Undefined:
|
||||
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
|
||||
TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType");
|
||||
case ScalarType::NumOptions:
|
||||
TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
|
||||
TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType");
|
||||
}
|
||||
return dtype;
|
||||
}
|
||||
@ -133,7 +133,7 @@ DLDevice torchDeviceToDLDevice(at::Device device) {
|
||||
ctx.device_type = DLDeviceType::kDLExtDev;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Cannot pack tensors on " + device.str());
|
||||
TORCH_CHECK_BUFFER(false, "Cannot pack tensors on " + device.str());
|
||||
}
|
||||
|
||||
return ctx;
|
||||
@ -165,14 +165,14 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
|
||||
case DLDeviceType::kDLExtDev:
|
||||
return at::Device(DeviceType::PrivateUse1, index);
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported device_type: ", std::to_string(type));
|
||||
}
|
||||
}
|
||||
|
||||
ScalarType toScalarType(const DLDataType& dtype) {
|
||||
ScalarType stype = ScalarType::Undefined;
|
||||
TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
|
||||
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
switch (dtype.bits) {
|
||||
@ -189,7 +189,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::UInt64;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
@ -208,7 +208,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::Long;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kInt bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
@ -224,7 +224,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::Double;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
@ -234,7 +234,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::BFloat16;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
@ -250,7 +250,7 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::ComplexDouble;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
@ -260,12 +260,12 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
stype = ScalarType::Bool;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code));
|
||||
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
|
||||
}
|
||||
return stype;
|
||||
}
|
||||
|
@ -267,6 +267,13 @@ class C10_API NotImplementedError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
// Used in ATen for buffer-related errors, e.g. trying to create a DLPack of
|
||||
// an unsupported device. These turn into BufferError when they cross to
|
||||
// Python.
|
||||
class C10_API BufferError : public Error {
|
||||
using Error::Error;
|
||||
};
|
||||
|
||||
// Used in ATen for non finite indices. These turn into
|
||||
// ExitException when they cross to Python.
|
||||
class C10_API EnforceFiniteError : public Error {
|
||||
@ -635,6 +642,10 @@ namespace c10::detail {
|
||||
#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
||||
TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__)
|
||||
|
||||
// Like TORCH_CHECK, but raises BufferError instead of Errors.
|
||||
#define TORCH_CHECK_BUFFER(cond, ...) \
|
||||
TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__)
|
||||
|
||||
#define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \
|
||||
TORCH_CHECK_WITH_MSG( \
|
||||
ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__)
|
||||
|
@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
from torch.utils.dlpack import DLDeviceType, from_dlpack, to_dlpack
|
||||
|
||||
|
||||
# Wraps a tensor, exposing only DLPack methods:
|
||||
@ -304,21 +304,21 @@ class TestTorchDlPack(TestCase):
|
||||
@skipMeta
|
||||
def test_dlpack_export_requires_grad(self):
|
||||
x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, r"require gradient"):
|
||||
with self.assertRaisesRegex(BufferError, r"require gradient"):
|
||||
x.__dlpack__()
|
||||
|
||||
@skipMeta
|
||||
def test_dlpack_export_is_conj(self):
|
||||
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
|
||||
y = torch.conj(x)
|
||||
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
|
||||
with self.assertRaisesRegex(BufferError, r"conjugate bit"):
|
||||
y.__dlpack__()
|
||||
|
||||
@skipMeta
|
||||
def test_dlpack_export_non_strided(self):
|
||||
x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
|
||||
y = torch.conj(x)
|
||||
with self.assertRaisesRegex(RuntimeError, r"strided"):
|
||||
with self.assertRaisesRegex(BufferError, r"strided"):
|
||||
y.__dlpack__()
|
||||
|
||||
@skipMeta
|
||||
@ -459,6 +459,29 @@ class TestTorchDlPack(TestCase):
|
||||
with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"):
|
||||
self._test_from_dlpack(device, out_device="cpu", copy=False)
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_unsupported_device_error(self, device):
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device)
|
||||
dl_device_type = DLDeviceType.kDLHexagon
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, f"Unsupported device_type: {int(dl_device_type)}"
|
||||
):
|
||||
inp.__dlpack__(max_version=(1, 0), dl_device=(dl_device_type, 0))
|
||||
|
||||
@skipMeta
|
||||
@onlyCPU
|
||||
def test_dlpack_unsupported_dtype_error(self, device):
|
||||
inp = make_tensor((5,), dtype=torch.float32, device=device).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
BufferError, ".* types are not supported by dlpack"
|
||||
):
|
||||
from_dlpack(inp)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTorchDlPack, globals())
|
||||
|
||||
|
@ -1709,13 +1709,13 @@ class Tensor(torch._C.TensorBase):
|
||||
# so we prohibit exporting tensors that would lose their properties like
|
||||
# requires_grad and having the conjugate bit set.
|
||||
if self.requires_grad:
|
||||
raise RuntimeError(
|
||||
raise BufferError(
|
||||
"Can't export tensors that require gradient, use tensor.detach()"
|
||||
)
|
||||
if self.is_conj():
|
||||
raise RuntimeError("Can't export tensors with the conjugate bit set")
|
||||
raise BufferError("Can't export tensors with the conjugate bit set")
|
||||
if self.layout != torch.strided:
|
||||
raise RuntimeError(
|
||||
raise BufferError(
|
||||
"Can't export tensors with layout other than torch.strided"
|
||||
)
|
||||
|
||||
@ -1724,8 +1724,8 @@ class Tensor(torch._C.TensorBase):
|
||||
and self.device.index != torch.cuda.current_device()
|
||||
):
|
||||
raise BufferError(
|
||||
"Can't export tensors on a different CUDA device. "
|
||||
f"Expected: {self.device}. "
|
||||
"Can't export tensors on a different CUDA device index. "
|
||||
f"Expected: {self.device.index}. "
|
||||
f"Current device: {torch.cuda.current_device()}."
|
||||
)
|
||||
|
||||
|
@ -74,6 +74,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) {
|
||||
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(BufferError, PyExc_BufferError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(SyntaxError, PyExc_SyntaxError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
|
||||
_CATCH_GENERIC_ERROR( \
|
||||
|
Reference in New Issue
Block a user