mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add decomposition for squeeze_copy (#130941)
* Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130941 Approved by: https://github.com/amjames, https://github.com/eellison ghstack dependencies: #136653
This commit is contained in:
committed by
PyTorch MergeBot
parent
083921852b
commit
b85f21fc1d
@ -426,6 +426,7 @@ dtensor_fails = {
|
||||
xfail("special.xlog1py"),
|
||||
xfail("special.zeta"),
|
||||
xfail("squeeze", "multiple"),
|
||||
xfail("squeeze_copy"),
|
||||
xfail("signal.windows.bartlett"),
|
||||
xfail("signal.windows.blackman"),
|
||||
xfail("signal.windows.cosine"),
|
||||
|
@ -1283,12 +1283,6 @@ aten::split_copy.Tensor_out
|
||||
aten::squeeze_
|
||||
aten::squeeze_.dim
|
||||
aten::squeeze_.dims
|
||||
aten::squeeze_copy
|
||||
aten::squeeze_copy.dim
|
||||
aten::squeeze_copy.dim_out
|
||||
aten::squeeze_copy.dims
|
||||
aten::squeeze_copy.dims_out
|
||||
aten::squeeze_copy.out
|
||||
aten::sspaddmm.out
|
||||
aten::t_
|
||||
aten::to_mkldnn
|
||||
|
@ -1416,6 +1416,7 @@ class TestOperators(TestCase):
|
||||
xfail("as_strided_scatter", ""),
|
||||
xfail("masked.cumprod", ""),
|
||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||
xfail("squeeze_copy"),
|
||||
xfail("t_copy"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
@ -1482,6 +1483,7 @@ class TestOperators(TestCase):
|
||||
xfail("put"),
|
||||
xfail("quantile"),
|
||||
xfail("renorm"),
|
||||
xfail("squeeze_copy"),
|
||||
xfail("take"),
|
||||
xfail("tensor_split"),
|
||||
xfail("to_sparse"),
|
||||
@ -1542,6 +1544,7 @@ class TestOperators(TestCase):
|
||||
xfail(
|
||||
"index_fill"
|
||||
), # aten::_unique hit the vmap fallback which is currently disabled
|
||||
xfail("squeeze_copy"),
|
||||
xfail("t_copy"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
|
@ -4440,6 +4440,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
xfail("put"),
|
||||
xfail("quantile"),
|
||||
xfail("renorm"),
|
||||
xfail("squeeze_copy"),
|
||||
xfail("resize_as_"),
|
||||
xfail("take"),
|
||||
xfail("tensor_split"),
|
||||
|
@ -339,6 +339,7 @@ def mps_ops_modifier(ops):
|
||||
'split_with_sizes_copy',
|
||||
'splitlist_args',
|
||||
'squeeze',
|
||||
'squeeze_copy',
|
||||
'squeezemultiple',
|
||||
'sub',
|
||||
'svd',
|
||||
|
@ -200,6 +200,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"transpose_copy",
|
||||
"permute",
|
||||
"squeeze",
|
||||
"squeeze_copy",
|
||||
"unsqueeze",
|
||||
"unsqueeze_copy",
|
||||
"resize",
|
||||
|
@ -622,6 +622,7 @@ def _core_aten_decompositions_post_autograd() -> (
|
||||
aten.special_xlog1py,
|
||||
aten.split.Tensor,
|
||||
aten.split_with_sizes_copy,
|
||||
aten.squeeze_copy,
|
||||
aten.squeeze.default,
|
||||
aten.squeeze.dim,
|
||||
aten.std,
|
||||
|
@ -296,6 +296,7 @@ __all__ = [
|
||||
"stack",
|
||||
"swap_axes", # alias for transpose
|
||||
"squeeze",
|
||||
"squeeze_copy",
|
||||
"t",
|
||||
"t_copy",
|
||||
"T",
|
||||
@ -6376,6 +6377,7 @@ expand_copy = _make_copy_from_view(aten.expand)
|
||||
# TODO: This must return a sparse tensor if the input is sparse, but refs have
|
||||
# no sparse support. See narrow_copy_sparse in core.
|
||||
narrow_copy = _make_copy_from_view(aten.narrow)
|
||||
squeeze_copy = _make_copy_from_view(aten.squeeze)
|
||||
t_copy = _make_copy_from_view(aten.t)
|
||||
transpose_copy = _make_copy_from_view(aten.transpose)
|
||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||
|
@ -19599,6 +19599,26 @@ op_db: List[OpInfo] = [
|
||||
# https://github.com/pytorch/pytorch/issues/66357
|
||||
check_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_squeeze_multiple),
|
||||
OpInfo('squeeze_copy',
|
||||
ref=_squeeze_ref,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
supports_out=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
# https://github.com/pytorch/pytorch/issues/66357
|
||||
check_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_squeeze,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestJit',
|
||||
'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32,),
|
||||
),
|
||||
)),
|
||||
UnaryUfuncInfo(
|
||||
'fill',
|
||||
ref=_fill_np,
|
||||
@ -23993,6 +24013,11 @@ python_ref_db = [
|
||||
"_refs.squeeze",
|
||||
torch_opinfo_name="squeeze",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.squeeze_copy",
|
||||
torch_opinfo_name="squeeze_copy",
|
||||
supports_out=True,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.squeeze",
|
||||
torch_opinfo_name="squeeze",
|
||||
|
Reference in New Issue
Block a user