Throw invalid_argument instead of RuntimeError when parameters exceed… (#158267)

Throw invalid_argument instead of RuntimeError when parameters exceed limits (for torch.int32 dtype)

Fixes #157707

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158267
Approved by: https://github.com/albanD
This commit is contained in:
gaoyufeng
2025-07-25 23:49:42 +00:00
committed by PyTorch MergeBot
parent 21a95bdf7c
commit f5cf05c983
4 changed files with 14 additions and 18 deletions

View File

@ -488,7 +488,7 @@ class TestNumPyInterop(TestCase):
) # type: ignore[call-overload]
else:
self.assertRaisesRegex(
RuntimeError,
ValueError,
"(Overflow|an integer is required)",
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
) # type: ignore[call-overload]

View File

@ -9432,7 +9432,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long long'):
with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'):
torch.manual_seed(invalid_seed)
torch.set_rng_state(rng_state)
@ -10851,8 +10851,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_invalid_arg_error_handling(self) -> None:
""" Tests that errors from old TH functions are propagated back """
for invalid_val in [-1, 2**65]:
self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val))
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val))
def _get_tensor_prop(self, t):
preserved = (

View File

@ -219,7 +219,7 @@ class TestIndexing(TestCase):
assert_raises(IndexError, a.__getitem__, 1 << 30)
# Index overflow produces IndexError
# Note torch raises RuntimeError here
assert_raises((IndexError, RuntimeError), a.__getitem__, 1 << 64)
assert_raises((IndexError, ValueError), a.__getitem__, 1 << 64)
def test_single_bool_index(self):
# Single boolean index

View File

@ -62,13 +62,11 @@ inline int32_t THPUtils_unpackInt(PyObject* obj) {
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
if (overflow != 0) {
throw std::runtime_error("Overflow when unpacking long");
}
if (value > std::numeric_limits<int32_t>::max() ||
value < std::numeric_limits<int32_t>::min()) {
throw std::runtime_error("Overflow when unpacking long");
}
TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long");
TORCH_CHECK_VALUE(
value <= std::numeric_limits<int32_t>::max() &&
value >= std::numeric_limits<int32_t>::min(),
"Overflow when unpacking long");
return (int32_t)value;
}
@ -78,9 +76,7 @@ inline int64_t THPUtils_unpackLong(PyObject* obj) {
if (value == -1 && PyErr_Occurred()) {
throw python_error();
}
if (overflow != 0) {
throw std::runtime_error("Overflow when unpacking long long");
}
TORCH_CHECK_VALUE(overflow == 0, "Overflow when unpacking long long");
return (int64_t)value;
}
@ -89,9 +85,9 @@ inline uint32_t THPUtils_unpackUInt32(PyObject* obj) {
if (PyErr_Occurred()) {
throw python_error();
}
if (value > std::numeric_limits<uint32_t>::max()) {
throw std::runtime_error("Overflow when unpacking unsigned long");
}
TORCH_CHECK_VALUE(
value <= std::numeric_limits<uint32_t>::max(),
"Overflow when unpacking long long");
return (uint32_t)value;
}