mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[primTorch] Adds broadcast_shapes reference (#78612)
1. Added references `_refs.broadcast_shapes` 2. Added OpInfo test for `torch.broadcast_shapes` A few minor changes: - `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs - type annotation update for `_resize_meta` Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
4858c56334
commit
fea909b43e
@ -1533,40 +1533,6 @@ class TestOldViewOps(TestCase):
|
||||
res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@unittest.skipIf(np.__version__ < '1.20',
|
||||
"NumPy does not support broadcast_shapes before the 1.20 version")
|
||||
@onlyCPU
|
||||
def test_broadcast_shapes_numpy_ref(self, device):
|
||||
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
|
||||
for s0 in examples:
|
||||
x0 = torch.randn(s0)
|
||||
actual = torch.broadcast_shapes(s0)
|
||||
numpy_expected = np.broadcast_shapes(s0)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
for s1 in examples:
|
||||
x1 = torch.randn(s1)
|
||||
actual = torch.broadcast_shapes(s0, s1)
|
||||
numpy_expected = np.broadcast_shapes(s0, s1)
|
||||
self.assertEqual(actual, numpy_expected)
|
||||
|
||||
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
|
||||
for integral_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(*integral_inputs)
|
||||
res2_numpy = np.broadcast_shapes(*integral_inputs)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
for list_inputs in inputs_list:
|
||||
res1 = torch.broadcast_shapes(list_inputs)
|
||||
res2 = np.broadcast_shapes(list_inputs)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
|
||||
for s0 in diff_input_types:
|
||||
res1 = torch.broadcast_shapes(*s0)
|
||||
res2_numpy = np.broadcast_shapes(*s0)
|
||||
self.assertEqual(res1, res2_numpy)
|
||||
|
||||
# Skip BFloat16 since numpy does not support it
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
def test_broadcast_to(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user