mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "stft: remove non-center overload and python functional wrapper"
This reverts commit d23ecbfc9ac157560611b242f015743f189dbf48. Reverted https://github.com/pytorch/pytorch/pull/73434 on behalf of https://github.com/albanD
This commit is contained in:
@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
||||
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
|
||||
|
@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
|
||||
}
|
||||
}
|
||||
|
||||
Tensor stft(
|
||||
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
|
||||
const bool normalized,
|
||||
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
|
||||
return at::stft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
|
||||
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
|
||||
return_complexOpt);
|
||||
}
|
||||
|
||||
// Create complex tensor from the old style of real tensor with size=(..., 2)
|
||||
// This is to support istft in the transition to requiring complex input.
|
||||
// NOTE: This may return a view of the input tensor, or might clone if necessary
|
||||
@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
|
||||
#undef REPR
|
||||
}
|
||||
|
||||
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const Tensor& window,
|
||||
const bool center, const bool normalized, const optional<bool> onesidedOpt,
|
||||
const optional<int64_t> lengthOpt) {
|
||||
return at::native::istft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
|
||||
onesidedOpt, lengthOpt, /*return_complex=*/false);
|
||||
}
|
||||
|
||||
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
|
||||
const auto input_sizes = input.sizes();
|
||||
const auto input_strides = input.strides();
|
||||
|
@ -4320,7 +4320,12 @@
|
||||
|
||||
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
# Overload without center & pad mode, needed for forward-compatibility
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
variants: function, method
|
||||
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
|
||||
|
||||
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
|
||||
|
@ -12,7 +12,7 @@ namespace serialize {
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 11;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
|
||||
#else
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
#endif
|
||||
@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
// Bump the version number to 10 to update aten::gelu and
|
||||
// and aten::gelu.out to support the new approximate kwarg.
|
||||
// (see: https://github.com/pytorch/pytorch/pull/61439)
|
||||
// 4) [02/25/2022]
|
||||
// Bump version number to 11 to update aten::stft to do padding in ATen
|
||||
constexpr uint64_t kProducedFileFormatVersion = 11L;
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
|
||||
#else
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
#endif
|
||||
|
@ -118,7 +118,6 @@ ALLOW_LIST = [
|
||||
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
|
||||
("aten::scatter_reduce.two", datetime.date(2022, 4, 15)),
|
||||
("aten::stft", datetime.date(2022, 6, 1)),
|
||||
("aten::_s_where", datetime.date(2022, 9, 30)),
|
||||
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
|
||||
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),
|
||||
|
Binary file not shown.
@ -57,11 +57,3 @@ class TestVersionedGeluOutV9(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = torch.zeros_like(x)
|
||||
return torch._C._nn.gelu(x, out=out)
|
||||
|
||||
class TestVersionedStftV10(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, n_fft: int, window):
|
||||
# calling aten::stft direct instead of torch.functional.stft
|
||||
return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True)
|
||||
|
@ -96,7 +96,6 @@ ALL_MODULES = {
|
||||
TestVersionedLogspaceOutV8(): "aten::logspace.out",
|
||||
TestVersionedGeluV9(): "aten::gelu",
|
||||
TestVersionedGeluOutV9(): "aten::gelu.out",
|
||||
TestVersionedStftV10(): "aten::stft",
|
||||
}
|
||||
|
||||
"""
|
||||
|
@ -540,20 +540,3 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
# "Upgraded" model should match the new version output
|
||||
self.assertEqual(output, output_current)
|
||||
|
||||
def test_versioned_stft_v10(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
v10_mobile_module = _load_for_lite_interpreter(buffer)
|
||||
|
||||
for in_dtype, window_dtype in product(
|
||||
[torch.float32, torch.complex64], repeat=2):
|
||||
input = torch.rand((100,), dtype=in_dtype)
|
||||
window = torch.rand((10,), dtype=window_dtype)
|
||||
n_fft = 10
|
||||
output = v10_mobile_module(input, n_fft, window)
|
||||
output_expected = torch.stft(input, n_fft=n_fft, window=window,
|
||||
center=False, return_complex=True)
|
||||
self.assertEqual(output, output_expected)
|
||||
|
@ -109,6 +109,7 @@ blocklist = [
|
||||
"block_diag",
|
||||
"norm",
|
||||
"chain_matmul",
|
||||
"stft",
|
||||
"tensordot",
|
||||
"split",
|
||||
"unique_consecutive",
|
||||
|
@ -2,7 +2,7 @@ from collections import OrderedDict
|
||||
import enum
|
||||
import functools
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import warnings
|
||||
import copyreg
|
||||
from copy import deepcopy
|
||||
@ -545,6 +545,40 @@ class Tensor(torch._C._TensorBase):
|
||||
else:
|
||||
return LU, pivots
|
||||
|
||||
def stft(self, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
|
||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
||||
onesided: Optional[bool] = None, return_complex: Optional[bool] = None):
|
||||
r"""See :func:`torch.stft`
|
||||
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with
|
||||
the previous signature may cause error or return incorrect result.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
|
||||
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
||||
onesided=onesided, return_complex=return_complex
|
||||
)
|
||||
return torch.stft(self, n_fft, hop_length, win_length, window, center,
|
||||
pad_mode, normalized, onesided, return_complex=return_complex)
|
||||
|
||||
def istft(self, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
|
||||
center: bool = True, normalized: bool = False,
|
||||
onesided: Optional[bool] = None, length: Optional[int] = None,
|
||||
return_complex: bool = False):
|
||||
r"""See :func:`torch.istft`"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
|
||||
return_complex=return_complex
|
||||
)
|
||||
return torch.istft(self, n_fft, hop_length, win_length, window, center,
|
||||
normalized, onesided, length, return_complex=return_complex)
|
||||
|
||||
def resize(self, *sizes):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.resize, (self,), self, *sizes)
|
||||
|
@ -4752,21 +4752,16 @@ See :func:`torch.dsplit`
|
||||
""")
|
||||
|
||||
add_docstr_all('stft',
|
||||
"stft(n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
|
||||
r"""
|
||||
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
|
||||
|
||||
See :func:`torch.stft`
|
||||
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with
|
||||
the previous signature may cause error or return incorrect result.
|
||||
""")
|
||||
|
||||
add_docstr_all('istft',
|
||||
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor"
|
||||
r"""
|
||||
istft(n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, normalized=False, onesided=True, length=None) -> Tensor
|
||||
|
||||
See :func:`torch.istft`
|
||||
""")
|
||||
|
@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() {
|
||||
std::vector<Upgrader>({
|
||||
Upgrader({0, 8, "logspace_out_0_8", 10})
|
||||
})},
|
||||
{std::string("aten::stft"),
|
||||
std::vector<Upgrader>({
|
||||
Upgrader({0, 10, "stft_0_10", 11})
|
||||
})},
|
||||
});
|
||||
return operatorVersionMapForMobile;
|
||||
}
|
||||
@ -531,35 +527,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
}),
|
||||
ByteCodeFunctionWithOperator({
|
||||
mobile::Function::registerFunc(
|
||||
"stft_0_10",
|
||||
std::vector<Instruction>({
|
||||
Instruction{OpCode::STOREN, 1, 8},
|
||||
Instruction{OpCode::MOVE, 1, 0},
|
||||
Instruction{OpCode::MOVE, 2, 0},
|
||||
Instruction{OpCode::MOVE, 3, 0},
|
||||
Instruction{OpCode::MOVE, 4, 0},
|
||||
Instruction{OpCode::MOVE, 5, 0},
|
||||
Instruction{OpCode::LOADC, 1, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::MOVE, 6, 0},
|
||||
Instruction{OpCode::MOVE, 7, 0},
|
||||
Instruction{OpCode::MOVE, 8, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::RET, 0, 0},
|
||||
}), // instructions list,
|
||||
std::vector<c10::IValue>({
|
||||
c10::IValue("reflect"),
|
||||
c10::IValue(false),
|
||||
}), // constants list,
|
||||
std::vector<c10::TypePtr>(), // types list,
|
||||
8
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::stft", "", 10}),
|
||||
}), // operators list
|
||||
}),
|
||||
});
|
||||
for (const auto& upgrader_function : upgrader_function_list) {
|
||||
for (const auto& op : upgrader_function.operators) {
|
||||
|
@ -15,17 +15,6 @@ namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
|
||||
{"stft_0_10", R"SCRIPT(
|
||||
def stft_0_10(
|
||||
self: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
||||
normalized: bool = False, onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None) -> Tensor:
|
||||
return torch.stft(
|
||||
self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=False, normalized=normalized, onesided=onesided,
|
||||
return_complex=return_complex)
|
||||
)SCRIPT"},
|
||||
{"logspace_0_8", R"SCRIPT(
|
||||
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
|
||||
device: Optional[Device], pin_memory: Optional[bool]):
|
||||
|
@ -16,11 +16,7 @@ static bool isVersionMapSorted = false;
|
||||
// Note for developers: The list of upgraders should be SORTED
|
||||
// by the version number where the upgrader is registered.
|
||||
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
|
||||
{{"aten::stft",
|
||||
{{11,
|
||||
"stft_0_10",
|
||||
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
|
||||
{"aten::logspace",
|
||||
{{"aten::logspace",
|
||||
{{9,
|
||||
"logspace_0_8",
|
||||
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
|
||||
|
@ -4,7 +4,7 @@ from typing import (
|
||||
|
||||
import torch
|
||||
from torch._C import _add_docstr
|
||||
import torch.nn.functional
|
||||
import torch.nn.functional as F
|
||||
from ._lowrank import svd_lowrank, pca_lowrank
|
||||
from .overrides import (
|
||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
||||
@ -478,13 +478,12 @@ def _meshgrid(*tensors, indexing: Optional[str]):
|
||||
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
stft = _add_docstr(
|
||||
torch.stft,
|
||||
"stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
|
||||
r"""
|
||||
|
||||
Short-time Fourier transform (STFT).
|
||||
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None) -> Tensor:
|
||||
r"""Short-time Fourier transform (STFT).
|
||||
|
||||
.. warning::
|
||||
From version 1.8.0, :attr:`return_complex` must always be given
|
||||
@ -590,9 +589,22 @@ Args:
|
||||
Returns:
|
||||
Tensor: A tensor containing the STFT result with shape described above
|
||||
|
||||
""")
|
||||
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
|
||||
stft.__module__ = "torch.functional"
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
||||
onesided=onesided, return_complex=return_complex)
|
||||
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
||||
# period is over for PR #73432
|
||||
if center:
|
||||
signal_dim = input.dim()
|
||||
extended_shape = [1] * (3 - signal_dim) + list(input.size())
|
||||
pad = int(n_fft // 2)
|
||||
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
||||
input = input.view(input.shape[-signal_dim:])
|
||||
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
|
||||
normalized, onesided, return_complex)
|
||||
|
||||
|
||||
istft = _add_docstr(
|
||||
|
Reference in New Issue
Block a user