mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re-add stft option to align window for center = false (#146379)
Skips advancing the fc window on https://github.com/pytorch/pytorch/pull/145437, since I just found that there were non-trivial efforts to do so a while ago that eventually was reverted: https://github.com/pytorch/pytorch/pull/73434 Works around the issue by keeping the stft sans center overload Pull Request resolved: https://github.com/pytorch/pytorch/pull/146379 Approved by: https://github.com/justinchuby, https://github.com/iseeyuan
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b79d47635
commit
ed309b9156
@ -826,7 +826,7 @@ static Stream& write_opt(Stream& SS, const std::optional<T>& value) {
|
||||
Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
|
||||
const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
|
||||
const bool center, std::string_view mode, const bool normalized,
|
||||
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
|
||||
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt, const std::optional<bool> align_to_windowOpt) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> window_maybe_owned = at::borrow_from_optional_tensor(window_opt);
|
||||
const Tensor& window = *window_maybe_owned;
|
||||
@ -853,11 +853,14 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
} \
|
||||
SS << ", normalized=" << normalized << ", onesided="; \
|
||||
write_opt(SS, onesidedOpt) << ", return_complex="; \
|
||||
write_opt(SS, return_complexOpt) << ") "
|
||||
write_opt(SS, return_complexOpt) << ", align_to_window="; \
|
||||
write_opt(SS, align_to_windowOpt) << ") "
|
||||
|
||||
TORCH_CHECK(!window.defined() || window.device() == self.device(),
|
||||
"stft input and window must be on the same device but got self on ",
|
||||
self.device(), " and window on ", window.device())
|
||||
TORCH_CHECK(!center || !align_to_windowOpt.has_value(),
|
||||
"stft align_to_window should only be set when center = false.")
|
||||
|
||||
// default_init hop_length and win_length
|
||||
auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
|
||||
@ -869,7 +872,6 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
"stft requires the return_complex parameter be given for real inputs, "
|
||||
"and will further require that return_complex=True in a future PyTorch release.");
|
||||
|
||||
|
||||
TORCH_WARN_ONCE(
|
||||
"stft with return_complex=False is deprecated. In a future pytorch "
|
||||
"release, stft will return complex tensors for all inputs, and "
|
||||
@ -943,7 +945,17 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
window_.narrow(0, left, win_length).fill_(1);
|
||||
}
|
||||
}
|
||||
int64_t n_frames = 1 + (len - n_fft) / hop_length;
|
||||
|
||||
const bool align_to_window = align_to_windowOpt.value_or(false);
|
||||
int64_t n_frames;
|
||||
if (!center && align_to_window) {
|
||||
// Calculate n_frames based on window length, since we are aligning start of window with t = 0.
|
||||
n_frames = 1 + (len - win_length) / hop_length;
|
||||
// Window-based padding.
|
||||
input = at::pad(input, {(n_fft - win_length) / 2, (n_fft - win_length) / 2}, mode);
|
||||
} else {
|
||||
n_frames = 1 + (len - n_fft) / hop_length;
|
||||
}
|
||||
// time2col
|
||||
input = input.as_strided(
|
||||
{batch, n_frames, n_fft},
|
||||
@ -982,11 +994,12 @@ Tensor stft(
|
||||
const Tensor& self, const int64_t n_fft, const std::optional<int64_t> hop_lengthOpt,
|
||||
const std::optional<int64_t> win_lengthOpt, const std::optional<Tensor>& window_opt,
|
||||
const bool normalized,
|
||||
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt) {
|
||||
const std::optional<bool> onesidedOpt, const std::optional<bool> return_complexOpt,
|
||||
const std::optional<bool> align_to_windowOpt) {
|
||||
return at::stft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window_opt,
|
||||
/*center=*/false, /*mode=*/"constant", normalized, onesidedOpt,
|
||||
return_complexOpt);
|
||||
return_complexOpt, align_to_windowOpt);
|
||||
}
|
||||
|
||||
// Create complex tensor from the old style of real tensor with size=(..., 2)
|
||||
|
@ -5758,11 +5758,11 @@
|
||||
- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
# Overload without center & pad mode, needed for forward-compatibility
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
|
||||
variants: function, method
|
||||
cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
|
||||
|
||||
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor
|
||||
- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
|
||||
|
@ -1226,6 +1226,14 @@ class TestFFT(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
|
||||
y = x.stft(10, pad_mode='constant')
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@skipCPUIfNoFFT
|
||||
def test_stft_align_to_window_only_requires_non_center(self, device):
|
||||
x = torch.rand(100)
|
||||
for align_to_window in [True, False]:
|
||||
with self.assertRaisesRegex(RuntimeError, 'stft align_to_window should only be set when center = false'):
|
||||
y = x.stft(10, center=True, return_complex=True, align_to_window=align_to_window)
|
||||
|
||||
# stft and istft are currently warning if a window is not provided
|
||||
@onlyNativeDeviceTypes
|
||||
@skipCPUIfNoFFT
|
||||
|
@ -3362,6 +3362,7 @@ def stft(
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None,
|
||||
align_to_window: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
torch._check(
|
||||
window is None or window.device == input.device,
|
||||
@ -3370,6 +3371,10 @@ def stft(
|
||||
+ f" and window on {window.device}" # type: ignore[union-attr]
|
||||
),
|
||||
)
|
||||
torch._check(
|
||||
not center or align_to_window is None,
|
||||
"stft only supports align_to_window for center = False.",
|
||||
)
|
||||
|
||||
hop_length_ = hop_length if hop_length is not None else n_fft // 4
|
||||
win_length_ = win_length if win_length is not None else n_fft
|
||||
@ -3433,6 +3438,9 @@ def stft(
|
||||
window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
|
||||
|
||||
input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
|
||||
if not center and align_to_window:
|
||||
input_pad_amount = (n_fft - win_length_) // 2
|
||||
input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode)
|
||||
if window is not None:
|
||||
input = input * window
|
||||
|
||||
|
@ -940,6 +940,7 @@ class Tensor(torch._C.TensorBase):
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None,
|
||||
align_to_window: Optional[bool] = None,
|
||||
):
|
||||
r"""See :func:`torch.stft`
|
||||
|
||||
@ -961,6 +962,7 @@ class Tensor(torch._C.TensorBase):
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
return_complex=return_complex,
|
||||
align_to_window=align_to_window,
|
||||
)
|
||||
return torch.stft(
|
||||
self,
|
||||
@ -973,6 +975,7 @@ class Tensor(torch._C.TensorBase):
|
||||
normalized,
|
||||
onesided,
|
||||
return_complex=return_complex,
|
||||
align_to_window=align_to_window,
|
||||
)
|
||||
|
||||
def istft(
|
||||
|
@ -6398,7 +6398,8 @@ See :func:`torch.dsplit`
|
||||
add_docstr_all(
|
||||
"stft",
|
||||
r"""
|
||||
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None, pad_end=0) -> Tensor
|
||||
stft(frame_length, hop, fft_size=None, return_onesided=True, window=None,
|
||||
pad_end=0, align_to_window=None) -> Tensor
|
||||
|
||||
See :func:`torch.stft`
|
||||
""",
|
||||
|
@ -551,6 +551,7 @@ def stft(
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = None,
|
||||
return_complex: Optional[bool] = None,
|
||||
align_to_window: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
r"""Short-time Fourier transform (STFT).
|
||||
|
||||
@ -698,6 +699,11 @@ def stft(
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
return_complex=return_complex,
|
||||
align_to_window=align_to_window,
|
||||
)
|
||||
if center and align_to_window is not None:
|
||||
raise RuntimeError(
|
||||
"stft align_to_window should only be set when center = false"
|
||||
)
|
||||
# NOTE: Do not edit. This code will be removed once the forward-compatibility
|
||||
# period is over for PR #73432
|
||||
@ -716,6 +722,7 @@ def stft(
|
||||
normalized,
|
||||
onesided,
|
||||
return_complex,
|
||||
align_to_window,
|
||||
)
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ def _compute_edge_sizes(n_fft, window_size):
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::stft")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
|
||||
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b", "b")
|
||||
def stft(
|
||||
g: jit_utils.GraphContext,
|
||||
input: _C.Value,
|
||||
@ -109,6 +109,7 @@ def stft(
|
||||
normalized: bool = False,
|
||||
onesided: Optional[bool] = True,
|
||||
return_complex: Optional[bool] = False,
|
||||
align_to_window: Optional[bool] = None,
|
||||
) -> _C.Value:
|
||||
"""Associates `torch.stft` with the `STFT` ONNX operator.
|
||||
Note that torch.stft calls _VF.stft, without centering or padding options.
|
||||
@ -137,6 +138,12 @@ def stft(
|
||||
msg="STFT does not currently support complex types", value=input
|
||||
)
|
||||
|
||||
if align_to_window is not None:
|
||||
raise errors.SymbolicValueError(
|
||||
msg="STFT does not currently support the align_to_window option",
|
||||
value=input,
|
||||
) # TODO(#145944): add compatibility with align_to_window option.
|
||||
|
||||
# Get STFT sizes
|
||||
frame_step_value = hop_length if hop_length is not None else n_fft // 4
|
||||
frame_step_const = g.op(
|
||||
|
@ -1130,7 +1130,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
||||
torch.std: lambda input, dim=None: -1,
|
||||
torch.std_mean: lambda input, dim=None: -1,
|
||||
torch.stft: (
|
||||
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
|
||||
lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950
|
||||
),
|
||||
torch.sub: lambda input, other, out=None: -1,
|
||||
torch.subtract: lambda input, other, out=None: -1,
|
||||
|
Reference in New Issue
Block a user