Revert "stft: remove non-center overload and python functional wrapper"

This reverts commit d23ecbfc9ac157560611b242f015743f189dbf48.

Reverted https://github.com/pytorch/pytorch/pull/73434 on behalf of https://github.com/albanD
This commit is contained in:
PyTorch MergeBot
2022-05-09 19:59:43 +00:00
parent ce3857e73c
commit 2c5bf12584
16 changed files with 177 additions and 186 deletions

View File

@ -541,7 +541,8 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional<int64_t>, bool, c10::string_view), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional<int64_t>, c10::optional<int64_t>, const c10::optional<Tensor> &, bool, c10::string_view, bool, c10::optional<bool>, c10::optional<bool>), fp32)
KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional<at::ScalarType>), fp32)

View File

@ -907,6 +907,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
}
}
Tensor stft(
const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const c10::optional<Tensor>& window_opt,
const bool normalized,
const optional<bool> onesidedOpt, const optional<bool> return_complexOpt) {
return at::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
return_complexOpt);
}
// Create complex tensor from the old style of real tensor with size=(..., 2)
// This is to support istft in the transition to requiring complex input.
// NOTE: This may return a view of the input tensor, or might clone if necessary
@ -1100,6 +1111,15 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
#undef REPR
}
Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const optional<bool> onesidedOpt,
const optional<int64_t> lengthOpt) {
return at::native::istft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
onesidedOpt, lengthOpt, /*return_complex=*/false);
}
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
const auto input_sizes = input.sizes();
const auto input_strides = input.strides();

View File

@ -4320,7 +4320,12 @@
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
# Overload without center & pad mode, needed for forward-compatibility
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
variants: function, method
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor

View File

@ -12,7 +12,7 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 11;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
@ -83,9 +83,7 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// Bump the version number to 10 to update aten::gelu and
// and aten::gelu.out to support the new approximate kwarg.
// (see: https://github.com/pytorch/pytorch/pull/61439)
// 4) [02/25/2022]
// Bump version number to 11 to update aten::stft to do padding in ATen
constexpr uint64_t kProducedFileFormatVersion = 11L;
constexpr uint64_t kProducedFileFormatVersion = 0xAL;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif

View File

@ -118,7 +118,6 @@ ALLOW_LIST = [
("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
("aten::scatter_reduce.two", datetime.date(2022, 4, 15)),
("aten::stft", datetime.date(2022, 6, 1)),
("aten::_s_where", datetime.date(2022, 9, 30)),
("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)),
("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)),

View File

@ -57,11 +57,3 @@ class TestVersionedGeluOutV9(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out)
class TestVersionedStftV10(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, n_fft: int, window):
# calling aten::stft direct instead of torch.functional.stft
return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True)

View File

@ -96,7 +96,6 @@ ALL_MODULES = {
TestVersionedLogspaceOutV8(): "aten::logspace.out",
TestVersionedGeluV9(): "aten::gelu",
TestVersionedGeluOutV9(): "aten::gelu.out",
TestVersionedStftV10(): "aten::stft",
}
"""

View File

@ -540,20 +540,3 @@ class TestSaveLoadForOpVersion(JitTestCase):
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)
def test_versioned_stft_v10(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v10_mobile_module = _load_for_lite_interpreter(buffer)
for in_dtype, window_dtype in product(
[torch.float32, torch.complex64], repeat=2):
input = torch.rand((100,), dtype=in_dtype)
window = torch.rand((10,), dtype=window_dtype)
n_fft = 10
output = v10_mobile_module(input, n_fft, window)
output_expected = torch.stft(input, n_fft=n_fft, window=window,
center=False, return_complex=True)
self.assertEqual(output, output_expected)

View File

@ -109,6 +109,7 @@ blocklist = [
"block_diag",
"norm",
"chain_matmul",
"stft",
"tensordot",
"split",
"unique_consecutive",

View File

@ -2,7 +2,7 @@ from collections import OrderedDict
import enum
import functools
from numbers import Number
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import copyreg
from copy import deepcopy
@ -545,6 +545,40 @@ class Tensor(torch._C._TensorBase):
else:
return LU, pivots
def stft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
onesided: Optional[bool] = None, return_complex: Optional[bool] = None):
r"""See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.stft, (self,), self, n_fft, hop_length=hop_length,
win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex
)
return torch.stft(self, n_fft, hop_length, win_length, window, center,
pad_mode, normalized, onesided, return_complex=return_complex)
def istft(self, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: 'Optional[Tensor]' = None,
center: bool = True, normalized: bool = False,
onesided: Optional[bool] = None, length: Optional[int] = None,
return_complex: bool = False):
r"""See :func:`torch.istft`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, normalized=normalized, onesided=onesided, length=length,
return_complex=return_complex
)
return torch.istft(self, n_fft, hop_length, win_length, window, center,
normalized, onesided, length, return_complex=return_complex)
def resize(self, *sizes):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.resize, (self,), self, *sizes)

View File

@ -4752,21 +4752,16 @@ See :func:`torch.dsplit`
""")
add_docstr_all('stft',
"stft(n_fft, hop_length=None, win_length=None, window=None, center=True, "
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
See :func:`torch.stft`
.. warning::
This function changed signature at version 0.4.1. Calling with
the previous signature may cause error or return incorrect result.
""")
add_docstr_all('istft',
"istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
"normalized=False, onesided=None, length=None, return_complex=False) -> Tensor"
r"""
istft(n_fft, hop_length=None, win_length=None, window=None,
center=True, normalized=False, onesided=True, length=None) -> Tensor
See :func:`torch.istft`
""")

View File

@ -67,10 +67,6 @@ getOperatorVersionMapForMobile() {
std::vector<Upgrader>({
Upgrader({0, 8, "logspace_out_0_8", 10})
})},
{std::string("aten::stft"),
std::vector<Upgrader>({
Upgrader({0, 10, "stft_0_10", 11})
})},
});
return operatorVersionMapForMobile;
}
@ -531,35 +527,6 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
}),
ByteCodeFunctionWithOperator({
mobile::Function::registerFunc(
"stft_0_10",
std::vector<Instruction>({
Instruction{OpCode::STOREN, 1, 8},
Instruction{OpCode::MOVE, 1, 0},
Instruction{OpCode::MOVE, 2, 0},
Instruction{OpCode::MOVE, 3, 0},
Instruction{OpCode::MOVE, 4, 0},
Instruction{OpCode::MOVE, 5, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::MOVE, 6, 0},
Instruction{OpCode::MOVE, 7, 0},
Instruction{OpCode::MOVE, 8, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::RET, 0, 0},
}), // instructions list,
std::vector<c10::IValue>({
c10::IValue("reflect"),
c10::IValue(false),
}), // constants list,
std::vector<c10::TypePtr>(), // types list,
8
),
std::vector<OperatorString>({
OperatorString({"aten::stft", "", 10}),
}), // operators list
}),
});
for (const auto& upgrader_function : upgrader_function_list) {
for (const auto& op : upgrader_function.operators) {

View File

@ -15,17 +15,6 @@ namespace torch {
namespace jit {
static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
{"stft_0_10", R"SCRIPT(
def stft_0_10(
self: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
normalized: bool = False, onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
return torch.stft(
self, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=False, normalized=normalized, onesided=onesided,
return_complex=return_complex)
)SCRIPT"},
{"logspace_0_8", R"SCRIPT(
def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
device: Optional[Device], pin_memory: Optional[bool]):

View File

@ -16,11 +16,7 @@ static bool isVersionMapSorted = false;
// Note for developers: The list of upgraders should be SORTED
// by the version number where the upgrader is registered.
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
{{"aten::stft",
{{11,
"stft_0_10",
"aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}},
{"aten::logspace",
{{"aten::logspace",
{{9,
"logspace_0_8",
"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},

View File

@ -4,7 +4,7 @@ from typing import (
import torch
from torch._C import _add_docstr
import torch.nn.functional
import torch.nn.functional as F
from ._lowrank import svd_lowrank, pca_lowrank
from .overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
@ -478,121 +478,133 @@ def _meshgrid(*tensors, indexing: Optional[str]):
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
stft = _add_docstr(
torch.stft,
"stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
"pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor"
r"""
def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
win_length: Optional[int] = None, window: Optional[Tensor] = None,
center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None) -> Tensor:
r"""Short-time Fourier transform (STFT).
Short-time Fourier transform (STFT).
.. warning::
From version 1.8.0, :attr:`return_complex` must always be given
explicitly for real inputs and `return_complex=False` has been
deprecated. Strongly prefer `return_complex=True` as in a future
pytorch release, this function will only return complex tensors.
.. warning::
From version 1.8.0, :attr:`return_complex` must always be given
explicitly for real inputs and `return_complex=False` has been
deprecated. Strongly prefer `return_complex=True` as in a future
pytorch release, this function will only return complex tensors.
Note that :func:`torch.view_as_real` can be used to recover a real
tensor with an extra last dimension for real and imaginary components.
Note that :func:`torch.view_as_real` can be used to recover a real
tensor with an extra last dimension for real and imaginary components.
The STFT computes the Fourier transform of short overlapping windows of the
input. This giving frequency components of the signal as they change over
time. The interface of this function is modeled after (but *not* a drop-in
replacement for) librosa_ stft function.
The STFT computes the Fourier transform of short overlapping windows of the
input. This giving frequency components of the signal as they change over
time. The interface of this function is modeled after (but *not* a drop-in
replacement for) librosa_ stft function.
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
Ignoring the optional batch dimension, this method computes the following
expression:
Ignoring the optional batch dimension, this method computes the following
expression:
.. math::
X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
\text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
\exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
.. math::
X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
\text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
\exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right),
where :math:`m` is the index of the sliding window, and :math:`\omega` is
the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
where :math:`m` is the index of the sliding window, and :math:`\omega` is
the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
sequences.
* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
sequences.
* If :attr:`hop_length` is ``None`` (default), it is treated as equal to
``floor(n_fft / 4)``.
* If :attr:`hop_length` is ``None`` (default), it is treated as equal to
``floor(n_fft / 4)``.
* If :attr:`win_length` is ``None`` (default), it is treated as equal to
:attr:`n_fft`.
* If :attr:`win_length` is ``None`` (default), it is treated as equal to
:attr:`n_fft`.
* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
:meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
treated as if having :math:`1` everywhere in the window. If
:math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
both sides to length :attr:`n_fft` before being applied.
* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
:meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
treated as if having :math:`1` everywhere in the window. If
:math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
both sides to length :attr:`n_fft` before being applied.
* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
both sides so that the :math:`t`-th frame is centered at time
:math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
begins at time :math:`t \times \text{hop\_length}`.
* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
both sides so that the :math:`t`-th frame is centered at time
:math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
begins at time :math:`t \times \text{hop\_length}`.
* :attr:`pad_mode` determines the padding method used on :attr:`input` when
:attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
all available options. Default is ``"reflect"``.
* :attr:`pad_mode` determines the padding method used on :attr:`input` when
:attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
all available options. Default is ``"reflect"``.
* If :attr:`onesided` is ``True`` (default for real input), only values for
:math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
\frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
the real-to-complex Fourier transform satisfies the conjugate symmetry,
i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
Note if the input or window tensors are complex, then :attr:`onesided`
output is not possible.
* If :attr:`onesided` is ``True`` (default for real input), only values for
:math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
\frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
the real-to-complex Fourier transform satisfies the conjugate symmetry,
i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
Note if the input or window tensors are complex, then :attr:`onesided`
output is not possible.
* If :attr:`normalized` is ``True`` (default is ``False``), the function
returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
* If :attr:`normalized` is ``True`` (default is ``False``), the function
returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
* If :attr:`return_complex` is ``True`` (default if input is complex), the
return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
the output is a ``input.dim() + 2`` dimensional real tensor where the last
dimension represents the real and imaginary components.
* If :attr:`return_complex` is ``True`` (default if input is complex), the
return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
the output is a ``input.dim() + 2`` dimensional real tensor where the last
dimension represents the real and imaginary components.
Returns either a complex tensor of size :math:`(* \times N \times T)` if
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
\times T \times 2)`. Where :math:`*` is the optional batch size of
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
and :math:`T` is the total number of frames used.
Returns either a complex tensor of size :math:`(* \times N \times T)` if
:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
\times T \times 2)`. Where :math:`*` is the optional batch size of
:attr:`input`, :math:`N` is the number of frequencies where STFT is applied
and :math:`T` is the total number of frames used.
.. warning::
This function changed signature at version 0.4.1. Calling with the
previous signature may cause error or return incorrect result.
.. warning::
This function changed signature at version 0.4.1. Calling with the
previous signature may cause error or return incorrect result.
Args:
input (Tensor): the input tensor
n_fft (int): size of Fourier transform
hop_length (int, optional): the distance between neighboring sliding window
frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
win_length (int, optional): the size of window frame and STFT filter.
Default: ``None`` (treated as equal to :attr:`n_fft`)
window (Tensor, optional): the optional window function.
Default: ``None`` (treated as window of all :math:`1` s)
center (bool, optional): whether to pad :attr:`input` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
normalized (bool, optional): controls whether to return the normalized STFT results
Default: ``False``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy for real inputs.
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
return_complex (bool, optional): whether to return a complex tensor, or
a real tensor with an extra last dimension for the real and
imaginary components.
Args:
input (Tensor): the input tensor
n_fft (int): size of Fourier transform
hop_length (int, optional): the distance between neighboring sliding window
frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
win_length (int, optional): the size of window frame and STFT filter.
Default: ``None`` (treated as equal to :attr:`n_fft`)
window (Tensor, optional): the optional window function.
Default: ``None`` (treated as window of all :math:`1` s)
center (bool, optional): whether to pad :attr:`input` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Default: ``True``
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. Default: ``"reflect"``
normalized (bool, optional): controls whether to return the normalized STFT results
Default: ``False``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy for real inputs.
Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
return_complex (bool, optional): whether to return a complex tensor, or
a real tensor with an extra last dimension for the real and
imaginary components.
Returns:
Tensor: A tensor containing the STFT result with shape described above
Returns:
Tensor: A tensor containing the STFT result with shape described above
""")
# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
stft.__module__ = "torch.functional"
"""
if has_torch_function_unary(input):
return handle_torch_function(
stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
window=window, center=center, pad_mode=pad_mode, normalized=normalized,
onesided=onesided, return_complex=return_complex)
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
if center:
signal_dim = input.dim()
extended_shape = [1] * (3 - signal_dim) + list(input.size())
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
normalized, onesided, return_complex)
istft = _add_docstr(