mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[primTorch] Adds broadcast_shapes reference (#78612)
1. Added references `_refs.broadcast_shapes` 2. Added OpInfo test for `torch.broadcast_shapes` A few minor changes: - `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs - type annotation update for `_resize_meta` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
4858c56334
commit
fea909b43e
@ -371,7 +371,8 @@ class TestCommon(TestCase):
|
||||
prims.utils.compare_tensor_meta(result, meta_result)
|
||||
elif isinstance(result, Sequence):
|
||||
for a, b in zip(result, meta_result):
|
||||
prims.utils.compare_tensor_meta(a, b)
|
||||
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
|
||||
prims.utils.compare_tensor_meta(a, b)
|
||||
|
||||
def _ref_test_helper(self, ctx, device, dtype, op):
|
||||
if dtype is torch.chalf:
|
||||
@ -385,9 +386,10 @@ class TestCommon(TestCase):
|
||||
torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
for a, b in zip(tree_flatten(ref_result)[0], tree_flatten(torch_result)[0]):
|
||||
prims.utils.compare_tensor_meta(a, b)
|
||||
if getattr(op, 'validate_view_consistency', True):
|
||||
self.assertEqual(a._is_view(), b._is_view())
|
||||
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
|
||||
prims.utils.compare_tensor_meta(a, b)
|
||||
if getattr(op, 'validate_view_consistency', True):
|
||||
self.assertEqual(a._is_view(), b._is_view())
|
||||
|
||||
# Computes the dtype the more precise computatino would occur in
|
||||
precise_dtype = torch.bool
|
||||
|
@ -1533,40 +1533,6 @@ class TestOldViewOps(TestCase):
|
||||
res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@unittest.skipIf(np.__version__ < '1.20',
|
||||
"NumPy does not support broadcast_shapes before the 1.20 version")
|
||||
@onlyCPU
|
||||
def test_broadcast_shapes_numpy_ref(self, device):
|
||||
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
|
||||
for s0 in examples:
|
||||
x0 = torch.randn(s0)
|
||||
actual = torch.broadcast_shapes(s0)
|
||||
numpy_expected = np.broadcast_shapes(s0)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
for s1 in examples:
|
||||
x1 = torch.randn(s1)
|
||||
actual = torch.broadcast_shapes(s0, s1)
|
||||
numpy_expected = np.broadcast_shapes(s0, s1)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
|
||||
for integral_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(*integral_inputs)
|
||||
res2_numpy = np.broadcast_shapes(*integral_inputs)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
for list_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(list_inputs)
|
||||
res2 = np.broadcast_shapes(list_inputs)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
|
||||
for s0 in diff_input_types:
|
||||
res1 = torch.broadcast_shapes(*s0)
|
||||
res2_numpy = np.broadcast_shapes(*s0)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
def test_broadcast_to(self, device, dtype):
|
||||
|
@ -2176,9 +2176,7 @@ copy_to = _make_prim(
|
||||
)
|
||||
|
||||
|
||||
def _resize_meta(
|
||||
a: TensorLikeType, shape: Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
):
|
||||
def _resize_meta(a: TensorLikeType, shape: ShapeType):
|
||||
return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
|
||||
|
||||
|
||||
|
@ -171,6 +171,7 @@ __all__ = [
|
||||
# View & Shape Ops
|
||||
#
|
||||
"as_strided",
|
||||
"broadcast_shapes",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"cat",
|
||||
@ -216,7 +217,10 @@ Tensor = torch.Tensor
|
||||
|
||||
|
||||
def _broadcast_shapes(*_shapes):
|
||||
shapes = tuple(filter(lambda x: x is not None, _shapes))
|
||||
shapes = tuple(
|
||||
(x,) if isinstance(x, int) else x
|
||||
for x in filter(lambda x: x is not None, _shapes)
|
||||
)
|
||||
|
||||
# Short-circuits on no input
|
||||
if len(shapes) == 0:
|
||||
@ -1568,6 +1572,10 @@ def as_strided(
|
||||
return prims.as_strided(a, size, stride, storage_offset)
|
||||
|
||||
|
||||
def broadcast_shapes(*shapes) -> ShapeType:
|
||||
return torch.Size(_broadcast_shapes(*shapes))
|
||||
|
||||
|
||||
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
|
||||
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
|
||||
|
||||
|
@ -2364,6 +2364,24 @@ def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
|
||||
|
||||
def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs):
|
||||
shapes = (
|
||||
((), ()),
|
||||
((S,), ()),
|
||||
((S, 1), (S,)),
|
||||
((S, 1), S),
|
||||
((M, S), ()),
|
||||
((S, M, S), (M, S)),
|
||||
((S, M, S), (S, M, S)),
|
||||
((M, 1, S), (M, S)),
|
||||
((M, 1, S), (1, M, S)),
|
||||
((0, 1, 3), (0, 10, 3))
|
||||
)
|
||||
|
||||
for shape in shapes:
|
||||
inp, *arg0 = shape
|
||||
yield SampleInput(inp, args=arg0)
|
||||
|
||||
# The base reference input generation for elementwise binary operations
|
||||
def _reference_inputs_elementwise_binary(op, device, dtype, requires_grad, exclude_zero, **kwargs):
|
||||
yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
|
||||
@ -10648,6 +10666,27 @@ op_db: List[OpInfo] = [
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
sample_inputs_func=sample_inputs_broadcast_to),
|
||||
OpInfo('broadcast_shapes',
|
||||
op=torch.broadcast_shapes,
|
||||
ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
|
||||
dtypes=_dispatch_dtypes((torch.float32,)),
|
||||
supports_out=False,
|
||||
supports_gradgrad=False,
|
||||
assert_autodiffed=False,
|
||||
supports_autograd=False,
|
||||
supports_scripting=False,
|
||||
sample_inputs_func=sample_inputs_broadcast_shapes,
|
||||
skips=(
|
||||
# https://github.com/pytorch/pytorch/issues/64997
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# skip dtype tests since broadcast_shape is not device dependent.
|
||||
# having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
|
||||
# skip these tests since we have non tensor input
|
||||
DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
|
||||
DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('broadcast_tensors',
|
||||
ref=np.broadcast_arrays,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
@ -19783,6 +19822,10 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.broadcast_shapes",
|
||||
torch_opinfo_name="broadcast_shapes",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.broadcast_tensors",
|
||||
torch_opinfo_name="broadcast_tensors",
|
||||
|
Reference in New Issue
Block a user