mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "stft: remove non-center overload and python functional wrapper"
This reverts commit d23ecbfc9ac157560611b242f015743f189dbf48. Reverted https://github.com/pytorch/pytorch/pull/73434 on behalf of https://github.com/albanD
This commit is contained in:
@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
||||
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
|
||||
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)
|
||||
|
@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
|
||||
}
|
||||
}
|
||||
|
||||
Tensor stft(
|
||||
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
|
||||
const bool normalized,
|
||||
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
|
||||
return at::stft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
|
||||
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
|
||||
return_complexOpt);
|
||||
}
|
||||
|
||||
// Create complex tensor from the old style of real tensor with size=(..., 2)
|
||||
// This is to support istft in the transition to requiring complex input.
|
||||
// NOTE: This may return a view of the input tensor, or might clone if necessary
|
||||
@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
|
||||
#undef REPR
|
||||
}
|
||||
|
||||
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const Tensor& window,
|
||||
const bool center, const bool normalized, const optional<bool> onesidedOpt,
|
||||
const optional<int64_t> lengthOpt) {
|
||||
return at::native::istft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
|
||||
onesidedOpt, lengthOpt, /*return_complex=*/false);
|
||||
}
|
||||
|
||||
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
|
||||
const auto input_sizes = input.sizes();
|
||||
const auto input_strides = input.strides();
|
||||
|
@ -4320,7 +4320,12 @@
|
||||
|
||||
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
# Overload without center & pad mode, needed for forward-compatibility
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
variants: function, method
|
||||
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
|
||||
|
||||
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
|
||||
|
@ -12,7 +12,7 @@ namespace serialize {
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
|
||||
#if ENABLE_UPGRADERS
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 11;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
|
||||
#else
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
#endif
|
||||
@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
|
||||
// Bump the version number to 10 to update aten::gelu and
|
||||
// and aten::gelu.out to support the new approximate kwarg.
|
||||
// (see: https://github.com/pytorch/pytorch/pull/61439)
|
||||
// 4) [02/25/2022]
|
||||
// Bump version number to 11 to update aten::stft to do padding in ATen
|
||||
constexpr uint64_t kProducedFileFormatVersion = 11L;
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
|
||||
#else
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
#endif
|
||||
|
@ -118,7 +118,6 @@ ALLOW_LIST = [
|
||||
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
|
||||
("aten::scatter_reduce.two", datetime.date(2022, 4, 15)),
|
||||
("aten::stft", datetime.date(2022, 6, 1)),
|
||||
("aten::_s_where", datetime.date(2022, 9, 30)),
|
||||
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
|
||||
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),
|
||||
|
Binary file not shown.
@ -57,11 +57,3 @@ class TestVersionedGeluOutV9(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = torch.zeros_like(x)
|
||||
return torch._C._nn.gelu(x, out=out)
|
||||
|
||||
class TestVersionedStftV10(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, n_fft: int, window):
|
||||
# calling aten::stft direct instead of torch.functional.stft
|
||||
return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True)
|
||||
|
@ -96,7 +96,6 @@ ALL_MODULES = {
|
||||
TestVersionedLogspaceOutV8(): "aten::logspace.out",
|
||||
TestVersionedGeluV9(): "aten::gelu",
|
||||
TestVersionedGeluOutV9(): "aten::gelu.out",
|
||||
TestVersionedStftV10(): "aten::stft",
|
||||
}
|
||||
|
||||
"""
|
||||
|
@ -540,20 +540,3 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
# "Upgraded" model should match the new version output
|
||||
self.assertEqual(output, output_current)
|
||||
|
||||
def test_versioned_stft_v10(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
v10_mobile_module = _load_for_lite_interpreter(buffer)
|
||||
|
||||
for in_dtype, window_dtype in product(
|
||||
[torch.float32, torch.complex64], repeat=2):
|
||||
input = torch.rand((100,), dtype=in_dtype)
|
||||
window = torch.rand((10,), dtype=window_dtype)
|
||||
n_fft = 10
|
||||
output = v10_mobile_module(input, n_fft, window)
|
||||
output_expected = torch.stft(input, n_fft=n_fft, window=window,
|
||||
center=False, return_complex=True)
|
||||
self.assertEqual(output, output_expected)
|
||||
|
@ -109,6 +109,7 @@ blocklist = [
|
||||
"block_diag",
|
||||
"norm",
|
||||
"chain_matmul",
|
||||
"stft",
|
||||
"tensordot",
|
||||
"split",
|
||||
"unique_consecutive",
|
||||
|
@ -2,7 +2,7 @@ from collections import OrderedDict
|
||||
import enum
|
||||
import functools
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import warnings
|
||||
import copyreg
|
||||
from copy import deepcopy
|
||||
@ -545,6 +545,40 @@ class Tensor(torch._C._TensorBase):
|
||||
else:
|
||||
return LU, pivots
|
||||
|
||||
def stft(self, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
|
||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
||||
onesided: Optional[bool] = None, return_complex: Optional[bool] = None):
|
||||
r"""See :func:`torch.stft`
|
||||
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with
|
||||
the previous signature may cause error or return incorrect result.
|
||||
"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
|
||||
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
||||
onesided=onesided, return_complex=return_complex
|
||||
)
|
||||
return torch.stft(self, n_fft, hop_length, win_length, window, center,
|
||||
pad_mode, normalized, onesided, return_complex=return_complex)
|
||||
|
||||
def istft(self, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
|
||||
center: bool = True, normalized: bool = False,
|
||||
onesided: Optional[bool] = None, length: Optional[int] = None,
|
||||
return_complex: bool = False):
|
||||
r"""See :func:`torch.istft`"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
|
||||
return_complex=return_complex
|
||||
)
|
||||
return torch.istft(self, n_fft, hop_length, win_length, window, center,
|
||||
normalized, onesided, length, return_complex=return_complex)
|
||||
|
||||
def resize(self, *sizes):
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.resize, (self,), self, *sizes)
|
||||
|
@ -4752,21 +4752,16 @@ See :func:`torch.dsplit`
|
||||
""")
|
||||
|
||||
add_docstr_all('stft',
|
||||
"stft(n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
|
||||
r"""
|
||||
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
|
||||
|
||||
See :func:`torch.stft`
|
||||
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with
|
||||
the previous signature may cause error or return incorrect result.
|
||||
""")
|
||||
|
||||
add_docstr_all('istft',
|
||||
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor"
|
||||
r"""
|
||||
istft(n_fft, hop_length=None, win_length=None, window=None,
|
||||
center=True, normalized=False, onesided=True, length=None) -> Tensor
|
||||
|
||||
See :func:`torch.istft`
|
||||
""")
|
||||
|
@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() {
|
||||
std::vector<Upgrader>({
|
||||
Upgrader({0, 8, "logspace_out_0_8", 10})
|
||||
})},
|
||||
{std::string("aten::stft"),
|
||||
std::vector<Upgrader>({
|
||||
Upgrader({0, 10, "stft_0_10", 11})
|
||||
})},
|
||||
});
|
||||
return operatorVersionMapForMobile;
|
||||
}
|
||||
@ -531,35 +527,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
|
||||
OperatorString({"prim::unchecked_cast", "", 1}),
|
||||
}), // operators list
|
||||
}),
|
||||
ByteCodeFunctionWithOperator({
|
||||
mobile::Function::registerFunc(
|
||||
"stft_0_10",
|
||||
std::vector<Instruction>({
|
||||
Instruction{OpCode::STOREN, 1, 8},
|
||||
Instruction{OpCode::MOVE, 1, 0},
|
||||
Instruction{OpCode::MOVE, 2, 0},
|
||||
Instruction{OpCode::MOVE, 3, 0},
|
||||
Instruction{OpCode::MOVE, 4, 0},
|
||||
Instruction{OpCode::MOVE, 5, 0},
|
||||
Instruction{OpCode::LOADC, 1, 0},
|
||||
Instruction{OpCode::LOADC, 0, 0},
|
||||
Instruction{OpCode::MOVE, 6, 0},
|
||||
Instruction{OpCode::MOVE, 7, 0},
|
||||
Instruction{OpCode::MOVE, 8, 0},
|
||||
Instruction{OpCode::OP, 0, 0},
|
||||
Instruction{OpCode::RET, 0, 0},
|
||||
}), // instructions list,
|
||||
std::vector<c10::IValue>({
|
||||
c10::IValue("reflect"),
|
||||
c10::IValue(false),
|
||||
}), // constants list,
|
||||
std::vector<c10::TypePtr>(), // types list,
|
||||
8
|
||||
),
|
||||
std::vector<OperatorString>({
|
||||
OperatorString({"aten::stft", "", 10}),
|
||||
}), // operators list
|
||||
}),
|
||||
});
|
||||
for (const auto& upgrader_function : upgrader_function_list) {
|
||||
for (const auto& op : upgrader_function.operators) {
|
||||
|
@ -15,17 +15,6 @@ namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
|
||||
{"stft_0_10", R"SCRIPT(
|
||||
def stft_0_10(
|
||||
self: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
||||
normalized: bool = False, onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None) -> Tensor:
|
||||
return torch.stft(
|
||||
self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=False, normalized=normalized, onesided=onesided,
|
||||
return_complex=return_complex)
|
||||
)SCRIPT"},
|
||||
{"logspace_0_8", R"SCRIPT(
|
||||
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
|
||||
device: Optional[Device], pin_memory: Optional[bool]):
|
||||
|
@ -16,11 +16,7 @@ static bool isVersionMapSorted = false;
|
||||
// Note for developers: The list of upgraders should be SORTED
|
||||
// by the version number where the upgrader is registered.
|
||||
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
|
||||
{{"aten::stft",
|
||||
{{11,
|
||||
"stft_0_10",
|
||||
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
|
||||
{"aten::logspace",
|
||||
{{"aten::logspace",
|
||||
{{9,
|
||||
"logspace_0_8",
|
||||
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
|
||||
|
@ -4,7 +4,7 @@ from typing import (
|
||||
|
||||
import torch
|
||||
from torch._C import _add_docstr
|
||||
import torch.nn.functional
|
||||
import torch.nn.functional as F
|
||||
from ._lowrank import svd_lowrank, pca_lowrank
|
||||
from .overrides import (
|
||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
||||
@ -478,121 +478,133 @@ def _meshgrid(*tensors, indexing: Optional[str]):
|
||||
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
stft = _add_docstr(
|
||||
torch.stft,
|
||||
"stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
|
||||
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
|
||||
r"""
|
||||
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
win_length: Optional[int] = None, window: Optional[Tensor] = None,
|
||||
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None) -> Tensor:
|
||||
r"""Short-time Fourier transform (STFT).
|
||||
|
||||
Short-time Fourier transform (STFT).
|
||||
.. warning::
|
||||
From version 1.8.0, :attr:`return_complex` must always be given
|
||||
explicitly for real inputs and `return_complex=False` has been
|
||||
deprecated. Strongly prefer `return_complex=True` as in a future
|
||||
pytorch release, this function will only return complex tensors.
|
||||
|
||||
.. warning::
|
||||
From version 1.8.0, :attr:`return_complex` must always be given
|
||||
explicitly for real inputs and `return_complex=False` has been
|
||||
deprecated. Strongly prefer `return_complex=True` as in a future
|
||||
pytorch release, this function will only return complex tensors.
|
||||
Note that :func:`torch.view_as_real` can be used to recover a real
|
||||
tensor with an extra last dimension for real and imaginary components.
|
||||
|
||||
Note that :func:`torch.view_as_real` can be used to recover a real
|
||||
tensor with an extra last dimension for real and imaginary components.
|
||||
The STFT computes the Fourier transform of short overlapping windows of the
|
||||
input. This giving frequency components of the signal as they change over
|
||||
time. The interface of this function is modeled after (but *not* a drop-in
|
||||
replacement for) librosa_ stft function.
|
||||
|
||||
The STFT computes the Fourier transform of short overlapping windows of the
|
||||
input. This giving frequency components of the signal as they change over
|
||||
time. The interface of this function is modeled after (but *not* a drop-in
|
||||
replacement for) librosa_ stft function.
|
||||
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
|
||||
|
||||
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
|
||||
Ignoring the optional batch dimension, this method computes the following
|
||||
expression:
|
||||
|
||||
Ignoring the optional batch dimension, this method computes the following
|
||||
expression:
|
||||
.. math::
|
||||
X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
|
||||
\text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
|
||||
\exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
|
||||
|
||||
.. math::
|
||||
X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
|
||||
\text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
|
||||
\exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
|
||||
where :math:`m` is the index of the sliding window, and :math:`\omega` is
|
||||
the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
|
||||
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
|
||||
|
||||
where :math:`m` is the index of the sliding window, and :math:`\omega` is
|
||||
the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
|
||||
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
|
||||
* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
|
||||
sequences.
|
||||
|
||||
* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
|
||||
sequences.
|
||||
* If :attr:`hop_length` is ``None`` (default), it is treated as equal to
|
||||
``floor(n_fft / 4)``.
|
||||
|
||||
* If :attr:`hop_length` is ``None`` (default), it is treated as equal to
|
||||
``floor(n_fft / 4)``.
|
||||
* If :attr:`win_length` is ``None`` (default), it is treated as equal to
|
||||
:attr:`n_fft`.
|
||||
|
||||
* If :attr:`win_length` is ``None`` (default), it is treated as equal to
|
||||
:attr:`n_fft`.
|
||||
* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
|
||||
:meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
|
||||
treated as if having :math:`1` everywhere in the window. If
|
||||
:math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
|
||||
both sides to length :attr:`n_fft` before being applied.
|
||||
|
||||
* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
|
||||
:meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
|
||||
treated as if having :math:`1` everywhere in the window. If
|
||||
:math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
|
||||
both sides to length :attr:`n_fft` before being applied.
|
||||
* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
|
||||
both sides so that the :math:`t`-th frame is centered at time
|
||||
:math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
|
||||
begins at time :math:`t \times \text{hop\_length}`.
|
||||
|
||||
* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
|
||||
both sides so that the :math:`t`-th frame is centered at time
|
||||
:math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
|
||||
begins at time :math:`t \times \text{hop\_length}`.
|
||||
* :attr:`pad_mode` determines the padding method used on :attr:`input` when
|
||||
:attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
|
||||
all available options. Default is ``"reflect"``.
|
||||
|
||||
* :attr:`pad_mode` determines the padding method used on :attr:`input` when
|
||||
:attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
|
||||
all available options. Default is ``"reflect"``.
|
||||
* If :attr:`onesided` is ``True`` (default for real input), only values for
|
||||
:math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
|
||||
\frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
|
||||
the real-to-complex Fourier transform satisfies the conjugate symmetry,
|
||||
i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
|
||||
Note if the input or window tensors are complex, then :attr:`onesided`
|
||||
output is not possible.
|
||||
|
||||
* If :attr:`onesided` is ``True`` (default for real input), only values for
|
||||
:math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
|
||||
\frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
|
||||
the real-to-complex Fourier transform satisfies the conjugate symmetry,
|
||||
i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
|
||||
Note if the input or window tensors are complex, then :attr:`onesided`
|
||||
output is not possible.
|
||||
* If :attr:`normalized` is ``True`` (default is ``False``), the function
|
||||
returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
|
||||
|
||||
* If :attr:`normalized` is ``True`` (default is ``False``), the function
|
||||
returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
|
||||
* If :attr:`return_complex` is ``True`` (default if input is complex), the
|
||||
return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
|
||||
the output is a ``input.dim() + 2`` dimensional real tensor where the last
|
||||
dimension represents the real and imaginary components.
|
||||
|
||||
* If :attr:`return_complex` is ``True`` (default if input is complex), the
|
||||
return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
|
||||
the output is a ``input.dim() + 2`` dimensional real tensor where the last
|
||||
dimension represents the real and imaginary components.
|
||||
Returns either a complex tensor of size :math:`(* \times N \times T)` if
|
||||
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
|
||||
\times T \times 2)`. Where :math:`*` is the optional batch size of
|
||||
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
|
||||
and :math:`T` is the total number of frames used.
|
||||
|
||||
Returns either a complex tensor of size :math:`(* \times N \times T)` if
|
||||
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
|
||||
\times T \times 2)`. Where :math:`*` is the optional batch size of
|
||||
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
|
||||
and :math:`T` is the total number of frames used.
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with the
|
||||
previous signature may cause error or return incorrect result.
|
||||
|
||||
.. warning::
|
||||
This function changed signature at version 0.4.1. Calling with the
|
||||
previous signature may cause error or return incorrect result.
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
n_fft (int): size of Fourier transform
|
||||
hop_length (int, optional): the distance between neighboring sliding window
|
||||
frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
|
||||
win_length (int, optional): the size of window frame and STFT filter.
|
||||
Default: ``None`` (treated as equal to :attr:`n_fft`)
|
||||
window (Tensor, optional): the optional window function.
|
||||
Default: ``None`` (treated as window of all :math:`1` s)
|
||||
center (bool, optional): whether to pad :attr:`input` on both sides so
|
||||
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
|
||||
Default: ``True``
|
||||
pad_mode (string, optional): controls the padding method used when
|
||||
:attr:`center` is ``True``. Default: ``"reflect"``
|
||||
normalized (bool, optional): controls whether to return the normalized STFT results
|
||||
Default: ``False``
|
||||
onesided (bool, optional): controls whether to return half of results to
|
||||
avoid redundancy for real inputs.
|
||||
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
|
||||
return_complex (bool, optional): whether to return a complex tensor, or
|
||||
a real tensor with an extra last dimension for the real and
|
||||
imaginary components.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
n_fft (int): size of Fourier transform
|
||||
hop_length (int, optional): the distance between neighboring sliding window
|
||||
frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
|
||||
win_length (int, optional): the size of window frame and STFT filter.
|
||||
Default: ``None`` (treated as equal to :attr:`n_fft`)
|
||||
window (Tensor, optional): the optional window function.
|
||||
Default: ``None`` (treated as window of all :math:`1` s)
|
||||
center (bool, optional): whether to pad :attr:`input` on both sides so
|
||||
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
|
||||
Default: ``True``
|
||||
pad_mode (string, optional): controls the padding method used when
|
||||
:attr:`center` is ``True``. Default: ``"reflect"``
|
||||
normalized (bool, optional): controls whether to return the normalized STFT results
|
||||
Default: ``False``
|
||||
onesided (bool, optional): controls whether to return half of results to
|
||||
avoid redundancy for real inputs.
|
||||
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
|
||||
return_complex (bool, optional): whether to return a complex tensor, or
|
||||
a real tensor with an extra last dimension for the real and
|
||||
imaginary components.
|
||||
Returns:
|
||||
Tensor: A tensor containing the STFT result with shape described above
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor containing the STFT result with shape described above
|
||||
|
||||
""")
|
||||
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
|
||||
stft.__module__ = "torch.functional"
|
||||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(
|
||||
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
|
||||
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
|
||||
onesided=onesided, return_complex=return_complex)
|
||||
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
||||
# period is over for PR #73432
|
||||
if center:
|
||||
signal_dim = input.dim()
|
||||
extended_shape = [1] * (3 - signal_dim) + list(input.size())
|
||||
pad = int(n_fft // 2)
|
||||
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
|
||||
input = input.view(input.shape[-signal_dim:])
|
||||
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
|
||||
normalized, onesided, return_complex)
|
||||
|
||||
|
||||
istft = _add_docstr(
|
||||
|
Reference in New Issue
Block a user