mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
unify broadcast_shapes functions and avoid duplicates (#160251)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160251 Approved by: https://github.com/jingsh, https://github.com/ColinPeppler ghstack dependencies: #160250
This commit is contained in:
committed by
PyTorch MergeBot
parent
c03809e8a5
commit
65dc4df74d
@ -1656,7 +1656,7 @@ class TestOldViewOps(TestCase):
|
||||
inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]]
|
||||
for integral_inputs_with_neg_vals in inputs_with_neg_vals:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError, "Attempting to broadcast a dimension with negative length!"
|
||||
):
|
||||
torch.broadcast_shapes(*integral_inputs_with_neg_vals)
|
||||
|
||||
@ -1664,20 +1664,21 @@ class TestOldViewOps(TestCase):
|
||||
for error_input in integral_inputs_error_case:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Shape mismatch: objects cannot be broadcast to a single shape",
|
||||
".*expected shape should be broadcastable to*",
|
||||
):
|
||||
torch.broadcast_shapes(*error_input)
|
||||
|
||||
negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
|
||||
for s0 in negative_inputs:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError, "Attempting to broadcast a dimension with negative length!"
|
||||
):
|
||||
torch.broadcast_shapes(s0)
|
||||
|
||||
for s1 in negative_inputs:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError,
|
||||
"Attempting to broadcast a dimension with negative length!",
|
||||
):
|
||||
torch.broadcast_shapes(s0, s1)
|
||||
|
||||
|
Reference in New Issue
Block a user