stft: Require return_complex to be passed explicitly for real input (#86724)

This behavior has been deprecated since PyTorch 1.8 but this step of
the deprecation cycle was put on hold in #50102 waiting for JIT
upgraders functionality which doesn't seem to have panned out. I'd say
there has been more than enough of a deprecation period, so we should
just continue.

BC-breaking message:

`torch.stft` takes an optional `return_complex` parameter that
indicates whether the output should be a floating point tensor or a
complex tensor. `return_complex` previously defaulted to `False` for
real input tensors. This PR removes the default and makes
`return_complex` a required argument for real inputs. However, complex
inputs will continue to default to `return_complex=True`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86724
Approved by: https://github.com/mruberry, https://github.com/albanD
This commit is contained in:
Peter Bell
2022-10-16 20:23:08 +01:00
committed by PyTorch MergeBot
parent 2b7236a0e1
commit 3007efda08
4 changed files with 30 additions and 19 deletions

View File

@ -797,20 +797,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
const bool return_complex = return_complexOpt.value_or(
self.is_complex() || (window.defined() && window.is_complex()));
if (!return_complex) {
if (!return_complexOpt.has_value()) {
TORCH_WARN_ONCE(
"stft will soon require the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release."
);
}
TORCH_CHECK(return_complexOpt.has_value(),
"stft requires the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release.");
// TORCH_WARN_ONCE(
// "stft with return_complex=False is deprecated. In a future pytorch "
// "release, stft will return complex tensors for all inputs, and "
// "return_complex=False will raise an error.\n"
// "Note: you can still call torch.view_as_real on the complex output to "
// "recover the old return format.");
TORCH_WARN_ONCE(
"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
"return_complex=False will raise an error.\n"
"Note: you can still call torch.view_as_real on the complex output to "
"recover the old return format.");
}
if (!at::isFloatingType(self.scalar_type()) && !at::isComplexType(self.scalar_type())) {

View File

@ -1181,9 +1181,8 @@ class TestFFT(TestCase):
@skipCPUIfNoFFT
def test_stft_requires_complex(self, device):
x = torch.rand(100)
y = x.stft(10, pad_mode='constant')
# with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
# y = x.stft(10, pad_mode='constant')
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')
@skipCPUIfNoFFT
def test_fft_input_modification(self, device):

View File

@ -612,6 +612,15 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
a real tensor with an extra last dimension for the real and
imaginary components.
.. versionchanged:: 1.14.0
``return_complex`` is now a required argument for real inputs,
as the default is being transitioned to ``True``.
.. deprecated:: 1.14.0
``return_complex=False`` is deprecated, instead use ``return_complex=True``
Note that calling :func:`torch.view_as_real` on the output will
recover the deprecated output format.
Returns:
Tensor: A tensor containing the STFT result with shape described above

View File

@ -4508,11 +4508,16 @@ def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
def mt(shape, **kwargs):
return make_tensor(shape, device=device, dtype=dtype,
requires_grad=requires_grad, **kwargs)
yield SampleInput(mt(100), kwargs=dict(n_fft=10))
yield SampleInput(mt(100), n_fft=10, return_complex=True)
yield SampleInput(mt(100), n_fft=10, return_complex=False)
if dtype.is_complex:
yield SampleInput(mt(100), n_fft=10)
for center in [False, True]:
yield SampleInput(mt(10), kwargs=dict(n_fft=7, center=center))
yield SampleInput(mt((10, 100)), kwargs=dict(n_fft=16, hop_length=4, center=center))
yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True)
yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4,
center=center, return_complex=True)
window = mt(16, low=.5, high=2.0)
yield SampleInput(
@ -4521,7 +4526,8 @@ def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
if not dtype.is_complex:
yield SampleInput(
mt((10, 100)), kwargs=dict(n_fft=16, window=window, onesided=False))
mt((10, 100)), n_fft=16, window=window, onesided=False,
return_complex=True)
def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs):