Add decomposition for transpose_copy (#130943)

* Extracted from #128416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130943
Approved by: https://github.com/amjames, https://github.com/eellison
This commit is contained in:
Tom Ritchford
2024-09-10 13:33:01 +00:00
committed by PyTorch MergeBot
parent ad75b09d89
commit e05ea2b179
8 changed files with 35 additions and 2 deletions

View File

@ -1301,8 +1301,6 @@ aten::to_padded_tensor.out
aten::topk aten::topk
aten::topk.values aten::topk.values
aten::transpose_ aten::transpose_
aten::transpose_copy.int
aten::transpose_copy.int_out
aten::triangular_solve aten::triangular_solve
aten::triangular_solve.X aten::triangular_solve.X
aten::unbind_copy.int aten::unbind_copy.int

View File

@ -1429,6 +1429,7 @@ class TestOperators(TestCase):
xfail("masked.cumprod", ""), xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled xfail("renorm"), # hit vmap fallback, which is disabled
xfail("t_copy"), xfail("t_copy"),
xfail("transpose_copy"),
xfail("unsqueeze_copy"), xfail("unsqueeze_copy"),
} }
), ),
@ -1567,6 +1568,7 @@ class TestOperators(TestCase):
"index_fill" "index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled ), # aten::_unique hit the vmap fallback which is currently disabled
xfail("t_copy"), xfail("t_copy"),
xfail("transpose_copy"),
xfail("unsqueeze_copy"), xfail("unsqueeze_copy"),
} }
), ),

View File

@ -4446,6 +4446,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("resize_as_"), xfail("resize_as_"),
xfail("take"), xfail("take"),
xfail("tensor_split"), xfail("tensor_split"),
xfail("transpose_copy"),
xfail("to_sparse"), xfail("to_sparse"),
# TypeError: expected Tensor as element 0 in argument 0, but got float # TypeError: expected Tensor as element 0 in argument 0, but got float
xfail("item"), xfail("item"),

View File

@ -344,6 +344,7 @@ def mps_ops_modifier(ops):
'tanh', 'tanh',
'tensor_split', 'tensor_split',
'transpose', 'transpose',
'transpose_copy',
'T', 'T',
'unbind', 'unbind',
'unflatten', 'unflatten',

View File

@ -197,6 +197,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"nanmean", "nanmean",
"nansum", "nansum",
"transpose", "transpose",
"transpose_copy",
"permute", "permute",
"squeeze", "squeeze",
"unsqueeze", "unsqueeze",

View File

@ -452,6 +452,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.threshold_backward, aten.threshold_backward,
aten.trace, aten.trace,
aten.transpose.int, aten.transpose.int,
aten.transpose_copy,
aten.tril, aten.tril,
aten.tril_, aten.tril_,
aten.triu, aten.triu,

View File

@ -290,6 +290,7 @@ __all__ = [
"take_along_dim", "take_along_dim",
"tensor_split", "tensor_split",
"transpose", "transpose",
"transpose_copy",
"unfold", "unfold",
"unfold_copy", "unfold_copy",
"unsqueeze", "unsqueeze",
@ -6335,6 +6336,7 @@ expand_copy = _make_copy_from_view(aten.expand)
# no sparse support. See narrow_copy_sparse in core. # no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow) narrow_copy = _make_copy_from_view(aten.narrow)
t_copy = _make_copy_from_view(aten.t) t_copy = _make_copy_from_view(aten.t)
transpose_copy = _make_copy_from_view(aten.transpose)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view) view_copy = _make_copy_from_view(aten.view)

View File

@ -19621,6 +19621,24 @@ op_db: List[OpInfo] = [
# vmap does not support inplace views # vmap does not support inplace views
check_inplace_batched_forward_grad=False, check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_transpose_swapdims), sample_inputs_func=sample_inputs_transpose_swapdims),
OpInfo('transpose_copy',
assert_jit_shape_analysis=True,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
supports_out=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_transpose_swapdims,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
DecorateInfo(
unittest.expectedFailure,
'TestJit',
'test_variant_consistency_jit',
dtypes=(torch.float32,)
),
)),
OpInfo('T', OpInfo('T',
op=lambda x: x.T, op=lambda x: x.T,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
@ -23835,6 +23853,15 @@ python_ref_db = [
"_refs.transpose", "_refs.transpose",
torch_opinfo_name="transpose", torch_opinfo_name="transpose",
), ),
PythonRefInfo(
"_refs.transpose_copy",
torch_opinfo_name="transpose_copy",
skips=(
# RuntimeError: no _refs support for torch.Tensor.is_conj
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
),
supports_out=True,
),
PythonRefInfo( PythonRefInfo(
"_refs.t", "_refs.t",
torch_opinfo_name="t", torch_opinfo_name="t",