unify broadcast_shapes functions and avoid duplicates (#160251)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160251
Approved by: https://github.com/jingsh, https://github.com/ColinPeppler
ghstack dependencies: #160250
This commit is contained in:
Laith Sakka
2025-08-15 13:58:19 -07:00
committed by PyTorch MergeBot
parent c03809e8a5
commit 65dc4df74d
3 changed files with 29 additions and 65 deletions

View File

@ -1656,7 +1656,7 @@ class TestOldViewOps(TestCase):
inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]]
for integral_inputs_with_neg_vals in inputs_with_neg_vals:
with self.assertRaisesRegex(
RuntimeError, "Trying to create tensor with negative dimension"
ValueError, "Attempting to broadcast a dimension with negative length!"
):
torch.broadcast_shapes(*integral_inputs_with_neg_vals)
@ -1664,20 +1664,21 @@ class TestOldViewOps(TestCase):
for error_input in integral_inputs_error_case:
with self.assertRaisesRegex(
RuntimeError,
"Shape mismatch: objects cannot be broadcast to a single shape",
".*expected shape should be broadcastable to*",
):
torch.broadcast_shapes(*error_input)
negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
for s0 in negative_inputs:
with self.assertRaisesRegex(
RuntimeError, "Trying to create tensor with negative dimension"
ValueError, "Attempting to broadcast a dimension with negative length!"
):
torch.broadcast_shapes(s0)
for s1 in negative_inputs:
with self.assertRaisesRegex(
RuntimeError, "Trying to create tensor with negative dimension"
ValueError,
"Attempting to broadcast a dimension with negative length!",
):
torch.broadcast_shapes(s0, s1)

View File

@ -385,7 +385,7 @@ def handle_noncontiguous_outputs(input_tlist, output):
def _broadcast_shapes(*_shapes):
from torch.fx.experimental.symbolic_shapes import guard_or_false
from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
shapes = tuple(
(x,) if isinstance(x, IntLike) else x
@ -396,10 +396,12 @@ def _broadcast_shapes(*_shapes):
if len(shapes) == 0:
return None
# Type checking
# TODO: make common validations available as utils
for shape in shapes:
assert isinstance(shape, Sequence)
if not isinstance(shape, Sequence):
raise RuntimeError(
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
shape,
)
# Computes common shape
common_shape: list[Union[int, torch.SymInt]] = [
@ -407,16 +409,26 @@ def _broadcast_shapes(*_shapes):
] * reduce(max, (len(shape) for shape in shapes))
for arg_idx, shape in enumerate(shapes):
for idx in range(-1, -1 - len(shape), -1):
# if both 1, or statically known the same, we rather pick non-broadcast path.
if guard_or_false(common_shape[idx] == shape[idx]):
# NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
if is_nested_int(shape[idx]):
# Broadcasting is allowed for (j0, 1) or (j0, j0);
# not (j0, j1), (j0, 5), etc.
if is_nested_int(common_shape[idx]) and guard_or_false(
shape[idx] == common_shape[idx]
):
continue
elif guard_or_false(common_shape[idx] == 1):
else:
if guard_or_false(shape[idx] == common_shape[idx]):
continue
if guard_or_false(common_shape[idx] == 1):
if shape[idx] < 0:
raise ValueError(
"Attempting to broadcast a dimension with negative length!"
)
common_shape[idx] = shape[idx]
elif guard_or_false(shape[idx] == 1):
if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
# broadcast case .
continue
else:

View File

@ -105,58 +105,9 @@ def broadcast_shapes(*shapes):
# This wrapper exists to support variadic args.
# TODO Move this to C++ once the jit has better support for torch.Size.
if not torch.jit.is_tracing():
max_len = 0
for shape in shapes:
if isinstance(shape, (int, torch.SymInt)):
if max_len < 1:
max_len = 1
elif isinstance(shape, (tuple, list)):
s = len(shape)
if max_len < s:
max_len = s
result = [1] * max_len
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
is_nested_int,
)
for shape in shapes:
if isinstance(shape, (int, torch.SymInt)):
shape = (shape,)
if isinstance(shape, (tuple, list)):
for i in range(-1, -1 - len(shape), -1):
if shape[i] < 0:
raise RuntimeError(
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
)
# NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
if is_nested_int(shape[i]):
# Broadcasting is allowed for (j0, 1) or (j0, j0);
# not (j0, j1), (j0, 5), etc.
if is_nested_int(result[i]) and guard_size_oblivious(
shape[i] == result[i]
):
continue
else:
# NB: result is initialized to 1 so this is effectively an
# equals one test
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
shape[i] == result[i]
):
continue
if result[i] != 1:
raise RuntimeError(
"Shape mismatch: objects cannot be broadcast to a single shape"
)
result[i] = shape[i]
else:
raise RuntimeError(
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
shape,
)
result = torch._refs._broadcast_shapes(*shapes)
if result is None:
return torch.Size([])
return torch.Size(result)
else:
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail