Add decomposition for squeeze_copy (#130941)

* Extracted from #128416

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130941
Approved by: https://github.com/amjames, https://github.com/eellison
ghstack dependencies: #136653
This commit is contained in:
Tom Ritchford
2024-09-26 10:25:14 +00:00
committed by PyTorch MergeBot
parent 083921852b
commit b85f21fc1d
9 changed files with 35 additions and 6 deletions

View File

@ -426,6 +426,7 @@ dtensor_fails = {
xfail("special.xlog1py"),
xfail("special.zeta"),
xfail("squeeze", "multiple"),
xfail("squeeze_copy"),
xfail("signal.windows.bartlett"),
xfail("signal.windows.blackman"),
xfail("signal.windows.cosine"),

View File

@ -1283,12 +1283,6 @@ aten::split_copy.Tensor_out
aten::squeeze_
aten::squeeze_.dim
aten::squeeze_.dims
aten::squeeze_copy
aten::squeeze_copy.dim
aten::squeeze_copy.dim_out
aten::squeeze_copy.dims
aten::squeeze_copy.dims_out
aten::squeeze_copy.out
aten::sspaddmm.out
aten::t_
aten::to_mkldnn

View File

@ -1416,6 +1416,7 @@ class TestOperators(TestCase):
xfail("as_strided_scatter", ""),
xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled
xfail("squeeze_copy"),
xfail("t_copy"),
xfail("transpose_copy"),
xfail("unsqueeze_copy"),
@ -1482,6 +1483,7 @@ class TestOperators(TestCase):
xfail("put"),
xfail("quantile"),
xfail("renorm"),
xfail("squeeze_copy"),
xfail("take"),
xfail("tensor_split"),
xfail("to_sparse"),
@ -1542,6 +1544,7 @@ class TestOperators(TestCase):
xfail(
"index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled
xfail("squeeze_copy"),
xfail("t_copy"),
xfail("transpose_copy"),
xfail("unsqueeze_copy"),

View File

@ -4440,6 +4440,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("put"),
xfail("quantile"),
xfail("renorm"),
xfail("squeeze_copy"),
xfail("resize_as_"),
xfail("take"),
xfail("tensor_split"),

View File

@ -339,6 +339,7 @@ def mps_ops_modifier(ops):
'split_with_sizes_copy',
'splitlist_args',
'squeeze',
'squeeze_copy',
'squeezemultiple',
'sub',
'svd',

View File

@ -200,6 +200,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"transpose_copy",
"permute",
"squeeze",
"squeeze_copy",
"unsqueeze",
"unsqueeze_copy",
"resize",

View File

@ -622,6 +622,7 @@ def _core_aten_decompositions_post_autograd() -> (
aten.special_xlog1py,
aten.split.Tensor,
aten.split_with_sizes_copy,
aten.squeeze_copy,
aten.squeeze.default,
aten.squeeze.dim,
aten.std,

View File

@ -296,6 +296,7 @@ __all__ = [
"stack",
"swap_axes", # alias for transpose
"squeeze",
"squeeze_copy",
"t",
"t_copy",
"T",
@ -6376,6 +6377,7 @@ expand_copy = _make_copy_from_view(aten.expand)
# TODO: This must return a sparse tensor if the input is sparse, but refs have
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
squeeze_copy = _make_copy_from_view(aten.squeeze)
t_copy = _make_copy_from_view(aten.t)
transpose_copy = _make_copy_from_view(aten.transpose)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)

View File

@ -19599,6 +19599,26 @@ op_db: List[OpInfo] = [
# https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
sample_inputs_func=sample_inputs_squeeze_multiple),
OpInfo('squeeze_copy',
ref=_squeeze_ref,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=True,
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
# https://github.com/pytorch/pytorch/issues/66357
check_batched_forward_grad=False,
sample_inputs_func=sample_inputs_squeeze,
skips=(
DecorateInfo(
unittest.expectedFailure,
'TestJit',
'test_variant_consistency_jit',
dtypes=(torch.float32,),
),
)),
UnaryUfuncInfo(
'fill',
ref=_fill_np,
@ -23993,6 +24013,11 @@ python_ref_db = [
"_refs.squeeze",
torch_opinfo_name="squeeze",
),
PythonRefInfo(
"_refs.squeeze_copy",
torch_opinfo_name="squeeze_copy",
supports_out=True,
),
PythonRefInfo(
"_refs.squeeze",
torch_opinfo_name="squeeze",