[primTorch] Adds broadcast_shapes reference (#78612)

1. Added references `_refs.broadcast_shapes`
2. Added OpInfo test for `torch.broadcast_shapes`

A few minor changes:
- `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs
- type annotation update for `_resize_meta`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612
Approved by: https://github.com/mruberry
This commit is contained in:
jjsjann123
2022-06-02 08:56:37 +00:00
committed by PyTorch MergeBot
parent 4858c56334
commit fea909b43e
5 changed files with 59 additions and 42 deletions

View File

@ -1533,40 +1533,6 @@ class TestOldViewOps(TestCase):
res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
self.assertEqual(res1, res2)
@unittest.skipIf(np.__version__ < '1.20',
"NumPy does not support broadcast_shapes before the 1.20 version")
@onlyCPU
def test_broadcast_shapes_numpy_ref(self, device):
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
for s0 in examples:
x0 = torch.randn(s0)
actual = torch.broadcast_shapes(s0)
numpy_expected = np.broadcast_shapes(s0)
self.assertEqual(actual, numpy_expected)
for s1 in examples:
x1 = torch.randn(s1)
actual = torch.broadcast_shapes(s0, s1)
numpy_expected = np.broadcast_shapes(s0, s1)
self.assertEqual(actual, numpy_expected)
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
for integral_inputs in inputs_list:
res1 = torch.broadcast_shapes(*integral_inputs)
res2_numpy = np.broadcast_shapes(*integral_inputs)
self.assertEqual(res1, res2_numpy)
for list_inputs in inputs_list:
res1 = torch.broadcast_shapes(list_inputs)
res2 = np.broadcast_shapes(list_inputs)
self.assertEqual(res1, res2)
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
for s0 in diff_input_types:
res1 = torch.broadcast_shapes(*s0)
res2_numpy = np.broadcast_shapes(*s0)
self.assertEqual(res1, res2_numpy)
# Skip BFloat16 since numpy does not support it
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test_broadcast_to(self, device, dtype):