Re-add stft option to align window for center = false (#146379)

Skips advancing the fc window on https://github.com/pytorch/pytorch/pull/145437, since I just found that there were non-trivial efforts to do so a while ago that eventually was reverted: https://github.com/pytorch/pytorch/pull/73434

Works around the issue by keeping the stft sans center overload

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146379
Approved by: https://github.com/justinchuby, https://github.com/iseeyuan
This commit is contained in:
Jack Zhang
2025-02-06 14:07:13 +00:00
committed by PyTorch MergeBot
parent 1b79d47635
commit ed309b9156
9 changed files with 58 additions and 11 deletions

View File

@ -826,7 +826,7 @@ static Stream& write_opt(Stream& SS, const std::optional<T>& value) {
Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
const bool center, std::string_view mode, const bool normalized,
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt, const std::optional<bool> align_to_windowOpt) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
const Tensor& window = *window_maybe_owned;
@ -853,11 +853,14 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
} \
SS << ", normalized=" << normalized << ", onesided="; \
write_opt(SS, onesidedOpt) << ", return_complex="; \
write_opt(SS, return_complexOpt) << ") "
write_opt(SS, return_complexOpt) << ", align_to_window="; \
write_opt(SS, align_to_windowOpt) << ") "
TORCH_CHECK(!window.defined() || window.device() == self.device(),
"stft input and window must be on the same device but got self on ",
self.device(), " and window on ", window.device())
TORCH_CHECK(!center || !align_to_windowOpt.has_value(),
"stft align_to_window should only be set when center = false.")
// default_init hop_length and win_length
auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
@ -869,7 +872,6 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
"stft requires the return_complex parameter be given for real inputs, "
"and will further require that return_complex=True in a future PyTorch release.");
TORCH_WARN_ONCE(
"stft with return_complex=False is deprecated. In a future pytorch "
"release, stft will return complex tensors for all inputs, and "
@ -943,7 +945,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
window_.narrow(0, left, win_length).fill_(1);
}
}
int64_t n_frames = 1 + (len - n_fft) / hop_length;
const bool align_to_window = align_to_windowOpt.value_or(false);
int64_t n_frames;
if (!center && align_to_window) {
// Calculate n_frames based on window length, since we are aligning start of window with t = 0.
n_frames = 1 + (len - win_length) / hop_length;
// Window-based padding.
input = at::pad(input, {(n_fft - win_length) / 2, (n_fft - win_length) / 2}, mode);
} else {
n_frames = 1 + (len - n_fft) / hop_length;
}
// time2col
input = input.as_strided(
{batch, n_frames, n_fft},
@ -982,11 +994,12 @@ Tensor stft(
const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
const bool normalized,
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt,
const std::optional<bool> align_to_windowOpt) {
return at::stft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
return_complexOpt);
return_complexOpt, align_to_windowOpt);
}
// Create complex tensor from the old style of real tensor with size=(..., 2)

View File

@ -5758,11 +5758,11 @@
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
# Overload without center & pad mode, needed for forward-compatibility
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
variants: function, method
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
variants: function, method
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor

View File

@ -1226,6 +1226,14 @@ class TestFFT(TestCase):
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
y = x.stft(10, pad_mode='constant')
@onlyNativeDeviceTypes
@skipCPUIfNoFFT
def test_stft_align_to_window_only_requires_non_center(self, device):
x = torch.rand(100)
for align_to_window in [True, False]:
with self.assertRaisesRegex(RuntimeError, 'stft align_to_window should only be set when center = false'):
y = x.stft(10, center=True, return_complex=True, align_to_window=align_to_window)
# stft and istft are currently warning if a window is not provided
@onlyNativeDeviceTypes
@skipCPUIfNoFFT

View File

@ -3362,6 +3362,7 @@ def stft(
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
align_to_window: Optional[bool] = None,
) -> Tensor:
torch._check(
window is None or window.device == input.device,
@ -3370,6 +3371,10 @@ def stft(
+ f" and window on {window.device}" # type: ignore[union-attr]
),
)
torch._check(
not center or align_to_window is None,
"stft only supports align_to_window for center = False.",
)
hop_length_ = hop_length if hop_length is not None else n_fft // 4
win_length_ = win_length if win_length is not None else n_fft
@ -3433,6 +3438,9 @@ def stft(
window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
if not center and align_to_window:
input_pad_amount = (n_fft - win_length_) // 2
input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode)
if window is not None:
input = input * window

View File

@ -940,6 +940,7 @@ class Tensor(torch._C.TensorBase):
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
align_to_window: Optional[bool] = None,
):
r"""See :func:`torch.stft`
@ -961,6 +962,7 @@ class Tensor(torch._C.TensorBase):
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
align_to_window=align_to_window,
)
return torch.stft(
self,
@ -973,6 +975,7 @@ class Tensor(torch._C.TensorBase):
normalized,
onesided,
return_complex=return_complex,
align_to_window=align_to_window,
)
def istft(

View File

@ -6398,7 +6398,8 @@ See :func:`torch.dsplit`
add_docstr_all(
"stft",
r"""
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None,
pad_end=0, align_to_window=None) -> Tensor
See :func:`torch.stft`
""",

View File

@ -551,6 +551,7 @@ def stft(
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
align_to_window: Optional[bool] = None,
) -> Tensor:
r"""Short-time Fourier transform (STFT).
@ -698,6 +699,11 @@ def stft(
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
align_to_window=align_to_window,
)
if center and align_to_window is not None:
raise RuntimeError(
"stft align_to_window should only be set when center = false"
)
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
@ -716,6 +722,7 @@ def stft(
normalized,
onesided,
return_complex,
align_to_window,
)

View File

@ -98,7 +98,7 @@ def _compute_edge_sizes(n_fft, window_size):
@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
def stft(
g: jit_utils.GraphContext,
input: _C.Value,
@ -109,6 +109,7 @@ def stft(
normalized: bool = False,
onesided: Optional[bool] = True,
return_complex: Optional[bool] = False,
align_to_window: Optional[bool] = None,
) -> _C.Value:
"""Associates `torch.stft` with the `STFT` ONNX operator.
Note that torch.stft calls _VF.stft, without centering or padding options.
@ -137,6 +138,12 @@ def stft(
msg="STFT does not currently support complex types", value=input
)
if align_to_window is not None:
raise errors.SymbolicValueError(
msg="STFT does not currently support the align_to_window option",
value=input,
) # TODO(#145944): add compatibility with align_to_window option.
# Get STFT sizes
frame_step_value = hop_length if hop_length is not None else n_fft // 4
frame_step_const = g.op(

View File

@ -1130,7 +1130,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
torch.std: lambda input, dim=None: -1,
torch.std_mean: lambda input, dim=None: -1,
torch.stft: (
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950
),
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,