mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw invalid_argument instead of RuntimeError when parameters exceed… (#158267)
Throw invalid_argument instead of RuntimeError when parameters exceed limits (for torch.int32 dtype) Fixes #157707 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158267 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
21a95bdf7c
commit
f5cf05c983
@ -488,7 +488,7 @@ class TestNumPyInterop(TestCase):
|
||||
) # type: ignore[call-overload]
|
||||
else:
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
"(Overflow|an integer is required)",
|
||||
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
|
||||
) # type: ignore[call-overload]
|
||||
|
@ -9432,7 +9432,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
|
||||
self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
|
||||
for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
|
||||
with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long long'):
|
||||
with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'):
|
||||
torch.manual_seed(invalid_seed)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@ -10851,8 +10851,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
def test_invalid_arg_error_handling(self) -> None:
|
||||
""" Tests that errors from old TH functions are propagated back """
|
||||
for invalid_val in [-1, 2**65]:
|
||||
self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
|
||||
self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
|
||||
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val))
|
||||
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val))
|
||||
|
||||
def _get_tensor_prop(self, t):
|
||||
preserved = (
|
||||
|
@ -219,7 +219,7 @@ class TestIndexing(TestCase):
|
||||
assert_raises(IndexError, a.__getitem__, 1 << 30)
|
||||
# Index overflow produces IndexError
|
||||
# Note torch raises RuntimeError here
|
||||
assert_raises((IndexError, RuntimeError), a.__getitem__, 1 << 64)
|
||||
assert_raises((IndexError, ValueError), a.__getitem__, 1 << 64)
|
||||
|
||||
def test_single_bool_index(self):
|
||||
# Single boolean index
|
||||
|
@ -62,13 +62,11 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) {
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
if (overflow != 0) {
|
||||
throw std::runtime_error("Overflow when unpacking long");
|
||||
}
|
||||
if (value > std::numeric_limits<int32_t>::max() ||
|
||||
value < std::numeric_limits<int32_t>::min()) {
|
||||
throw std::runtime_error("Overflow when unpacking long");
|
||||
}
|
||||
TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long");
|
||||
TORCH_CHECK_VALUE(
|
||||
value <= std::numeric_limits<int32_t>::max() &&
|
||||
value >= std::numeric_limits<int32_t>::min(),
|
||||
"Overflow when unpacking long");
|
||||
return (int32_t)value;
|
||||
}
|
||||
|
||||
@ -78,9 +76,7 @@ inline int64_t THPUtils_unpackLong(PyObject* obj) {
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
if (overflow != 0) {
|
||||
throw std::runtime_error("Overflow when unpacking long long");
|
||||
}
|
||||
TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long");
|
||||
return (int64_t)value;
|
||||
}
|
||||
|
||||
@ -89,9 +85,9 @@ inline uint32_t THPUtils_unpackUInt32(PyObject* obj) {
|
||||
if (PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
}
|
||||
if (value > std::numeric_limits<uint32_t>::max()) {
|
||||
throw std::runtime_error("Overflow when unpacking unsigned long");
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
value <= std::numeric_limits<uint32_t>::max(),
|
||||
"Overflow when unpacking long long");
|
||||
return (uint32_t)value;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user