mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Throw invalid_argument instead of RuntimeError when parameters exceed… (#158267)
Throw invalid_argument instead of RuntimeError when parameters exceed limits (for torch.int32 dtype) Fixes #157707 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158267 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
21a95bdf7c
commit
f5cf05c983
@ -9432,7 +9432,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead")
|
||||
self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg)
|
||||
for invalid_seed in [min_int64 - 1, max_uint64 + 1]:
|
||||
with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long long'):
|
||||
with self.assertRaisesRegex(ValueError, r'Overflow when unpacking long long'):
|
||||
torch.manual_seed(invalid_seed)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@ -10851,8 +10851,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
def test_invalid_arg_error_handling(self) -> None:
|
||||
""" Tests that errors from old TH functions are propagated back """
|
||||
for invalid_val in [-1, 2**65]:
|
||||
self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val))
|
||||
self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val))
|
||||
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_threads(invalid_val))
|
||||
self.assertRaises((ValueError, RuntimeError), lambda: torch.set_num_interop_threads(invalid_val))
|
||||
|
||||
def _get_tensor_prop(self, t):
|
||||
preserved = (
|
||||
|
Reference in New Issue
Block a user