From b85f21fc1dc220044c22ee8a718693f42bf4693f Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Thu, 26 Sep 2024 10:25:14 +0000 Subject: [PATCH] Add decomposition for squeeze_copy (#130941) * Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130941 Approved by: https://github.com/amjames, https://github.com/eellison ghstack dependencies: #136653 --- test/distributed/_tensor/test_dtensor_ops.py | 1 + ...asDecompTest.test_has_decomposition.expect | 6 ----- test/functorch/test_ops.py | 3 +++ test/functorch/test_vmap.py | 1 + test/test_mps.py | 1 + tools/autograd/gen_variable_type.py | 1 + torch/_decomp/__init__.py | 1 + torch/_refs/__init__.py | 2 ++ .../_internal/common_methods_invocations.py | 25 +++++++++++++++++++ 9 files changed, 35 insertions(+), 6 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 532bd3facae5..2fe8df1f1812 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -426,6 +426,7 @@ dtensor_fails = { xfail("special.xlog1py"), xfail("special.zeta"), xfail("squeeze", "multiple"), + xfail("squeeze_copy"), xfail("signal.windows.bartlett"), xfail("signal.windows.blackman"), xfail("signal.windows.cosine"), diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 08a2435cec21..444462b35dc7 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1283,12 +1283,6 @@ aten::split_copy.Tensor_out aten::squeeze_ aten::squeeze_.dim aten::squeeze_.dims -aten::squeeze_copy -aten::squeeze_copy.dim -aten::squeeze_copy.dim_out -aten::squeeze_copy.dims -aten::squeeze_copy.dims_out -aten::squeeze_copy.out aten::sspaddmm.out aten::t_ aten::to_mkldnn diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 93e8f23d1ea4..54136a4f7bab 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1416,6 +1416,7 @@ class TestOperators(TestCase): xfail("as_strided_scatter", ""), xfail("masked.cumprod", ""), xfail("renorm"), # hit vmap fallback, which is disabled + xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), @@ -1482,6 +1483,7 @@ class TestOperators(TestCase): xfail("put"), xfail("quantile"), xfail("renorm"), + xfail("squeeze_copy"), xfail("take"), xfail("tensor_split"), xfail("to_sparse"), @@ -1542,6 +1544,7 @@ class TestOperators(TestCase): xfail( "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled + xfail("squeeze_copy"), xfail("t_copy"), xfail("transpose_copy"), xfail("unsqueeze_copy"), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index c051e675e578..870b0e61b26e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4440,6 +4440,7 @@ class TestVmapOperatorsOpInfo(TestCase): xfail("put"), xfail("quantile"), xfail("renorm"), + xfail("squeeze_copy"), xfail("resize_as_"), xfail("take"), xfail("tensor_split"), diff --git a/test/test_mps.py b/test/test_mps.py index ecff9312e6c7..079de2b3359b 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -339,6 +339,7 @@ def mps_ops_modifier(ops): 'split_with_sizes_copy', 'splitlist_args', 'squeeze', + 'squeeze_copy', 'squeezemultiple', 'sub', 'svd', diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index fa6c578dea04..d26a83713a68 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -200,6 +200,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = { "transpose_copy", "permute", "squeeze", + "squeeze_copy", "unsqueeze", "unsqueeze_copy", "resize", diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index a602892fe9be..5449189f92ad 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -622,6 +622,7 @@ def _core_aten_decompositions_post_autograd() -> ( aten.special_xlog1py, aten.split.Tensor, aten.split_with_sizes_copy, + aten.squeeze_copy, aten.squeeze.default, aten.squeeze.dim, aten.std, diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 4ed5aa76cf02..443026be2980 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -296,6 +296,7 @@ __all__ = [ "stack", "swap_axes", # alias for transpose "squeeze", + "squeeze_copy", "t", "t_copy", "T", @@ -6376,6 +6377,7 @@ expand_copy = _make_copy_from_view(aten.expand) # TODO: This must return a sparse tensor if the input is sparse, but refs have # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) +squeeze_copy = _make_copy_from_view(aten.squeeze) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 99439c6ac6a2..ec6fad938b43 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -19599,6 +19599,26 @@ op_db: List[OpInfo] = [ # https://github.com/pytorch/pytorch/issues/66357 check_batched_forward_grad=False, sample_inputs_func=sample_inputs_squeeze_multiple), + OpInfo('squeeze_copy', + ref=_squeeze_ref, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + # https://github.com/pytorch/pytorch/issues/66357 + check_batched_forward_grad=False, + sample_inputs_func=sample_inputs_squeeze, + skips=( + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,), + ), + )), UnaryUfuncInfo( 'fill', ref=_fill_np, @@ -23993,6 +24013,11 @@ python_ref_db = [ "_refs.squeeze", torch_opinfo_name="squeeze", ), + PythonRefInfo( + "_refs.squeeze_copy", + torch_opinfo_name="squeeze_copy", + supports_out=True, + ), PythonRefInfo( "_refs.squeeze", torch_opinfo_name="squeeze",