diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2f0eefa88bec..c6d0260229b0 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -541,8 +541,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(quantile), "quantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile", Tensor(const Tensor &, const Tensor &, c10::optional, bool, c10::string_view), fp32) KERNEL_CPU(ADD_NS(nanquantile), "nanquantile.scalar", Tensor(const Tensor &, double, c10::optional, bool, c10::string_view), fp32) - KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::optional, c10::optional), fp32) - KERNEL_CPU(ADD_NS(stft), "stft.center", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) + KERNEL_CPU(ADD_NS(stft), "stft", Tensor(const Tensor &, int64_t, c10::optional, c10::optional, const c10::optional &, bool, c10::string_view, bool, c10::optional, c10::optional), fp32) KERNEL_CPU(ADD_NS(cdist), "cdist", Tensor(const Tensor &, const Tensor &, double, c10::optional), fp32) KERNEL_CPU(ADD_NS(cross), "cross", Tensor(const Tensor &, const Tensor &, c10::optional), fp32) KERNEL_CPU(ADD_NS(cumprod), "cumprod", Tensor(const Tensor &, int64_t, c10::optional), fp32) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index af000cc70d9f..5b9b273e9239 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -907,17 +907,6 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional hop } } -Tensor stft( - const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, - const optional win_lengthOpt, const c10::optional& window_opt, - const bool normalized, - const optional onesidedOpt, const optional return_complexOpt) { - return at::stft( - self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt, - /*center=*/false, /*mode=*/"constant", normalized, onesidedOpt, - return_complexOpt); -} - // Create complex tensor from the old style of real tensor with size=(..., 2) // This is to support istft in the transition to requiring complex input. // NOTE: This may return a view of the input tensor, or might clone if necessary @@ -1111,15 +1100,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho #undef REPR } -Tensor istft(const Tensor& self, const int64_t n_fft, const optional hop_lengthOpt, - const optional win_lengthOpt, const Tensor& window, - const bool center, const bool normalized, const optional onesidedOpt, - const optional lengthOpt) { - return at::native::istft( - self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, - onesidedOpt, lengthOpt, /*return_complex=*/false); -} - void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_sizes = input.sizes(); const auto input_strides = input.strides(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cf5707d5f86d..f8ad9b0e9214 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4297,12 +4297,7 @@ - func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) -# Overload without center & pad mode, needed for forward-compatibility -- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor - variants: function, method - cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] - -- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor +- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor variants: function, method - func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 78a91c64fe84..88b0c19b095b 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -12,7 +12,7 @@ namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; #if ENABLE_UPGRADERS -constexpr uint64_t kMaxSupportedFileFormatVersion = 0xAL; +constexpr uint64_t kMaxSupportedFileFormatVersion = 11; #else constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; #endif @@ -83,7 +83,9 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L; // Bump the version number to 10 to update aten::gelu and // and aten::gelu.out to support the new approximate kwarg. // (see: https://github.com/pytorch/pytorch/pull/61439) -constexpr uint64_t kProducedFileFormatVersion = 0xAL; +// 4) [02/25/2022] +// Bump version number to 11 to update aten::stft to do padding in ATen +constexpr uint64_t kProducedFileFormatVersion = 11L; #else constexpr uint64_t kProducedFileFormatVersion = 0x3L; #endif diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index b4eacefc72cf..31dc8b47c299 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -110,6 +110,7 @@ ALLOW_LIST = [ ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)), ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)), ("aten::scatter_reduce.two", datetime.date(2022, 4, 15)), + ("aten::stft", datetime.date(2022, 5, 1)), ("aten::_s_where", datetime.date(2022, 9, 30)), ("quantized::conv2d_cudnn", datetime.date(2022, 3, 22)), ("quantized::conv2d_relu_cudnn", datetime.date(2022, 3, 22)), diff --git a/test/jit/fixtures/test_versioned_stft_v10.ptl b/test/jit/fixtures/test_versioned_stft_v10.ptl new file mode 100644 index 000000000000..7dcb8cc8f715 Binary files /dev/null and b/test/jit/fixtures/test_versioned_stft_v10.ptl differ diff --git a/test/jit/fixtures_srcs/fixtures_src.py b/test/jit/fixtures_srcs/fixtures_src.py index dff23702311a..ba1322fff2e8 100644 --- a/test/jit/fixtures_srcs/fixtures_src.py +++ b/test/jit/fixtures_srcs/fixtures_src.py @@ -57,3 +57,11 @@ class TestVersionedGeluOutV9(torch.nn.Module): def forward(self, x): out = torch.zeros_like(x) return torch._C._nn.gelu(x, out=out) + +class TestVersionedStftV10(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, n_fft: int, window): + # calling aten::stft direct instead of torch.functional.stft + return torch.ops.aten.stft(x, n_fft=n_fft, window=window, return_complex=True) diff --git a/test/jit/fixtures_srcs/generate_models.py b/test/jit/fixtures_srcs/generate_models.py index 980e7dd0324e..e7e00262533d 100644 --- a/test/jit/fixtures_srcs/generate_models.py +++ b/test/jit/fixtures_srcs/generate_models.py @@ -96,6 +96,7 @@ ALL_MODULES = { TestVersionedLogspaceOutV8(): "aten::logspace.out", TestVersionedGeluV9(): "aten::gelu", TestVersionedGeluOutV9(): "aten::gelu.out", + TestVersionedStftV10(): "aten::stft", } """ diff --git a/test/jit/test_save_load_for_op_version.py b/test/jit/test_save_load_for_op_version.py index b5e38b37d3eb..ff793404e3b9 100644 --- a/test/jit/test_save_load_for_op_version.py +++ b/test/jit/test_save_load_for_op_version.py @@ -540,3 +540,20 @@ class TestSaveLoadForOpVersion(JitTestCase): self.assertTrue(output.size(dim=0) == 100) # "Upgraded" model should match the new version output self.assertEqual(output, output_current) + + def test_versioned_stft_v10(self): + model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_stft_v10.ptl" + loaded_model = torch.jit.load(model_path) + buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + v10_mobile_module = _load_for_lite_interpreter(buffer) + + for in_dtype, window_dtype in product( + [torch.float32, torch.complex64], repeat=2): + input = torch.rand((100,), dtype=in_dtype) + window = torch.rand((10,), dtype=window_dtype) + n_fft = 10 + output = v10_mobile_module(input, n_fft, window) + output_expected = torch.stft(input, n_fft=n_fft, window=window, + center=False, return_complex=True) + self.assertEqual(output, output_expected) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index f85d99b43ed2..3a178528b5af 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -109,7 +109,6 @@ blocklist = [ "block_diag", "norm", "chain_matmul", - "stft", "tensordot", "split", "unique_consecutive", diff --git a/torch/_tensor.py b/torch/_tensor.py index b26f979bd3f7..92b8f5e1d4f5 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -2,7 +2,7 @@ from collections import OrderedDict import enum import functools from numbers import Number -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Tuple, Union import warnings import copyreg from copy import deepcopy @@ -542,40 +542,6 @@ class Tensor(torch._C._TensorBase): else: return LU, pivots - def stft(self, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, - center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, - onesided: Optional[bool] = None, return_complex: Optional[bool] = None): - r"""See :func:`torch.stft` - - .. warning:: - This function changed signature at version 0.4.1. Calling with - the previous signature may cause error or return incorrect result. - """ - if has_torch_function_unary(self): - return handle_torch_function( - Tensor.stft, (self,), self, n_fft, hop_length=hop_length, - win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex - ) - return torch.stft(self, n_fft, hop_length, win_length, window, center, - pad_mode, normalized, onesided, return_complex=return_complex) - - def istft(self, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: 'Optional[Tensor]' = None, - center: bool = True, normalized: bool = False, - onesided: Optional[bool] = None, length: Optional[int] = None, - return_complex: bool = False): - r"""See :func:`torch.istft`""" - if has_torch_function_unary(self): - return handle_torch_function( - Tensor.istft, (self,), self, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, normalized=normalized, onesided=onesided, length=length, - return_complex=return_complex - ) - return torch.istft(self, n_fft, hop_length, win_length, window, center, - normalized, onesided, length, return_complex=return_complex) - def resize(self, *sizes): if has_torch_function_unary(self): return handle_torch_function(Tensor.resize, (self,), self, *sizes) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index e4b35eb72cd5..a912d1a06019 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -4752,16 +4752,21 @@ See :func:`torch.dsplit` """) add_docstr_all('stft', + "stft(n_fft, hop_length=None, win_length=None, window=None, center=True, " + "pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor" r""" -stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor See :func:`torch.stft` + +.. warning:: + This function changed signature at version 0.4.1. Calling with + the previous signature may cause error or return incorrect result. """) add_docstr_all('istft', + "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " + "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor" r""" -istft(n_fft, hop_length=None, win_length=None, window=None, - center=True, normalized=False, onesided=True, length=None) -> Tensor See :func:`torch.istft` """) diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index 0e52829255d0..eed4a676a9c4 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -67,6 +67,10 @@ getOperatorVersionMapForMobile() { std::vector({ Upgrader({0, 8, "logspace_out_0_8", 10}) })}, + {std::string("aten::stft"), + std::vector({ + Upgrader({0, 10, "stft_0_10", 11}) + })}, }); return operatorVersionMapForMobile; } @@ -527,6 +531,35 @@ const std::vector& getUpgraderBytecodeList() { OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list }), + ByteCodeFunctionWithOperator({ + mobile::Function::registerFunc( + "stft_0_10", + std::vector({ + Instruction{OpCode::STOREN, 1, 8}, + Instruction{OpCode::MOVE, 1, 0}, + Instruction{OpCode::MOVE, 2, 0}, + Instruction{OpCode::MOVE, 3, 0}, + Instruction{OpCode::MOVE, 4, 0}, + Instruction{OpCode::MOVE, 5, 0}, + Instruction{OpCode::LOADC, 1, 0}, + Instruction{OpCode::LOADC, 0, 0}, + Instruction{OpCode::MOVE, 6, 0}, + Instruction{OpCode::MOVE, 7, 0}, + Instruction{OpCode::MOVE, 8, 0}, + Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::RET, 0, 0}, + }), // instructions list, + std::vector({ + c10::IValue("reflect"), + c10::IValue(false), + }), // constants list, + std::vector(), // types list, + 8 + ), + std::vector({ + OperatorString({"aten::stft", "", 10}), + }), // operators list + }), }); for (const auto& upgrader_function : upgrader_function_list) { for (const auto& op : upgrader_function.operators) { diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index 7b09cc409a44..e50227d18ae9 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -15,6 +15,17 @@ namespace torch { namespace jit { static std::unordered_map kUpgradersEntryMap({ + {"stft_0_10", R"SCRIPT( +def stft_0_10( + self: Tensor, n_fft: int, hop_length: Optional[int] = None, + win_length: Optional[int] = None, window: Optional[Tensor] = None, + normalized: bool = False, onesided: Optional[bool] = None, + return_complex: Optional[bool] = None) -> Tensor: + return torch.stft( + self, n_fft=n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=False, normalized=normalized, onesided=onesided, + return_complex=return_complex) +)SCRIPT"}, {"logspace_0_8", R"SCRIPT( def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int], device: Optional[Device], pin_memory: Optional[bool]): diff --git a/torch/csrc/jit/operator_upgraders/version_map.cpp b/torch/csrc/jit/operator_upgraders/version_map.cpp index 1e19f4cc39db..d96527b66fcf 100644 --- a/torch/csrc/jit/operator_upgraders/version_map.cpp +++ b/torch/csrc/jit/operator_upgraders/version_map.cpp @@ -16,7 +16,11 @@ static bool isVersionMapSorted = false; // Note for developers: The list of upgraders should be SORTED // by the version number where the upgrader is registered. static std::unordered_map> operatorVersionMap( - {{"aten::logspace", + {{"aten::stft", + {{11, + "stft_0_10", + "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"}}}, + {"aten::logspace", {{9, "logspace_0_8", "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}}, diff --git a/torch/functional.py b/torch/functional.py index efb98c2b9253..bd6f3f0a195a 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -4,7 +4,7 @@ from typing import ( import torch from torch._C import _add_docstr -import torch.nn.functional as F +import torch.nn.functional from ._lowrank import svd_lowrank, pca_lowrank from .overrides import ( has_torch_function, has_torch_function_unary, has_torch_function_variadic, @@ -476,133 +476,121 @@ def _meshgrid(*tensors, indexing: Optional[str]): return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] -def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, - win_length: Optional[int] = None, window: Optional[Tensor] = None, - center: bool = True, pad_mode: str = 'reflect', normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None) -> Tensor: - r"""Short-time Fourier transform (STFT). +stft = _add_docstr( + torch.stft, + "stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " + "pad_mode='reflect', normalized=False, onesided=None, return_complex=None) -> Tensor" + r""" - .. warning:: - From version 1.8.0, :attr:`return_complex` must always be given - explicitly for real inputs and `return_complex=False` has been - deprecated. Strongly prefer `return_complex=True` as in a future - pytorch release, this function will only return complex tensors. +Short-time Fourier transform (STFT). - Note that :func:`torch.view_as_real` can be used to recover a real - tensor with an extra last dimension for real and imaginary components. +.. warning:: + From version 1.8.0, :attr:`return_complex` must always be given + explicitly for real inputs and `return_complex=False` has been + deprecated. Strongly prefer `return_complex=True` as in a future + pytorch release, this function will only return complex tensors. - The STFT computes the Fourier transform of short overlapping windows of the - input. This giving frequency components of the signal as they change over - time. The interface of this function is modeled after (but *not* a drop-in - replacement for) librosa_ stft function. + Note that :func:`torch.view_as_real` can be used to recover a real + tensor with an extra last dimension for real and imaginary components. - .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html +The STFT computes the Fourier transform of short overlapping windows of the +input. This giving frequency components of the signal as they change over +time. The interface of this function is modeled after (but *not* a drop-in +replacement for) librosa_ stft function. - Ignoring the optional batch dimension, this method computes the following - expression: +.. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html - .. math:: - X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% - \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % - \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), +Ignoring the optional batch dimension, this method computes the following +expression: - where :math:`m` is the index of the sliding window, and :math:`\omega` is - the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, - or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. +.. math:: + X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% + \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % + \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), - * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time - sequences. +where :math:`m` is the index of the sliding window, and :math:`\omega` is +the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, +or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. - * If :attr:`hop_length` is ``None`` (default), it is treated as equal to - ``floor(n_fft / 4)``. +* :attr:`input` must be either a 1-D time sequence or a 2-D batch of time + sequences. - * If :attr:`win_length` is ``None`` (default), it is treated as equal to - :attr:`n_fft`. +* If :attr:`hop_length` is ``None`` (default), it is treated as equal to + ``floor(n_fft / 4)``. - * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from - :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is - treated as if having :math:`1` everywhere in the window. If - :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on - both sides to length :attr:`n_fft` before being applied. +* If :attr:`win_length` is ``None`` (default), it is treated as equal to + :attr:`n_fft`. - * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on - both sides so that the :math:`t`-th frame is centered at time - :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame - begins at time :math:`t \times \text{hop\_length}`. +* :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from + :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is + treated as if having :math:`1` everywhere in the window. If + :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on + both sides to length :attr:`n_fft` before being applied. - * :attr:`pad_mode` determines the padding method used on :attr:`input` when - :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for - all available options. Default is ``"reflect"``. +* If :attr:`center` is ``True`` (default), :attr:`input` will be padded on + both sides so that the :math:`t`-th frame is centered at time + :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame + begins at time :math:`t \times \text{hop\_length}`. - * If :attr:`onesided` is ``True`` (default for real input), only values for - :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor - \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because - the real-to-complex Fourier transform satisfies the conjugate symmetry, - i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. - Note if the input or window tensors are complex, then :attr:`onesided` - output is not possible. +* :attr:`pad_mode` determines the padding method used on :attr:`input` when + :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for + all available options. Default is ``"reflect"``. - * If :attr:`normalized` is ``True`` (default is ``False``), the function - returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. +* If :attr:`onesided` is ``True`` (default for real input), only values for + :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor + \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because + the real-to-complex Fourier transform satisfies the conjugate symmetry, + i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. + Note if the input or window tensors are complex, then :attr:`onesided` + output is not possible. - * If :attr:`return_complex` is ``True`` (default if input is complex), the - return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, - the output is a ``input.dim() + 2`` dimensional real tensor where the last - dimension represents the real and imaginary components. +* If :attr:`normalized` is ``True`` (default is ``False``), the function + returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. - Returns either a complex tensor of size :math:`(* \times N \times T)` if - :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N - \times T \times 2)`. Where :math:`*` is the optional batch size of - :attr:`input`, :math:`N` is the number of frequencies where STFT is applied - and :math:`T` is the total number of frames used. +* If :attr:`return_complex` is ``True`` (default if input is complex), the + return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, + the output is a ``input.dim() + 2`` dimensional real tensor where the last + dimension represents the real and imaginary components. - .. warning:: - This function changed signature at version 0.4.1. Calling with the - previous signature may cause error or return incorrect result. +Returns either a complex tensor of size :math:`(* \times N \times T)` if +:attr:`return_complex` is true, or a real tensor of size :math:`(* \times N +\times T \times 2)`. Where :math:`*` is the optional batch size of +:attr:`input`, :math:`N` is the number of frequencies where STFT is applied +and :math:`T` is the total number of frames used. - Args: - input (Tensor): the input tensor - n_fft (int): size of Fourier transform - hop_length (int, optional): the distance between neighboring sliding window - frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) - win_length (int, optional): the size of window frame and STFT filter. - Default: ``None`` (treated as equal to :attr:`n_fft`) - window (Tensor, optional): the optional window function. - Default: ``None`` (treated as window of all :math:`1` s) - center (bool, optional): whether to pad :attr:`input` on both sides so - that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. - Default: ``True`` - pad_mode (string, optional): controls the padding method used when - :attr:`center` is ``True``. Default: ``"reflect"`` - normalized (bool, optional): controls whether to return the normalized STFT results - Default: ``False`` - onesided (bool, optional): controls whether to return half of results to - avoid redundancy for real inputs. - Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. - return_complex (bool, optional): whether to return a complex tensor, or - a real tensor with an extra last dimension for the real and - imaginary components. +.. warning:: + This function changed signature at version 0.4.1. Calling with the + previous signature may cause error or return incorrect result. - Returns: - Tensor: A tensor containing the STFT result with shape described above +Args: + input (Tensor): the input tensor + n_fft (int): size of Fourier transform + hop_length (int, optional): the distance between neighboring sliding window + frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) + win_length (int, optional): the size of window frame and STFT filter. + Default: ``None`` (treated as equal to :attr:`n_fft`) + window (Tensor, optional): the optional window function. + Default: ``None`` (treated as window of all :math:`1` s) + center (bool, optional): whether to pad :attr:`input` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. Default: ``"reflect"`` + normalized (bool, optional): controls whether to return the normalized STFT results + Default: ``False`` + onesided (bool, optional): controls whether to return half of results to + avoid redundancy for real inputs. + Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. + return_complex (bool, optional): whether to return a complex tensor, or + a real tensor with an extra last dimension for the real and + imaginary components. - """ - if has_torch_function_unary(input): - return handle_torch_function( - stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, - window=window, center=center, pad_mode=pad_mode, normalized=normalized, - onesided=onesided, return_complex=return_complex) - # NOTE: Do not edit. This code will be removed once the forward-compatibility - # period is over for PR #73432 - if center: - signal_dim = input.dim() - extended_shape = [1] * (3 - signal_dim) + list(input.size()) - pad = int(n_fft // 2) - input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) - input = input.view(input.shape[-signal_dim:]) - return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined] - normalized, onesided, return_complex) +Returns: + Tensor: A tensor containing the STFT result with shape described above + +""") +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +stft.__module__ = "torch.functional" istft = _add_docstr(