mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
stft: Require return_complex to be passed explicitly for real input (#86724)
This behavior has been deprecated since PyTorch 1.8 but this step of the deprecation cycle was put on hold in #50102 waiting for JIT upgraders functionality which doesn't seem to have panned out. I'd say there has been more than enough of a deprecation period, so we should just continue. BC-breaking message: `torch.stft` takes an optional `return_complex` parameter that indicates whether the output should be a floating point tensor or a complex tensor. `return_complex` previously defaulted to `False` for real input tensors. This PR removes the default and makes `return_complex` a required argument for real inputs. However, complex inputs will continue to default to `return_complex=True`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86724 Approved by: https://github.com/mruberry, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b7236a0e1
commit
3007efda08
@ -797,20 +797,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
|
||||
const bool return_complex = return_complexOpt.value_or(
|
||||
self.is_complex() || (window.defined() && window.is_complex()));
|
||||
if (!return_complex) {
|
||||
if (!return_complexOpt.has_value()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"stft will soon require the return_complex parameter be given for real inputs, "
|
||||
"and will further require that return_complex=True in a future PyTorch release."
|
||||
);
|
||||
}
|
||||
TORCH_CHECK(return_complexOpt.has_value(),
|
||||
"stft requires the return_complex parameter be given for real inputs, "
|
||||
"and will further require that return_complex=True in a future PyTorch release.");
|
||||
|
||||
|
||||
// TORCH_WARN_ONCE(
|
||||
// "stft with return_complex=False is deprecated. In a future pytorch "
|
||||
// "release, stft will return complex tensors for all inputs, and "
|
||||
// "return_complex=False will raise an error.\n"
|
||||
// "Note: you can still call torch.view_as_real on the complex output to "
|
||||
// "recover the old return format.");
|
||||
TORCH_WARN_ONCE(
|
||||
"stft with return_complex=False is deprecated. In a future pytorch "
|
||||
"release, stft will return complex tensors for all inputs, and "
|
||||
"return_complex=False will raise an error.\n"
|
||||
"Note: you can still call torch.view_as_real on the complex output to "
|
||||
"recover the old return format.");
|
||||
}
|
||||
|
||||
if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {
|
||||
|
@ -1181,9 +1181,8 @@ class TestFFT(TestCase):
|
||||
@skipCPUIfNoFFT
|
||||
def test_stft_requires_complex(self, device):
|
||||
x = torch.rand(100)
|
||||
y = x.stft(10, pad_mode='constant')
|
||||
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
|
||||
# y = x.stft(10, pad_mode='constant')
|
||||
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
|
||||
y = x.stft(10, pad_mode='constant')
|
||||
|
||||
@skipCPUIfNoFFT
|
||||
def test_fft_input_modification(self, device):
|
||||
|
@ -612,6 +612,15 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
|
||||
a real tensor with an extra last dimension for the real and
|
||||
imaginary components.
|
||||
|
||||
.. versionchanged:: 1.14.0
|
||||
``return_complex`` is now a required argument for real inputs,
|
||||
as the default is being transitioned to ``True``.
|
||||
|
||||
.. deprecated:: 1.14.0
|
||||
``return_complex=False`` is deprecated, instead use ``return_complex=True``
|
||||
Note that calling :func:`torch.view_as_real` on the output will
|
||||
recover the deprecated output format.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor containing the STFT result with shape described above
|
||||
|
||||
|
@ -4508,11 +4508,16 @@ def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def mt(shape, **kwargs):
|
||||
return make_tensor(shape, device=device, dtype=dtype,
|
||||
requires_grad=requires_grad, **kwargs)
|
||||
yield SampleInput(mt(100), kwargs=dict(n_fft=10))
|
||||
|
||||
yield SampleInput(mt(100), n_fft=10, return_complex=True)
|
||||
yield SampleInput(mt(100), n_fft=10, return_complex=False)
|
||||
if dtype.is_complex:
|
||||
yield SampleInput(mt(100), n_fft=10)
|
||||
|
||||
for center in [False, True]:
|
||||
yield SampleInput(mt(10), kwargs=dict(n_fft=7, center=center))
|
||||
yield SampleInput(mt((10, 100)), kwargs=dict(n_fft=16, hop_length=4, center=center))
|
||||
yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True)
|
||||
yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4,
|
||||
center=center, return_complex=True)
|
||||
|
||||
window = mt(16, low=.5, high=2.0)
|
||||
yield SampleInput(
|
||||
@ -4521,7 +4526,8 @@ def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
|
||||
mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
|
||||
if not dtype.is_complex:
|
||||
yield SampleInput(
|
||||
mt((10, 100)), kwargs=dict(n_fft=16, window=window, onesided=False))
|
||||
mt((10, 100)), n_fft=16, window=window, onesided=False,
|
||||
return_complex=True)
|
||||
|
||||
|
||||
def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
Reference in New Issue
Block a user