Throw invalid_argument instead of RuntimeError when parameters exceed… (#158267)

Throw invalid_argument instead of RuntimeError when parameters exceed limits (for torch.int32 dtype)

Fixes #157707

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158267
Approved by: https://github.com/albanD
This commit is contained in:
gaoyufeng
2025-07-25 23:49:42 +00:00
committed by PyTorch MergeBot
parent 21a95bdf7c
commit f5cf05c983
4 changed files with 14 additions and 18 deletions

View File

@ -9432,7 +9432,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long long'):
with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'):
torch.manual_seed(invalid_seed)
torch.set_rng_state(rng_state)
@ -10851,8 +10851,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
def test_invalid_arg_error_handling(self) -> None:
""" Tests that errors from old TH functions are propagated back """
for invalid_val in [-1, 2**65]:
self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val))
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val))
def _get_tensor_prop(self, t):
preserved = (