[primTorch] Adds broadcast_shapes reference (#78612)

1. Added references `_refs.broadcast_shapes`
2. Added OpInfo test for `torch.broadcast_shapes`

A few minor changes:
- `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs
- type annotation update for `_resize_meta`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612
Approved by: https://github.com/mruberry
This commit is contained in:
jjsjann123
2022-06-02 08:56:37 +00:00
committed by PyTorch MergeBot
parent 4858c56334
commit fea909b43e
5 changed files with 59 additions and 42 deletions

View File

@ -371,7 +371,8 @@ class TestCommon(TestCase):
prims.utils.compare_tensor_meta(result, meta_result)
elif isinstance(result, Sequence):
for a, b in zip(result, meta_result):
prims.utils.compare_tensor_meta(a, b)
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
prims.utils.compare_tensor_meta(a, b)
def _ref_test_helper(self, ctx, device, dtype, op):
if dtype is torch.chalf:
@ -385,9 +386,10 @@ class TestCommon(TestCase):
torch_result = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
for a, b in zip(tree_flatten(ref_result)[0], tree_flatten(torch_result)[0]):
prims.utils.compare_tensor_meta(a, b)
if getattr(op, 'validate_view_consistency', True):
self.assertEqual(a._is_view(), b._is_view())
if isinstance(a, torch.Tensor) or isinstance(b, torch.Tensor):
prims.utils.compare_tensor_meta(a, b)
if getattr(op, 'validate_view_consistency', True):
self.assertEqual(a._is_view(), b._is_view())
# Computes the dtype the more precise computatino would occur in
precise_dtype = torch.bool

View File

@ -1533,40 +1533,6 @@ class TestOldViewOps(TestCase):
res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
self.assertEqual(res1, res2)
@unittest.skipIf(np.__version__ < '1.20',
"NumPy does not support broadcast_shapes before the 1.20 version")
@onlyCPU
def test_broadcast_shapes_numpy_ref(self, device):
examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
for s0 in examples:
x0 = torch.randn(s0)
actual = torch.broadcast_shapes(s0)
numpy_expected = np.broadcast_shapes(s0)
self.assertEqual(actual, numpy_expected)
for s1 in examples:
x1 = torch.randn(s1)
actual = torch.broadcast_shapes(s0, s1)
numpy_expected = np.broadcast_shapes(s0, s1)
self.assertEqual(actual, numpy_expected)
inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
for integral_inputs in inputs_list:
res1 = torch.broadcast_shapes(*integral_inputs)
res2_numpy = np.broadcast_shapes(*integral_inputs)
self.assertEqual(res1, res2_numpy)
for list_inputs in inputs_list:
res1 = torch.broadcast_shapes(list_inputs)
res2 = np.broadcast_shapes(list_inputs)
self.assertEqual(res1, res2)
diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
for s0 in diff_input_types:
res1 = torch.broadcast_shapes(*s0)
res2_numpy = np.broadcast_shapes(*s0)
self.assertEqual(res1, res2_numpy)
# Skip BFloat16 since numpy does not support it
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test_broadcast_to(self, device, dtype):

View File

@ -2176,9 +2176,7 @@ copy_to = _make_prim(
)
def _resize_meta(
a: TensorLikeType, shape: Union[torch.Size, List[int], Tuple[int, ...]]
):
def _resize_meta(a: TensorLikeType, shape: ShapeType):
return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))

View File

@ -171,6 +171,7 @@ __all__ = [
# View & Shape Ops
#
"as_strided",
"broadcast_shapes",
"broadcast_tensors",
"broadcast_to",
"cat",
@ -216,7 +217,10 @@ Tensor = torch.Tensor
def _broadcast_shapes(*_shapes):
shapes = tuple(filter(lambda x: x is not None, _shapes))
shapes = tuple(
(x,) if isinstance(x, int) else x
for x in filter(lambda x: x is not None, _shapes)
)
# Short-circuits on no input
if len(shapes) == 0:
@ -1568,6 +1572,10 @@ def as_strided(
return prims.as_strided(a, size, stride, storage_offset)
def broadcast_shapes(*shapes) -> ShapeType:
return torch.Size(_broadcast_shapes(*shapes))
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))

View File

@ -2364,6 +2364,24 @@ def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs):
shapes = (
((), ()),
((S,), ()),
((S, 1), (S,)),
((S, 1), S),
((M, S), ()),
((S, M, S), (M, S)),
((S, M, S), (S, M, S)),
((M, 1, S), (M, S)),
((M, 1, S), (1, M, S)),
((0, 1, 3), (0, 10, 3))
)
for shape in shapes:
inp, *arg0 = shape
yield SampleInput(inp, args=arg0)
# The base reference input generation for elementwise binary operations
def _reference_inputs_elementwise_binary(op, device, dtype, requires_grad, exclude_zero, **kwargs):
yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
@ -10648,6 +10666,27 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_broadcast_to),
OpInfo('broadcast_shapes',
op=torch.broadcast_shapes,
ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
dtypes=_dispatch_dtypes((torch.float32,)),
supports_out=False,
supports_gradgrad=False,
assert_autodiffed=False,
supports_autograd=False,
supports_scripting=False,
sample_inputs_func=sample_inputs_broadcast_shapes,
skips=(
# https://github.com/pytorch/pytorch/issues/64997
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# skip dtype tests since broadcast_shape is not device dependent.
# having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
# skip these tests since we have non tensor input
DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('broadcast_tensors',
ref=np.broadcast_arrays,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
@ -19783,6 +19822,10 @@ python_ref_db = [
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
),
),
PythonRefInfo(
"_refs.broadcast_shapes",
torch_opinfo_name="broadcast_shapes",
),
PythonRefInfo(
"_refs.broadcast_tensors",
torch_opinfo_name="broadcast_tensors",