stft: Require return_complex to be passed explicitly for real input (#86724)

This behavior has been deprecated since PyTorch 1.8 but this step of
the deprecation cycle was put on hold in #50102 waiting for JIT
upgraders functionality which doesn't seem to have panned out. I'd say
there has been more than enough of a deprecation period, so we should
just continue.

BC-breaking message:

`torch.stft` takes an optional `return_complex` parameter that
indicates whether the output should be a floating point tensor or a
complex tensor. `return_complex` previously defaulted to `False` for
real input tensors. This PR removes the default and makes
`return_complex` a required argument for real inputs. However, complex
inputs will continue to default to `return_complex=True`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86724
Approved by: https://github.com/mruberry, https://github.com/albanD
This commit is contained in:
Peter Bell
2022-10-16 20:23:08 +01:00
committed by PyTorch MergeBot
parent 2b7236a0e1
commit 3007efda08
4 changed files with 30 additions and 19 deletions

View File

@ -1181,9 +1181,8 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
def test_stft_requires_complex(self, device):
x = torch.rand(100)
y = x.stft(10, pad_mode='constant')
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
# y = x.stft(10, pad_mode='constant')
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')
@skipCPUIfNoFFT
def test_fft_input_modification(self, device):