mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
unify broadcast_shapes functions and avoid duplicates (#160251)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160251 Approved by: https://github.com/jingsh, https://github.com/ColinPeppler ghstack dependencies: #160250
This commit is contained in:
committed by
PyTorch MergeBot
parent
c03809e8a5
commit
65dc4df74d
@ -1656,7 +1656,7 @@ class TestOldViewOps(TestCase):
|
||||
inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]]
|
||||
for integral_inputs_with_neg_vals in inputs_with_neg_vals:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError, "Attempting to broadcast a dimension with negative length!"
|
||||
):
|
||||
torch.broadcast_shapes(*integral_inputs_with_neg_vals)
|
||||
|
||||
@ -1664,20 +1664,21 @@ class TestOldViewOps(TestCase):
|
||||
for error_input in integral_inputs_error_case:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Shape mismatch: objects cannot be broadcast to a single shape",
|
||||
".*expected shape should be broadcastable to*",
|
||||
):
|
||||
torch.broadcast_shapes(*error_input)
|
||||
|
||||
negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
|
||||
for s0 in negative_inputs:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError, "Attempting to broadcast a dimension with negative length!"
|
||||
):
|
||||
torch.broadcast_shapes(s0)
|
||||
|
||||
for s1 in negative_inputs:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Trying to create tensor with negative dimension"
|
||||
ValueError,
|
||||
"Attempting to broadcast a dimension with negative length!",
|
||||
):
|
||||
torch.broadcast_shapes(s0, s1)
|
||||
|
||||
|
@ -385,7 +385,7 @@ def handle_noncontiguous_outputs(input_tlist, output):
|
||||
|
||||
|
||||
def _broadcast_shapes(*_shapes):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
|
||||
|
||||
shapes = tuple(
|
||||
(x,) if isinstance(x, IntLike) else x
|
||||
@ -396,10 +396,12 @@ def _broadcast_shapes(*_shapes):
|
||||
if len(shapes) == 0:
|
||||
return None
|
||||
|
||||
# Type checking
|
||||
# TODO: make common validations available as utils
|
||||
for shape in shapes:
|
||||
assert isinstance(shape, Sequence)
|
||||
if not isinstance(shape, Sequence):
|
||||
raise RuntimeError(
|
||||
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
|
||||
shape,
|
||||
)
|
||||
|
||||
# Computes common shape
|
||||
common_shape: list[Union[int, torch.SymInt]] = [
|
||||
@ -407,16 +409,26 @@ def _broadcast_shapes(*_shapes):
|
||||
] * reduce(max, (len(shape) for shape in shapes))
|
||||
for arg_idx, shape in enumerate(shapes):
|
||||
for idx in range(-1, -1 - len(shape), -1):
|
||||
# if both 1, or statically known the same, we rather pick non-broadcast path.
|
||||
if guard_or_false(common_shape[idx] == shape[idx]):
|
||||
# NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
|
||||
if is_nested_int(shape[idx]):
|
||||
# Broadcasting is allowed for (j0, 1) or (j0, j0);
|
||||
# not (j0, j1), (j0, 5), etc.
|
||||
if is_nested_int(common_shape[idx]) and guard_or_false(
|
||||
shape[idx] == common_shape[idx]
|
||||
):
|
||||
continue
|
||||
elif guard_or_false(common_shape[idx] == 1):
|
||||
else:
|
||||
if guard_or_false(shape[idx] == common_shape[idx]):
|
||||
continue
|
||||
|
||||
if guard_or_false(common_shape[idx] == 1):
|
||||
if shape[idx] < 0:
|
||||
raise ValueError(
|
||||
"Attempting to broadcast a dimension with negative length!"
|
||||
)
|
||||
common_shape[idx] = shape[idx]
|
||||
elif guard_or_false(shape[idx] == 1):
|
||||
|
||||
if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
|
||||
# broadcast case .
|
||||
continue
|
||||
else:
|
||||
|
@ -105,58 +105,9 @@ def broadcast_shapes(*shapes):
|
||||
# This wrapper exists to support variadic args.
|
||||
# TODO Move this to C++ once the jit has better support for torch.Size.
|
||||
if not torch.jit.is_tracing():
|
||||
max_len = 0
|
||||
for shape in shapes:
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
if max_len < 1:
|
||||
max_len = 1
|
||||
elif isinstance(shape, (tuple, list)):
|
||||
s = len(shape)
|
||||
if max_len < s:
|
||||
max_len = s
|
||||
result = [1] * max_len
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_size_oblivious,
|
||||
is_nested_int,
|
||||
)
|
||||
|
||||
for shape in shapes:
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, (tuple, list)):
|
||||
for i in range(-1, -1 - len(shape), -1):
|
||||
if shape[i] < 0:
|
||||
raise RuntimeError(
|
||||
f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
|
||||
)
|
||||
|
||||
# NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
|
||||
if is_nested_int(shape[i]):
|
||||
# Broadcasting is allowed for (j0, 1) or (j0, j0);
|
||||
# not (j0, j1), (j0, 5), etc.
|
||||
if is_nested_int(result[i]) and guard_size_oblivious(
|
||||
shape[i] == result[i]
|
||||
):
|
||||
continue
|
||||
else:
|
||||
# NB: result is initialized to 1 so this is effectively an
|
||||
# equals one test
|
||||
if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
|
||||
shape[i] == result[i]
|
||||
):
|
||||
continue
|
||||
|
||||
if result[i] != 1:
|
||||
raise RuntimeError(
|
||||
"Shape mismatch: objects cannot be broadcast to a single shape"
|
||||
)
|
||||
result[i] = shape[i]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
|
||||
shape,
|
||||
)
|
||||
result = torch._refs._broadcast_shapes(*shapes)
|
||||
if result is None:
|
||||
return torch.Size([])
|
||||
return torch.Size(result)
|
||||
else:
|
||||
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
|
||||
|
Reference in New Issue
Block a user