Revert "stft: remove non-center overload and python functional wrapper"

This reverts commit d23ecbfc9ac157560611b242f015743f189dbf48.

Reverted https://github.com/pytorch/pytorch/pull/73434 on behalf of https://github.com/albanD
This commit is contained in:
PyTorch MergeBot
2022-05-09 19:59:43 +00:00
parent ce3857e73c
commit 2c5bf12584
16 changed files with 177 additions and 186 deletions

View File

@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)

View File

@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
}
}
Tensor stft(
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
const bool normalized,
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
return at::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
return_complexOpt);
}
// Create complex tensor from the old style of real tensor with size=(..., 2)
// This is to support istft in the transition to requiring complex input.
// NOTE: This may return a view of the input tensor, or might clone if necessary
@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
#undef REPR
}
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const optional<bool> onesidedOpt,
const optional<int64_t> lengthOpt) {
return at::native::istft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
onesidedOpt, lengthOpt, /*return_complex=*/false);
}
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
const auto input_sizes = input.sizes();
const auto input_strides = input.strides();

View File

@ -4320,7 +4320,12 @@
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
# Overload without center & pad mode, needed for forward-compatibility
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor

View File

@ -12,7 +12,7 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 11;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// Bump the version number to 10 to update aten::gelu and
// and aten::gelu.out to support the new approximate kwarg.
// (see: https://github.com/pytorch/pytorch/pull/61439)
// 4) [02/25/2022]
// Bump version number to 11 to update aten::stft to do padding in ATen
constexpr uint64_t kProducedFileFormatVersion = 11L;
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif

View File

@ -118,7 +118,6 @@ ALLOW_LIST = [
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
("aten::scatter_reduce.two", datetime.date(2022, 4, 15)),
("aten::stft", datetime.date(2022, 6, 1)),
("aten::_s_where", datetime.date(2022, 9, 30)),
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),

View File

@ -57,11 +57,3 @@ class TestVersionedGeluOutV9(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out)
class TestVersionedStftV10(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, n_fft: int, window):
# calling aten::stft direct instead of torch.functional.stft
return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True)

View File

@ -96,7 +96,6 @@ ALL_MODULES = {
TestVersionedLogspaceOutV8(): "aten::logspace.out",
TestVersionedGeluV9(): "aten::gelu",
TestVersionedGeluOutV9(): "aten::gelu.out",
TestVersionedStftV10(): "aten::stft",
}
"""

View File

@ -540,20 +540,3 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)
def test_versioned_stft_v10(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v10_mobile_module = _load_for_lite_interpreter(buffer)
for in_dtype, window_dtype in product(
[torch.float32, torch.complex64], repeat=2):
input = torch.rand((100,), dtype=in_dtype)
window = torch.rand((10,), dtype=window_dtype)
n_fft = 10
output = v10_mobile_module(input, n_fft, window)
output_expected = torch.stft(input, n_fft=n_fft, window=window,
center=False, return_complex=True)
self.assertEqual(output, output_expected)

View File

@ -109,6 +109,7 @@ blocklist = [
"block_diag",
"norm",
"chain_matmul",
"stft",
"tensordot",
"split",
"unique_consecutive",

View File

@ -2,7 +2,7 @@ from collections import OrderedDict
import enum
import functools
from numbers import Number
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import copyreg
from copy import deepcopy
@ -545,6 +545,40 @@ class Tensor(torch._C._TensorBase):
else:
return LU, pivots
def stft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
onesided: Optional[bool] = None, return_complex: Optional[bool] = None):
r"""See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex
)
return torch.stft(self, n_fft, hop_length, win_length, window, center,
pad_mode, normalized, onesided, return_complex=return_complex)
def istft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, normalized: bool = False,
onesided: Optional[bool] = None, length: Optional[int] = None,
return_complex: bool = False):
r"""See :func:`torch.istft`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
return_complex=return_complex
)
return torch.istft(self, n_fft, hop_length, win_length, window, center,
normalized, onesided, length, return_complex=return_complex)
def resize(self, *sizes):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.resize, (self,), self, *sizes)

View File

@ -4752,21 +4752,16 @@ See :func:`torch.dsplit`
""")
add_docstr_all('stft',
"stft(n_fft, hop_length=None, win_length=None, window=None, center=True, "
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
""")
add_docstr_all('istft',
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor"
r"""
istft(n_fft, hop_length=None, win_length=None, window=None,
center=True, normalized=False, onesided=True, length=None) -> Tensor
See :func:`torch.istft`
""")

View File

@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() {
std::vector<Upgrader>({
Upgrader({0, 8, "logspace_out_0_8", 10})
})},
{std::string("aten::stft"),
std::vector<Upgrader>({
Upgrader({0, 10, "stft_0_10", 11})
})},
});
return operatorVersionMapForMobile;
}
@ -531,35 +527,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
}),
ByteCodeFunctionWithOperator({
mobile::Function::registerFunc(
"stft_0_10",
std::vector<Instruction>({
Instruction{OpCode::STOREN, 1, 8},
Instruction{OpCode::MOVE, 1, 0},
Instruction{OpCode::MOVE, 2, 0},
Instruction{OpCode::MOVE, 3, 0},
Instruction{OpCode::MOVE, 4, 0},
Instruction{OpCode::MOVE, 5, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::MOVE, 6, 0},
Instruction{OpCode::MOVE, 7, 0},
Instruction{OpCode::MOVE, 8, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::RET, 0, 0},
}), // instructions list,
std::vector<c10::IValue>({
c10::IValue("reflect"),
c10::IValue(false),
}), // constants list,
std::vector<c10::TypePtr>(), // types list,
8
),
std::vector<OperatorString>({
OperatorString({"aten::stft", "", 10}),
}), // operators list
}),
});
for (const auto& upgrader_function : upgrader_function_list) {
for (const auto& op : upgrader_function.operators) {

View File

@ -15,17 +15,6 @@ namespace torch {
namespace jit {
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
{"stft_0_10", R"SCRIPT(
def stft_0_10(
self: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
normalized: bool = False, onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
return torch.stft(
self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=False, normalized=normalized, onesided=onesided,
return_complex=return_complex)
)SCRIPT"},
{"logspace_0_8", R"SCRIPT(
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
device: Optional[Device], pin_memory: Optional[bool]):

View File

@ -16,11 +16,7 @@ static bool isVersionMapSorted = false;
// Note for developers: The list of upgraders should be SORTED
// by the version number where the upgrader is registered.
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
{{"aten::stft",
{{11,
"stft_0_10",
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
{"aten::logspace",
{{"aten::logspace",
{{9,
"logspace_0_8",
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},

View File

@ -4,7 +4,7 @@ from typing import (
import torch
from torch._C import _add_docstr
import torch.nn.functional
import torch.nn.functional as F
from ._lowrank import svd_lowrank, pca_lowrank
from .overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
@ -478,13 +478,12 @@ def _meshgrid(*tensors, indexing: Optional[str]):
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
stft = _add_docstr(
torch.stft,
"stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
r"""
Short-time Fourier transform (STFT).
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
r"""Short-time Fourier transform (STFT).
.. warning::
From version 1.8.0, :attr:`return_complex` must always be given
@ -590,9 +589,22 @@ Args:
Returns:
Tensor: A tensor containing the STFT result with shape described above
""")
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
stft.__module__ = "torch.functional"
"""
if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
if center:
signal_dim = input.dim()
extended_shape = [1] * (3 - signal_dim) + list(input.size())
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
normalized, onesided, return_complex)
istft = _add_docstr(