mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add decomposition for transpose_copy (#130943)
* Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130943 Approved by: https://github.com/amjames, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad75b09d89
commit
e05ea2b179
@ -1301,8 +1301,6 @@ aten::to_padded_tensor.out
|
||||
aten::topk
|
||||
aten::topk.values
|
||||
aten::transpose_
|
||||
aten::transpose_copy.int
|
||||
aten::transpose_copy.int_out
|
||||
aten::triangular_solve
|
||||
aten::triangular_solve.X
|
||||
aten::unbind_copy.int
|
||||
|
@ -1429,6 +1429,7 @@ class TestOperators(TestCase):
|
||||
xfail("masked.cumprod", ""),
|
||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||
xfail("t_copy"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
}
|
||||
),
|
||||
@ -1567,6 +1568,7 @@ class TestOperators(TestCase):
|
||||
"index_fill"
|
||||
), # aten::_unique hit the vmap fallback which is currently disabled
|
||||
xfail("t_copy"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
}
|
||||
),
|
||||
|
@ -4446,6 +4446,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
xfail("resize_as_"),
|
||||
xfail("take"),
|
||||
xfail("tensor_split"),
|
||||
xfail("transpose_copy"),
|
||||
xfail("to_sparse"),
|
||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||
xfail("item"),
|
||||
|
@ -344,6 +344,7 @@ def mps_ops_modifier(ops):
|
||||
'tanh',
|
||||
'tensor_split',
|
||||
'transpose',
|
||||
'transpose_copy',
|
||||
'T',
|
||||
'unbind',
|
||||
'unflatten',
|
||||
|
@ -197,6 +197,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||
"nanmean",
|
||||
"nansum",
|
||||
"transpose",
|
||||
"transpose_copy",
|
||||
"permute",
|
||||
"squeeze",
|
||||
"unsqueeze",
|
||||
|
@ -452,6 +452,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
aten.threshold_backward,
|
||||
aten.trace,
|
||||
aten.transpose.int,
|
||||
aten.transpose_copy,
|
||||
aten.tril,
|
||||
aten.tril_,
|
||||
aten.triu,
|
||||
|
@ -290,6 +290,7 @@ __all__ = [
|
||||
"take_along_dim",
|
||||
"tensor_split",
|
||||
"transpose",
|
||||
"transpose_copy",
|
||||
"unfold",
|
||||
"unfold_copy",
|
||||
"unsqueeze",
|
||||
@ -6335,6 +6336,7 @@ expand_copy = _make_copy_from_view(aten.expand)
|
||||
# no sparse support. See narrow_copy_sparse in core.
|
||||
narrow_copy = _make_copy_from_view(aten.narrow)
|
||||
t_copy = _make_copy_from_view(aten.t)
|
||||
transpose_copy = _make_copy_from_view(aten.transpose)
|
||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||
view_copy = _make_copy_from_view(aten.view)
|
||||
|
||||
|
@ -19621,6 +19621,24 @@ op_db: List[OpInfo] = [
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_transpose_swapdims),
|
||||
OpInfo('transpose_copy',
|
||||
assert_jit_shape_analysis=True,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
||||
supports_out=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_transpose_swapdims,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestJit',
|
||||
'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32,)
|
||||
),
|
||||
)),
|
||||
OpInfo('T',
|
||||
op=lambda x: x.T,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
||||
@ -23835,6 +23853,15 @@ python_ref_db = [
|
||||
"_refs.transpose",
|
||||
torch_opinfo_name="transpose",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.transpose_copy",
|
||||
torch_opinfo_name="transpose_copy",
|
||||
skips=(
|
||||
# RuntimeError: no _refs support for torch.Tensor.is_conj
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||
),
|
||||
supports_out=True,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.t",
|
||||
torch_opinfo_name="t",
|
||||
|
Reference in New Issue
Block a user