mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix out_wrapper, _make_copy_from_view to handle all signatures (#130937)
* See #128416 and #129476 * Simplify xskip lists in test/functorch/test_ops.py * Add supports_out=True to OpInfos for copy ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/130937 Approved by: https://github.com/peterbell10
This commit is contained in:
committed by
PyTorch MergeBot
parent
b193894b94
commit
f628813066
@ -934,7 +934,6 @@ class TestOperators(TestCase):
|
||||
skip("ormqr"), # Takes too long
|
||||
xfail("as_strided"), # incorrect output
|
||||
xfail("as_strided", "partial_views"), # incorrect output
|
||||
xfail("as_strided_copy"), # incorrect output
|
||||
xfail("as_strided_scatter"), # incorrect output
|
||||
skip("bernoulli"), # calls random op
|
||||
xfail("bfloat16"), # rank 4 tensor for channels_last
|
||||
@ -1062,6 +1061,7 @@ class TestOperators(TestCase):
|
||||
"test_vmapvjpvjp",
|
||||
{
|
||||
xfail("as_strided", "partial_views"),
|
||||
xfail("as_strided_copy"),
|
||||
},
|
||||
)
|
||||
def test_vmapvjpvjp(self, device, dtype, op):
|
||||
@ -1175,7 +1175,6 @@ class TestOperators(TestCase):
|
||||
xfail("nn.functional.max_unpool2d", "grad"),
|
||||
xfail("sparse.sampled_addmm", ""),
|
||||
xfail("sparse.mm", "reduce"),
|
||||
xfail("as_strided_copy", ""), # calls as_strided
|
||||
xfail("as_strided_scatter", ""), # calls as_strided
|
||||
xfail("index_reduce", "prod"), # .item() call
|
||||
# ---------------------------------------------------------------------
|
||||
@ -1290,7 +1289,6 @@ class TestOperators(TestCase):
|
||||
xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda)
|
||||
skip("as_strided"), # Test runner cannot handle this
|
||||
# requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
|
||||
xfail("as_strided_copy"),
|
||||
xfail("as_strided_scatter"),
|
||||
xfail(
|
||||
"nn.functional.gaussian_nll_loss"
|
||||
@ -1343,6 +1341,7 @@ class TestOperators(TestCase):
|
||||
"test_vmapjvpall",
|
||||
vmapjvpall_fail.union(
|
||||
{
|
||||
xfail("as_strided_copy"),
|
||||
decorate(
|
||||
"linalg.det",
|
||||
"singular",
|
||||
@ -1430,11 +1429,6 @@ class TestOperators(TestCase):
|
||||
xfail("masked.cumprod", ""),
|
||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||
}
|
||||
).difference(
|
||||
{
|
||||
# as_strided_copy fails test_vmapvjp, succeeds here
|
||||
xfail("as_strided_copy", ""),
|
||||
}
|
||||
),
|
||||
)
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
@ -1571,11 +1565,6 @@ class TestOperators(TestCase):
|
||||
"index_fill"
|
||||
), # aten::_unique hit the vmap fallback which is currently disabled
|
||||
}
|
||||
).difference(
|
||||
{
|
||||
# as_strided_copy fails test_vmapvjp, succeeds here
|
||||
xfail("as_strided_copy", ""),
|
||||
}
|
||||
),
|
||||
)
|
||||
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
|
||||
|
@ -240,6 +240,7 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
|
||||
),
|
||||
xfail(
|
||||
"alias_copy",
|
||||
dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64),
|
||||
reason="OnnxExporterError: Failed to export model",
|
||||
),
|
||||
xfail(
|
||||
|
@ -130,6 +130,10 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
# see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
|
||||
if func is None and isinstance(orig_func, torch._ops.OpOverload):
|
||||
func = torch._decomp.decomposition_table.get(orig_func, None)
|
||||
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
|
||||
default = getattr(orig_func, "default", None)
|
||||
if default is not None:
|
||||
func = torch._decomp.decomposition_table.get(default, None)
|
||||
|
||||
if func is not None:
|
||||
# If the ref exists query whether we should use it or not
|
||||
|
@ -2,7 +2,6 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
|
||||
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple
|
||||
|
||||
@ -214,7 +213,7 @@ def out_wrapper(
|
||||
*out_names: str,
|
||||
exact_dtype: bool = False,
|
||||
pass_is_out: bool = False,
|
||||
preserve_memory_format=False,
|
||||
preserve_memory_format: bool = False,
|
||||
):
|
||||
# The wrapped function needs to convert the output parameters to ensure
|
||||
# compatibility between the Python API (which always uses "out" as the
|
||||
@ -320,12 +319,18 @@ def out_wrapper(
|
||||
sig.empty,
|
||||
out_type,
|
||||
)
|
||||
params = chain(sig.parameters.values(), (out_param,))
|
||||
params = *sig.parameters.values(), out_param
|
||||
|
||||
# If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear
|
||||
# after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by
|
||||
# Parameter.kind guarantees that all the parameters are in legal order.
|
||||
params = sorted(params, key=lambda p: p.kind)
|
||||
|
||||
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
||||
parameters=params, return_annotation=return_type # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
_fn.__annotations__ = fn.__annotations__
|
||||
_fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
|
||||
_fn.__annotations__["out"] = out_type
|
||||
_fn.__annotations__["return"] = return_type
|
||||
|
||||
|
@ -16,6 +16,7 @@ import torch
|
||||
|
||||
import torch._prims as prims
|
||||
import torch._prims_common as utils
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import sym_float, sym_int
|
||||
from torch._prims_common import (
|
||||
BoolLike,
|
||||
@ -2198,18 +2199,25 @@ def _make_copy_from_view(fn):
|
||||
"""
|
||||
Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
|
||||
"""
|
||||
name = fn.__name__
|
||||
fn = out_wrapper()(fn)
|
||||
aten_fn = getattr(aten, fn.__name__)
|
||||
annotations = getattr(fn, "__annotations__", {})
|
||||
fn = out_wrapper()(aten_fn)
|
||||
|
||||
@wraps(fn)
|
||||
def _fn(*args, out=None, **kwargs):
|
||||
result = fn(*args, out=out, **kwargs)
|
||||
if out is None:
|
||||
return result.clone(memory_format=torch.contiguous_format)
|
||||
return result
|
||||
if out is not None:
|
||||
return result
|
||||
|
||||
copy_name = f"{name}_copy"
|
||||
return pytree.tree_map(
|
||||
lambda x: x.clone(memory_format=torch.contiguous_format),
|
||||
result,
|
||||
)
|
||||
|
||||
copy_name = f"{fn.__name__}_copy"
|
||||
_fn.__name__ = copy_name
|
||||
_fn = register_decomposition(getattr(aten, copy_name))(_fn)
|
||||
_fn.__annotations__.update(annotations)
|
||||
register_decomposition(getattr(aten, copy_name))(_fn)
|
||||
return _fn
|
||||
|
||||
|
||||
@ -2671,9 +2679,6 @@ def as_strided(
|
||||
return prims.as_strided(a, size, stride, storage_offset_int)
|
||||
|
||||
|
||||
as_strided_copy = _make_copy_from_view(as_strided)
|
||||
|
||||
|
||||
@register_decomposition(aten.as_strided_scatter)
|
||||
@out_wrapper()
|
||||
def as_strided_scatter(
|
||||
@ -3071,11 +3076,6 @@ def narrow(
|
||||
return prims.slice_in_dim(a, start, start + length, axis=dim)
|
||||
|
||||
|
||||
# TODO: This must return a sparse tensor if the input is sparse, but refs have
|
||||
# no sparse support. See narrow_copy_sparse in core.
|
||||
narrow_copy = _make_copy_from_view(narrow)
|
||||
|
||||
|
||||
def _normalize(
|
||||
a: Tensor, norm_dims: DimsType, eps: float
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
@ -4322,9 +4322,6 @@ def diagonal(
|
||||
return result
|
||||
|
||||
|
||||
diagonal_copy = _make_copy_from_view(diagonal)
|
||||
|
||||
|
||||
@register_decomposition(aten.diag_embed)
|
||||
@out_wrapper()
|
||||
def diag_embed(
|
||||
@ -4480,9 +4477,6 @@ def alias(a: TensorLikeType) -> TensorLikeType:
|
||||
return prims.view_of(a)
|
||||
|
||||
|
||||
alias_copy = _make_copy_from_view(alias)
|
||||
|
||||
|
||||
@register_decomposition(aten.transpose)
|
||||
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
|
||||
_dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
|
||||
@ -6317,6 +6311,13 @@ geometric_ = _make_inplace(geometric)
|
||||
log_normal_ = _make_inplace(log_normal)
|
||||
zero_ = _make_inplace(zero)
|
||||
|
||||
alias_copy = _make_copy_from_view(aten.alias)
|
||||
as_strided_copy = _make_copy_from_view(aten.as_strided)
|
||||
diagonal_copy = _make_copy_from_view(aten.diagonal)
|
||||
# TODO: This must return a sparse tensor if the input is sparse, but refs have
|
||||
# no sparse support. See narrow_copy_sparse in core.
|
||||
narrow_copy = _make_copy_from_view(aten.narrow)
|
||||
|
||||
|
||||
# xref: isStorage in torch/csrc/DynamicTypes.cpp
|
||||
def _isStorage(obj):
|
||||
|
@ -13216,7 +13216,8 @@ op_db: List[OpInfo] = [
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
sample_inputs_func=sample_inputs_alias_copy,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True),
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=True),
|
||||
BinaryUfuncInfo('eq',
|
||||
ref=np.equal,
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
|
||||
@ -14624,6 +14625,7 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
)),
|
||||
OpInfo('as_strided_scatter',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
@ -16861,9 +16863,6 @@ op_db: List[OpInfo] = [
|
||||
# https://github.com/pytorch/pytorch/issues/84577
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
|
||||
# Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels
|
||||
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
|
||||
# Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace',
|
||||
device_type='cuda'),
|
||||
@ -23569,6 +23568,7 @@ python_ref_db = [
|
||||
PythonRefInfo(
|
||||
"_refs.alias_copy",
|
||||
torch_opinfo_name="alias_copy",
|
||||
supports_out=True,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.atleast_1d",
|
||||
@ -23600,6 +23600,7 @@ python_ref_db = [
|
||||
PythonRefInfo(
|
||||
"_refs.as_strided_copy",
|
||||
torch_opinfo_name="as_strided_copy",
|
||||
supports_out=True,
|
||||
# FIXME: doesn't support chalf
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
skips=(
|
||||
@ -23607,6 +23608,8 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
# The view function this decompose into does not have a ref
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
@ -23694,6 +23697,7 @@ python_ref_db = [
|
||||
PythonRefInfo(
|
||||
"_refs.diagonal_copy",
|
||||
torch_opinfo_name="diagonal_copy",
|
||||
supports_out=True,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.diagonal_scatter",
|
||||
@ -23756,6 +23760,10 @@ python_ref_db = [
|
||||
torch_opinfo_name="narrow_copy",
|
||||
supports_out=True,
|
||||
error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True),
|
||||
skips=(
|
||||
# The view function this decompose into does not have a ref
|
||||
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.nn.functional.group_norm",
|
||||
@ -24404,6 +24412,12 @@ python_ref_db = [
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure, 'TestCommon', 'test_python_ref'
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"),
|
||||
'TestCommon',
|
||||
'test_python_ref_executor',
|
||||
dtypes=(torch.complex64, torch.complex128),
|
||||
),
|
||||
],
|
||||
),
|
||||
PythonRefInfo(
|
||||
|
Reference in New Issue
Block a user