Fix out_wrapper, _make_copy_from_view to handle all signatures (#130937)

* See #128416 and #129476
* Simplify xskip lists in test/functorch/test_ops.py
* Add supports_out=True to OpInfos for copy ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130937
Approved by: https://github.com/peterbell10
This commit is contained in:
Tom Ritchford
2024-07-18 13:15:07 +00:00
committed by PyTorch MergeBot
parent b193894b94
commit f628813066
6 changed files with 56 additions and 42 deletions

View File

@ -934,7 +934,6 @@ class TestOperators(TestCase):
skip("ormqr"), # Takes too long
xfail("as_strided"), # incorrect output
xfail("as_strided", "partial_views"), # incorrect output
xfail("as_strided_copy"), # incorrect output
xfail("as_strided_scatter"), # incorrect output
skip("bernoulli"), # calls random op
xfail("bfloat16"), # rank 4 tensor for channels_last
@ -1062,6 +1061,7 @@ class TestOperators(TestCase):
"test_vmapvjpvjp",
{
xfail("as_strided", "partial_views"),
xfail("as_strided_copy"),
},
)
def test_vmapvjpvjp(self, device, dtype, op):
@ -1175,7 +1175,6 @@ class TestOperators(TestCase):
xfail("nn.functional.max_unpool2d", "grad"),
xfail("sparse.sampled_addmm", ""),
xfail("sparse.mm", "reduce"),
xfail("as_strided_copy", ""), # calls as_strided
xfail("as_strided_scatter", ""), # calls as_strided
xfail("index_reduce", "prod"), # .item() call
# ---------------------------------------------------------------------
@ -1290,7 +1289,6 @@ class TestOperators(TestCase):
xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda)
skip("as_strided"), # Test runner cannot handle this
# requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
xfail("as_strided_copy"),
xfail("as_strided_scatter"),
xfail(
"nn.functional.gaussian_nll_loss"
@ -1343,6 +1341,7 @@ class TestOperators(TestCase):
"test_vmapjvpall",
vmapjvpall_fail.union(
{
xfail("as_strided_copy"),
decorate(
"linalg.det",
"singular",
@ -1430,11 +1429,6 @@ class TestOperators(TestCase):
xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled
}
).difference(
{
# as_strided_copy fails test_vmapvjp, succeeds here
xfail("as_strided_copy", ""),
}
),
)
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@ -1571,11 +1565,6 @@ class TestOperators(TestCase):
"index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled
}
).difference(
{
# as_strided_copy fails test_vmapvjp, succeeds here
xfail("as_strided_copy", ""),
}
),
)
def test_vmapvjp_has_batch_rule(self, device, dtype, op):

View File

@ -240,6 +240,7 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
),
xfail(
"alias_copy",
dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64),
reason="OnnxExporterError: Failed to export model",
),
xfail(

View File

@ -130,6 +130,10 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
# see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
if func is None and isinstance(orig_func, torch._ops.OpOverload):
func = torch._decomp.decomposition_table.get(orig_func, None)
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
default = getattr(orig_func, "default", None)
if default is not None:
func = torch._decomp.decomposition_table.get(default, None)
if func is not None:
# If the ref exists query whether we should use it or not

View File

@ -2,7 +2,6 @@
import inspect
import warnings
from functools import wraps
from itertools import chain
from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple
@ -214,7 +213,7 @@ def out_wrapper(
*out_names: str,
exact_dtype: bool = False,
pass_is_out: bool = False,
preserve_memory_format=False,
preserve_memory_format: bool = False,
):
# The wrapped function needs to convert the output parameters to ensure
# compatibility between the Python API (which always uses "out" as the
@ -320,12 +319,18 @@ def out_wrapper(
sig.empty,
out_type,
)
params = chain(sig.parameters.values(), (out_param,))
params = *sig.parameters.values(), out_param
# If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear
# after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by
# Parameter.kind guarantees that all the parameters are in legal order.
params = sorted(params, key=lambda p: p.kind)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=return_type # type: ignore[arg-type]
)
_fn.__annotations__ = fn.__annotations__
_fn.__annotations__ = dict(getattr(fn, "__annotations__", {}))
_fn.__annotations__["out"] = out_type
_fn.__annotations__["return"] = return_type

View File

@ -16,6 +16,7 @@ import torch
import torch._prims as prims
import torch._prims_common as utils
import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
BoolLike,
@ -2198,18 +2199,25 @@ def _make_copy_from_view(fn):
"""
Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
"""
name = fn.__name__
fn = out_wrapper()(fn)
aten_fn = getattr(aten, fn.__name__)
annotations = getattr(fn, "__annotations__", {})
fn = out_wrapper()(aten_fn)
@wraps(fn)
def _fn(*args, out=None, **kwargs):
result = fn(*args, out=out, **kwargs)
if out is None:
return result.clone(memory_format=torch.contiguous_format)
return result
if out is not None:
return result
copy_name = f"{name}_copy"
return pytree.tree_map(
lambda x: x.clone(memory_format=torch.contiguous_format),
result,
)
copy_name = f"{fn.__name__}_copy"
_fn.__name__ = copy_name
_fn = register_decomposition(getattr(aten, copy_name))(_fn)
_fn.__annotations__.update(annotations)
register_decomposition(getattr(aten, copy_name))(_fn)
return _fn
@ -2671,9 +2679,6 @@ def as_strided(
return prims.as_strided(a, size, stride, storage_offset_int)
as_strided_copy = _make_copy_from_view(as_strided)
@register_decomposition(aten.as_strided_scatter)
@out_wrapper()
def as_strided_scatter(
@ -3071,11 +3076,6 @@ def narrow(
return prims.slice_in_dim(a, start, start + length, axis=dim)
# TODO: This must return a sparse tensor if the input is sparse, but refs have
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(narrow)
def _normalize(
a: Tensor, norm_dims: DimsType, eps: float
) -> Tuple[Tensor, Tensor, Tensor]:
@ -4322,9 +4322,6 @@ def diagonal(
return result
diagonal_copy = _make_copy_from_view(diagonal)
@register_decomposition(aten.diag_embed)
@out_wrapper()
def diag_embed(
@ -4480,9 +4477,6 @@ def alias(a: TensorLikeType) -> TensorLikeType:
return prims.view_of(a)
alias_copy = _make_copy_from_view(alias)
@register_decomposition(aten.transpose)
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
_dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc]
@ -6317,6 +6311,13 @@ geometric_ = _make_inplace(geometric)
log_normal_ = _make_inplace(log_normal)
zero_ = _make_inplace(zero)
alias_copy = _make_copy_from_view(aten.alias)
as_strided_copy = _make_copy_from_view(aten.as_strided)
diagonal_copy = _make_copy_from_view(aten.diagonal)
# TODO: This must return a sparse tensor if the input is sparse, but refs have
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
# xref: isStorage in torch/csrc/DynamicTypes.cpp
def _isStorage(obj):

View File

@ -13216,7 +13216,8 @@ op_db: List[OpInfo] = [
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
sample_inputs_func=sample_inputs_alias_copy,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
supports_fwgrad_bwgrad=True,
supports_out=True),
BinaryUfuncInfo('eq',
ref=np.equal,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
@ -14624,6 +14625,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
)),
OpInfo('as_strided_scatter',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@ -16861,9 +16863,6 @@ op_db: List[OpInfo] = [
# https://github.com/pytorch/pytorch/issues/84577
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
# Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
# Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace',
device_type='cuda'),
@ -23569,6 +23568,7 @@ python_ref_db = [
PythonRefInfo(
"_refs.alias_copy",
torch_opinfo_name="alias_copy",
supports_out=True,
),
PythonRefInfo(
"_refs.atleast_1d",
@ -23600,6 +23600,7 @@ python_ref_db = [
PythonRefInfo(
"_refs.as_strided_copy",
torch_opinfo_name="as_strided_copy",
supports_out=True,
# FIXME: doesn't support chalf
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
skips=(
@ -23607,6 +23608,8 @@ python_ref_db = [
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
# The view function this decompose into does not have a ref
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
),
),
PythonRefInfo(
@ -23694,6 +23697,7 @@ python_ref_db = [
PythonRefInfo(
"_refs.diagonal_copy",
torch_opinfo_name="diagonal_copy",
supports_out=True,
),
PythonRefInfo(
"_refs.diagonal_scatter",
@ -23756,6 +23760,10 @@ python_ref_db = [
torch_opinfo_name="narrow_copy",
supports_out=True,
error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True),
skips=(
# The view function this decompose into does not have a ref
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
),
),
PythonRefInfo(
"_refs.nn.functional.group_norm",
@ -24404,6 +24412,12 @@ python_ref_db = [
DecorateInfo(
unittest.expectedFailure, 'TestCommon', 'test_python_ref'
),
DecorateInfo(
unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"),
'TestCommon',
'test_python_ref_executor',
dtypes=(torch.complex64, torch.complex128),
),
],
),
PythonRefInfo(