mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add decomposition for transpose_copy (#130943)
* Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130943 Approved by: https://github.com/amjames, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad75b09d89
commit
e05ea2b179
@ -1301,8 +1301,6 @@ aten::to_padded_tensor.out
|
|||||||
aten::topk
|
aten::topk
|
||||||
aten::topk.values
|
aten::topk.values
|
||||||
aten::transpose_
|
aten::transpose_
|
||||||
aten::transpose_copy.int
|
|
||||||
aten::transpose_copy.int_out
|
|
||||||
aten::triangular_solve
|
aten::triangular_solve
|
||||||
aten::triangular_solve.X
|
aten::triangular_solve.X
|
||||||
aten::unbind_copy.int
|
aten::unbind_copy.int
|
||||||
|
@ -1429,6 +1429,7 @@ class TestOperators(TestCase):
|
|||||||
xfail("masked.cumprod", ""),
|
xfail("masked.cumprod", ""),
|
||||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||||
xfail("t_copy"),
|
xfail("t_copy"),
|
||||||
|
xfail("transpose_copy"),
|
||||||
xfail("unsqueeze_copy"),
|
xfail("unsqueeze_copy"),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -1567,6 +1568,7 @@ class TestOperators(TestCase):
|
|||||||
"index_fill"
|
"index_fill"
|
||||||
), # aten::_unique hit the vmap fallback which is currently disabled
|
), # aten::_unique hit the vmap fallback which is currently disabled
|
||||||
xfail("t_copy"),
|
xfail("t_copy"),
|
||||||
|
xfail("transpose_copy"),
|
||||||
xfail("unsqueeze_copy"),
|
xfail("unsqueeze_copy"),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
@ -4446,6 +4446,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||||||
xfail("resize_as_"),
|
xfail("resize_as_"),
|
||||||
xfail("take"),
|
xfail("take"),
|
||||||
xfail("tensor_split"),
|
xfail("tensor_split"),
|
||||||
|
xfail("transpose_copy"),
|
||||||
xfail("to_sparse"),
|
xfail("to_sparse"),
|
||||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||||
xfail("item"),
|
xfail("item"),
|
||||||
|
@ -344,6 +344,7 @@ def mps_ops_modifier(ops):
|
|||||||
'tanh',
|
'tanh',
|
||||||
'tensor_split',
|
'tensor_split',
|
||||||
'transpose',
|
'transpose',
|
||||||
|
'transpose_copy',
|
||||||
'T',
|
'T',
|
||||||
'unbind',
|
'unbind',
|
||||||
'unflatten',
|
'unflatten',
|
||||||
|
@ -197,6 +197,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||||||
"nanmean",
|
"nanmean",
|
||||||
"nansum",
|
"nansum",
|
||||||
"transpose",
|
"transpose",
|
||||||
|
"transpose_copy",
|
||||||
"permute",
|
"permute",
|
||||||
"squeeze",
|
"squeeze",
|
||||||
"unsqueeze",
|
"unsqueeze",
|
||||||
|
@ -452,6 +452,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||||||
aten.threshold_backward,
|
aten.threshold_backward,
|
||||||
aten.trace,
|
aten.trace,
|
||||||
aten.transpose.int,
|
aten.transpose.int,
|
||||||
|
aten.transpose_copy,
|
||||||
aten.tril,
|
aten.tril,
|
||||||
aten.tril_,
|
aten.tril_,
|
||||||
aten.triu,
|
aten.triu,
|
||||||
|
@ -290,6 +290,7 @@ __all__ = [
|
|||||||
"take_along_dim",
|
"take_along_dim",
|
||||||
"tensor_split",
|
"tensor_split",
|
||||||
"transpose",
|
"transpose",
|
||||||
|
"transpose_copy",
|
||||||
"unfold",
|
"unfold",
|
||||||
"unfold_copy",
|
"unfold_copy",
|
||||||
"unsqueeze",
|
"unsqueeze",
|
||||||
@ -6335,6 +6336,7 @@ expand_copy = _make_copy_from_view(aten.expand)
|
|||||||
# no sparse support. See narrow_copy_sparse in core.
|
# no sparse support. See narrow_copy_sparse in core.
|
||||||
narrow_copy = _make_copy_from_view(aten.narrow)
|
narrow_copy = _make_copy_from_view(aten.narrow)
|
||||||
t_copy = _make_copy_from_view(aten.t)
|
t_copy = _make_copy_from_view(aten.t)
|
||||||
|
transpose_copy = _make_copy_from_view(aten.transpose)
|
||||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||||
view_copy = _make_copy_from_view(aten.view)
|
view_copy = _make_copy_from_view(aten.view)
|
||||||
|
|
||||||
|
@ -19621,6 +19621,24 @@ op_db: List[OpInfo] = [
|
|||||||
# vmap does not support inplace views
|
# vmap does not support inplace views
|
||||||
check_inplace_batched_forward_grad=False,
|
check_inplace_batched_forward_grad=False,
|
||||||
sample_inputs_func=sample_inputs_transpose_swapdims),
|
sample_inputs_func=sample_inputs_transpose_swapdims),
|
||||||
|
OpInfo('transpose_copy',
|
||||||
|
assert_jit_shape_analysis=True,
|
||||||
|
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
||||||
|
supports_out=True,
|
||||||
|
supports_forward_ad=True,
|
||||||
|
supports_fwgrad_bwgrad=True,
|
||||||
|
# vmap does not support inplace views
|
||||||
|
check_inplace_batched_forward_grad=False,
|
||||||
|
sample_inputs_func=sample_inputs_transpose_swapdims,
|
||||||
|
skips=(
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure,
|
||||||
|
'TestJit',
|
||||||
|
'test_variant_consistency_jit',
|
||||||
|
dtypes=(torch.float32,)
|
||||||
|
),
|
||||||
|
)),
|
||||||
OpInfo('T',
|
OpInfo('T',
|
||||||
op=lambda x: x.T,
|
op=lambda x: x.T,
|
||||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
|
||||||
@ -23835,6 +23853,15 @@ python_ref_db = [
|
|||||||
"_refs.transpose",
|
"_refs.transpose",
|
||||||
torch_opinfo_name="transpose",
|
torch_opinfo_name="transpose",
|
||||||
),
|
),
|
||||||
|
PythonRefInfo(
|
||||||
|
"_refs.transpose_copy",
|
||||||
|
torch_opinfo_name="transpose_copy",
|
||||||
|
skips=(
|
||||||
|
# RuntimeError: no _refs support for torch.Tensor.is_conj
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
|
||||||
|
),
|
||||||
|
supports_out=True,
|
||||||
|
),
|
||||||
PythonRefInfo(
|
PythonRefInfo(
|
||||||
"_refs.t",
|
"_refs.t",
|
||||||
torch_opinfo_name="t",
|
torch_opinfo_name="t",
|
||||||
|
Reference in New Issue
Block a user