mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200 Approved by: https://github.com/bobrenjc93
458 lines
15 KiB
Python
458 lines
15 KiB
Python
# mypy: ignore-errors
|
|
|
|
import unittest
|
|
from functools import partial
|
|
from itertools import product
|
|
from typing import Callable
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
from torch.testing._internal.common_dtype import floating_types
|
|
from torch.testing._internal.common_utils import TEST_SCIPY
|
|
from torch.testing._internal.opinfo.core import (
|
|
DecorateInfo,
|
|
ErrorInput,
|
|
OpInfo,
|
|
SampleInput,
|
|
)
|
|
|
|
|
|
if TEST_SCIPY:
|
|
import scipy.signal
|
|
|
|
|
|
def sample_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
|
|
r"""Base function used to create sample inputs for windows.
|
|
|
|
For additional required args you should use *args, as well as **kwargs for
|
|
additional keyword arguments.
|
|
"""
|
|
|
|
# Tests window sizes up to 5 samples.
|
|
for size, sym in product(range(6), (True, False)):
|
|
yield SampleInput(
|
|
size,
|
|
*args,
|
|
sym=sym,
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def reference_inputs_window(op_info, device, dtype, requires_grad, *args, **kwargs):
|
|
r"""Reference inputs function to use for windows which have a common signature, i.e.,
|
|
window size and sym only.
|
|
|
|
Implement other special functions for windows that have a specific signature.
|
|
See exponential and gaussian windows for instance.
|
|
"""
|
|
yield from sample_inputs_window(
|
|
op_info, device, dtype, requires_grad, *args, **kwargs
|
|
)
|
|
|
|
cases = (8, 16, 32, 64, 128, 256)
|
|
|
|
for size in cases:
|
|
yield SampleInput(size, sym=False)
|
|
yield SampleInput(size, sym=True)
|
|
|
|
|
|
def reference_inputs_exponential_window(
|
|
op_info, device, dtype, requires_grad, **kwargs
|
|
):
|
|
yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
|
|
|
|
cases = (
|
|
(8, {"center": 4, "tau": 0.5}),
|
|
(16, {"center": 8, "tau": 2.5}),
|
|
(32, {"center": 16, "tau": 43.5}),
|
|
(64, {"center": 20, "tau": 3.7}),
|
|
(128, {"center": 62, "tau": 99}),
|
|
(256, {"tau": 10}),
|
|
)
|
|
|
|
for size, kw in cases:
|
|
yield SampleInput(size, sym=False, **kw)
|
|
kw["center"] = None
|
|
yield SampleInput(size, sym=True, **kw)
|
|
|
|
|
|
def reference_inputs_gaussian_window(op_info, device, dtype, requires_grad, **kwargs):
|
|
yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
|
|
|
|
cases = (
|
|
(8, {"std": 0.1}),
|
|
(16, {"std": 1.2}),
|
|
(32, {"std": 2.1}),
|
|
(64, {"std": 3.9}),
|
|
(128, {"std": 4.5}),
|
|
(256, {"std": 10}),
|
|
)
|
|
|
|
for size, kw in cases:
|
|
yield SampleInput(size, sym=False, **kw)
|
|
yield SampleInput(size, sym=True, **kw)
|
|
|
|
|
|
def reference_inputs_kaiser_window(op_info, device, dtype, requires_grad, **kwargs):
|
|
yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
|
|
|
|
cases = (
|
|
(8, {"beta": 2}),
|
|
(16, {"beta": 12}),
|
|
(32, {"beta": 30}),
|
|
(64, {"beta": 35}),
|
|
(128, {"beta": 41.2}),
|
|
(256, {"beta": 100}),
|
|
)
|
|
|
|
for size, kw in cases:
|
|
yield SampleInput(size, sym=False, **kw)
|
|
yield SampleInput(size, sym=True, **kw)
|
|
|
|
|
|
def reference_inputs_general_cosine_window(
|
|
op_info, device, dtype, requires_grad, **kwargs
|
|
):
|
|
yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
|
|
|
|
cases = (
|
|
(8, {"a": [0.5, 0.5]}),
|
|
(16, {"a": [0.46, 0.54]}),
|
|
(32, {"a": [0.46, 0.23, 0.31]}),
|
|
(64, {"a": [0.5]}),
|
|
(128, {"a": [0.1, 0.8, 0.05, 0.05]}),
|
|
(256, {"a": [0.2, 0.2, 0.2, 0.2, 0.2]}),
|
|
)
|
|
|
|
for size, kw in cases:
|
|
yield SampleInput(size, sym=False, **kw)
|
|
yield SampleInput(size, sym=True, **kw)
|
|
|
|
|
|
def reference_inputs_general_hamming_window(
|
|
op_info, device, dtype, requires_grad, **kwargs
|
|
):
|
|
yield from sample_inputs_window(op_info, device, dtype, requires_grad, **kwargs)
|
|
|
|
cases = (
|
|
(8, {"alpha": 0.54}),
|
|
(16, {"alpha": 0.5}),
|
|
(32, {"alpha": 0.23}),
|
|
(64, {"alpha": 0.8}),
|
|
(128, {"alpha": 0.9}),
|
|
(256, {"alpha": 0.05}),
|
|
)
|
|
|
|
for size, kw in cases:
|
|
yield SampleInput(size, sym=False, **kw)
|
|
yield SampleInput(size, sym=True, **kw)
|
|
|
|
|
|
def error_inputs_window(op_info, device, *args, **kwargs):
|
|
# Tests for windows that have a negative size
|
|
yield ErrorInput(
|
|
SampleInput(-1, *args, dtype=torch.float32, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="requires non-negative window length, got M=-1",
|
|
)
|
|
|
|
# Tests for window tensors that are not torch.strided, for instance, torch.sparse_coo.
|
|
yield ErrorInput(
|
|
SampleInput(
|
|
3,
|
|
*args,
|
|
layout=torch.sparse_coo,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
**kwargs,
|
|
),
|
|
error_type=ValueError,
|
|
error_regex="is implemented for strided tensors only, got: torch.sparse_coo",
|
|
)
|
|
|
|
# Tests for window tensors that are not floating point dtypes, for instance, torch.long.
|
|
yield ErrorInput(
|
|
SampleInput(3, *args, dtype=torch.long, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="expects float32 or float64 dtypes, got: torch.int64",
|
|
)
|
|
|
|
# Tests for window tensors that are bfloat16
|
|
yield ErrorInput(
|
|
SampleInput(3, *args, dtype=torch.bfloat16, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="expects float32 or float64 dtypes, got: torch.bfloat16",
|
|
)
|
|
|
|
# Tests for window tensors that are float16
|
|
yield ErrorInput(
|
|
SampleInput(3, *args, dtype=torch.float16, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="expects float32 or float64 dtypes, got: torch.float16",
|
|
)
|
|
|
|
|
|
def error_inputs_exponential_window(op_info, device, **kwargs):
|
|
# Yield common error inputs
|
|
yield from error_inputs_window(op_info, device, **kwargs)
|
|
|
|
# Tests for negative decay values.
|
|
yield ErrorInput(
|
|
SampleInput(3, tau=-1, dtype=torch.float32, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="Tau must be positive, got: -1 instead.",
|
|
)
|
|
|
|
# Tests for symmetric windows and a given center value.
|
|
yield ErrorInput(
|
|
SampleInput(3, center=1, sym=True, dtype=torch.float32, device=device),
|
|
error_type=ValueError,
|
|
error_regex="Center must be None for symmetric windows",
|
|
)
|
|
|
|
|
|
def error_inputs_gaussian_window(op_info, device, **kwargs):
|
|
# Yield common error inputs
|
|
yield from error_inputs_window(op_info, device, std=0.5, **kwargs)
|
|
|
|
# Tests for negative standard deviations
|
|
yield ErrorInput(
|
|
SampleInput(3, std=-1, dtype=torch.float32, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="Standard deviation must be positive, got: -1 instead.",
|
|
)
|
|
|
|
|
|
def error_inputs_kaiser_window(op_info, device, **kwargs):
|
|
# Yield common error inputs
|
|
yield from error_inputs_window(op_info, device, beta=12, **kwargs)
|
|
|
|
# Tests for negative beta
|
|
yield ErrorInput(
|
|
SampleInput(3, beta=-1, dtype=torch.float32, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="beta must be non-negative, got: -1 instead.",
|
|
)
|
|
|
|
|
|
def error_inputs_general_cosine_window(op_info, device, **kwargs):
|
|
# Yield common error inputs
|
|
yield from error_inputs_window(op_info, device, a=[0.54, 0.46], **kwargs)
|
|
|
|
# Tests for negative beta
|
|
yield ErrorInput(
|
|
SampleInput(3, a=None, dtype=torch.float32, device=device, **kwargs),
|
|
error_type=TypeError,
|
|
error_regex="Coefficients must be a list/tuple",
|
|
)
|
|
|
|
yield ErrorInput(
|
|
SampleInput(3, a=[], dtype=torch.float32, device=device, **kwargs),
|
|
error_type=ValueError,
|
|
error_regex="Coefficients cannot be empty",
|
|
)
|
|
|
|
|
|
def reference_signal_window(fn: Callable):
|
|
r"""Wrapper for scipy signal window references.
|
|
|
|
Discards keyword arguments for window reference functions that don't have a matching signature with
|
|
torch, e.g., gaussian window.
|
|
"""
|
|
|
|
def _fn(
|
|
*args,
|
|
dtype=numpy.float64,
|
|
device=None,
|
|
layout=torch.strided,
|
|
requires_grad=False,
|
|
**kwargs,
|
|
):
|
|
r"""The unused arguments are defined to disregard those values"""
|
|
return fn(*args, **kwargs).astype(dtype)
|
|
|
|
return _fn
|
|
|
|
|
|
def make_signal_windows_opinfo(
|
|
name: str,
|
|
ref: Callable,
|
|
sample_inputs_func: Callable,
|
|
reference_inputs_func: Callable,
|
|
error_inputs_func: Callable,
|
|
*,
|
|
skips: tuple[DecorateInfo, ...] = (),
|
|
):
|
|
r"""Helper function to create OpInfo objects related to different windows."""
|
|
return OpInfo(
|
|
name=name,
|
|
ref=ref if TEST_SCIPY else None,
|
|
dtypes=floating_types(),
|
|
sample_inputs_func=sample_inputs_func,
|
|
reference_inputs_func=reference_inputs_func,
|
|
error_inputs_func=error_inputs_func,
|
|
supports_out=False,
|
|
supports_autograd=False,
|
|
skips=(
|
|
# TODO: same as this?
|
|
# https://github.com/pytorch/pytorch/issues/81774
|
|
# also see: arange, new_full
|
|
# fails to match any schemas despite working in the interpreter
|
|
DecorateInfo(
|
|
unittest.expectedFailure,
|
|
"TestOperatorSignatures",
|
|
"test_get_torch_func_signature_exhaustive",
|
|
),
|
|
# fails to match any schemas despite working in the interpreter
|
|
DecorateInfo(
|
|
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
|
|
),
|
|
# skip these tests since we have non tensor input
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"), "TestCommon", "test_noncontiguous_samples"
|
|
),
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"),
|
|
"TestCommon",
|
|
"test_variant_consistency_eager",
|
|
),
|
|
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_conj_view"),
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"), "TestMathBits", "test_neg_conj_view"
|
|
),
|
|
DecorateInfo(unittest.skip("Skipped!"), "TestMathBits", "test_neg_view"),
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"),
|
|
"TestVmapOperatorsOpInfo",
|
|
"test_vmap_exhaustive",
|
|
),
|
|
DecorateInfo(
|
|
unittest.skip("Skipped!"),
|
|
"TestVmapOperatorsOpInfo",
|
|
"test_op_has_batch_rule",
|
|
),
|
|
DecorateInfo(
|
|
unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
|
|
"TestCommon",
|
|
"test_numpy_ref_mps",
|
|
),
|
|
*skips,
|
|
),
|
|
)
|
|
|
|
|
|
op_db: list[OpInfo] = [
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.hamming",
|
|
ref=reference_signal_window(scipy.signal.windows.hamming)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.hann",
|
|
ref=reference_signal_window(scipy.signal.windows.hann) if TEST_SCIPY else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.bartlett",
|
|
ref=reference_signal_window(scipy.signal.windows.bartlett)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.blackman",
|
|
ref=reference_signal_window(scipy.signal.windows.blackman)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.cosine",
|
|
ref=reference_signal_window(scipy.signal.windows.cosine)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.exponential",
|
|
ref=reference_signal_window(scipy.signal.windows.exponential)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=partial(sample_inputs_window, tau=2.78),
|
|
reference_inputs_func=partial(reference_inputs_exponential_window, tau=2.78),
|
|
error_inputs_func=error_inputs_exponential_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.gaussian",
|
|
ref=reference_signal_window(scipy.signal.windows.gaussian)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=partial(sample_inputs_window, std=1.92),
|
|
reference_inputs_func=partial(reference_inputs_gaussian_window, std=1.92),
|
|
error_inputs_func=error_inputs_gaussian_window,
|
|
skips=(
|
|
DecorateInfo(
|
|
unittest.skip("Buggy on MPS for now (mistakenly promotes to float64)"),
|
|
"TestCommon",
|
|
"test_numpy_ref_mps",
|
|
),
|
|
),
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.kaiser",
|
|
ref=reference_signal_window(scipy.signal.windows.kaiser)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=partial(sample_inputs_window, beta=12.0),
|
|
reference_inputs_func=partial(reference_inputs_kaiser_window, beta=12.0),
|
|
error_inputs_func=error_inputs_kaiser_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.general_cosine",
|
|
ref=reference_signal_window(scipy.signal.windows.general_cosine)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=partial(sample_inputs_window, a=[0.54, 0.46]),
|
|
reference_inputs_func=partial(
|
|
reference_inputs_general_cosine_window, a=[0.54, 0.46]
|
|
),
|
|
error_inputs_func=error_inputs_general_cosine_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.general_hamming",
|
|
ref=reference_signal_window(scipy.signal.windows.general_hamming)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=partial(sample_inputs_window, alpha=0.54),
|
|
reference_inputs_func=partial(
|
|
reference_inputs_general_hamming_window, alpha=0.54
|
|
),
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
make_signal_windows_opinfo(
|
|
name="signal.windows.nuttall",
|
|
ref=reference_signal_window(scipy.signal.windows.nuttall)
|
|
if TEST_SCIPY
|
|
else None,
|
|
sample_inputs_func=sample_inputs_window,
|
|
reference_inputs_func=reference_inputs_window,
|
|
error_inputs_func=error_inputs_window,
|
|
),
|
|
]
|