mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fix] torch.broadcast_shapes
should not handle shapes with negative dimensions. (#72999)
Summary: Hi, The PR fixes https://github.com/pytorch/pytorch/issues/68957. It aims to include the following: - Fixes the code in `torch/functional.py`. - Add the missing tests for negative input values and non-iterable inputs. ~#### TODO~ ~- [x] Add OpInfo~ EDIT: `broadcast_shapes` don't take any tensor inputs. So we don't need OpInfo here. Thanks, kshitij12345 for guidance. #### Earlier ```python >>> shapes = [1, -12] >>> torch.broadcast_shapes(*shapes) torch.Size([-12]) # MUST RAISE ERROR ``` #### Now ```python >>> shapes = [1, -12] >>> torch.broadcast_shapes(*shapes) RuntimeError: Trying to create tensor with negative dimension -12: [-12] ``` #### NumPy's Output ```python >>> shapes = [1, -12] >>> numpy.broadcast_shapes(*shapes) ValueError: negative dimensions are not allowed ``` #### `torch.broadcast_tensor()` Output As mentioned in the [doc](https://pytorch.org/docs/stable/generated/torch.broadcast_shapes.html): ```python >>> shapes = [1, -12] >>> torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape RuntimeError: Trying to create tensor with negative dimension -12: [-12] ``` Looking forward to hearing from you and your questions. Thanks! :) cc: mruberry kshitij12345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/72999 Reviewed By: albanD Differential Revision: D34543995 Pulled By: ngimel fbshipit-source-id: e32b1f266500a5e002c8f353b1e02f44c23d4f6e (cherry picked from commit a6253ce6bb8455a3c89398f12b7d790a0b7e8d95)
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ad0578c59
commit
905efa82ff
@ -1462,6 +1462,78 @@ class TestOldViewOps(TestCase):
|
||||
actual = torch.broadcast_shapes(s0, s1)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
|
||||
for integral_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(*integral_inputs)
|
||||
res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11, ]]
|
||||
for integral_inputs_with_neg_vals in inputs_with_neg_vals:
|
||||
with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"):
|
||||
torch.broadcast_shapes(*integral_inputs_with_neg_vals)
|
||||
|
||||
integral_inputs_error_case = [(3, 5), (2, 4, 1)]
|
||||
for error_input in integral_inputs_error_case:
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape mismatch: objects cannot be broadcast to a single shape"):
|
||||
torch.broadcast_shapes(*error_input)
|
||||
|
||||
negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
|
||||
for s0 in negative_inputs:
|
||||
with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"):
|
||||
torch.broadcast_shapes(s0)
|
||||
|
||||
for s1 in negative_inputs:
|
||||
with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"):
|
||||
torch.broadcast_shapes(s0, s1)
|
||||
|
||||
float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)]
|
||||
for error_case in float_inputs_error_case:
|
||||
for float_input in error_case:
|
||||
with self.assertRaisesRegex(RuntimeError, "Input shapes "
|
||||
"should be of type ints, a tuple of ints, or a list of ints"):
|
||||
torch.broadcast_shapes(float_input)
|
||||
|
||||
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
|
||||
for s0 in diff_input_types:
|
||||
res1 = torch.broadcast_shapes(*s0)
|
||||
res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@unittest.skipIf(np.__version__ < '1.20',
|
||||
"NumPy does not support broadcast_shapes before the 1.20 version")
|
||||
@onlyCPU
|
||||
def test_broadcast_shapes_numpy_ref(self, device):
|
||||
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
|
||||
for s0 in examples:
|
||||
x0 = torch.randn(s0)
|
||||
actual = torch.broadcast_shapes(s0)
|
||||
numpy_expected = np.broadcast_shapes(s0)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
for s1 in examples:
|
||||
x1 = torch.randn(s1)
|
||||
actual = torch.broadcast_shapes(s0, s1)
|
||||
numpy_expected = np.broadcast_shapes(s0, s1)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
|
||||
for integral_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(*integral_inputs)
|
||||
res2_numpy = np.broadcast_shapes(*integral_inputs)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
for list_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(list_inputs)
|
||||
res2 = np.broadcast_shapes(list_inputs)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
|
||||
for s0 in diff_input_types:
|
||||
res1 = torch.broadcast_shapes(*s0)
|
||||
res2_numpy = np.broadcast_shapes(*s0)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*get_all_dtypes(include_bfloat16=False))
|
||||
def test_broadcast_to(self, device, dtype):
|
||||
|
@ -103,11 +103,41 @@ def broadcast_shapes(*shapes):
|
||||
"""
|
||||
# This wrapper exists to support variadic args.
|
||||
# TODO Movie this to C++ once the jit has better support for torch.Size.
|
||||
with torch.no_grad():
|
||||
scalar = torch.zeros((), device="cpu")
|
||||
tensors = [scalar.expand(shape) for shape in shapes]
|
||||
tensors = broadcast_tensors(*tensors)
|
||||
return tensors[0].shape
|
||||
if not torch.jit.is_tracing():
|
||||
max_len = 0
|
||||
for shape in shapes:
|
||||
if isinstance(shape, int):
|
||||
if max_len < 1:
|
||||
max_len = 1
|
||||
elif isinstance(shape, tuple) or isinstance(shape, list):
|
||||
s = len(shape)
|
||||
if max_len < s:
|
||||
max_len = s
|
||||
result = [1] * max_len
|
||||
for shape in shapes:
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, tuple) or isinstance(shape, list):
|
||||
for i in range(-1, -1 - len(shape), -1):
|
||||
if shape[i] < 0:
|
||||
raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})"
|
||||
.format(shape[i], shape[i]))
|
||||
if shape[i] == 1 or shape[i] == result[i]:
|
||||
continue
|
||||
if result[i] != 1:
|
||||
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
|
||||
result[i] = shape[i]
|
||||
else:
|
||||
raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
|
||||
return torch.Size(result)
|
||||
else:
|
||||
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
|
||||
with torch.no_grad():
|
||||
scalar = torch.zeros((), device="cpu")
|
||||
tensors = [scalar.expand(shape) for shape in shapes]
|
||||
tensors = broadcast_tensors(*tensors)
|
||||
return tensors[0].shape
|
||||
|
||||
|
||||
|
||||
def split(tensor, split_size_or_sections, dim=0):
|
||||
|
@ -8739,12 +8739,14 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
supports_out=False),
|
||||
OpInfo('broadcast_to',
|
||||
ref=np.broadcast_to,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_broadcast_to),
|
||||
OpInfo('broadcast_tensors',
|
||||
ref=np.broadcast_arrays,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
|
Reference in New Issue
Block a user