Migrate from Tuple -> tuple in torch/_decomp (#144260)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144260
Approved by: https://github.com/aorenste
This commit is contained in:
bobrenjc93
2025-01-09 22:30:48 +00:00
committed by PyTorch MergeBot
parent 3607ff2c1d
commit 8db67e0319
2 changed files with 43 additions and 43 deletions

View File

@ -8,7 +8,7 @@ import sys
from enum import Enum
from functools import partial, reduce
from itertools import chain, product
from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Iterable, List, Optional, Union
import torch
import torch._meta_registrations
@ -299,7 +299,7 @@ def _prelu_kernel_backward(
grad_output: Tensor,
self: Tensor,
weight: Tensor,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
input_grad = torch.where(self > 0, grad_output, weight * grad_output)
weight_grad = torch.where(self > 0, 0.0, self * grad_output)
return (input_grad, weight_grad)
@ -760,7 +760,7 @@ def slice_forward(
def _normalize_start_end(
x: Tensor, dim: int, start: Optional[int], end: Optional[int]
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Normalize start and end such that both are in the range
[0, x.get_size()[dim]] and start <= end.
@ -1376,19 +1376,19 @@ def split_with_sizes_copy(
@register_decomposition(aten.unsafe_split.Tensor)
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
return aten.split.Tensor(input, split_size, dim)
@register_decomposition(aten.unsafe_split_with_sizes.default)
def unsafe_split_with_sizes(
input: Tensor, split_sizes: List[int], dim: int = 0
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
return aten.split_with_sizes.default(input, split_sizes, dim)
@register_decomposition(aten.split.Tensor)
def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
input_sizes = self.shape
dim_size = input_sizes[dim]
if split_size == 0:
@ -1412,7 +1412,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
self: Tensor,
tensor_indices_or_sections: Tensor,
dim: int = 0,
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
assert tensor_indices_or_sections.device.type == "cpu"
assert tensor_indices_or_sections.dtype == torch.int64
split_dim = tensor_indices_or_sections.dim()
@ -1506,7 +1506,7 @@ def native_group_norm_backward(
HxW: int,
group: int,
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
utils.check_same_device(
grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
)
@ -1598,7 +1598,7 @@ def native_group_norm_backward_out(
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
result = native_group_norm_backward(
grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
)
@ -1628,7 +1628,7 @@ def native_layer_norm_backward(
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
@ -1715,7 +1715,7 @@ def native_layer_norm_backward_out(
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
result = native_layer_norm_backward(
grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
)
@ -1738,7 +1738,7 @@ def native_batch_norm_helper(
momentum: float,
eps: float,
functional: bool,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
reduction_dims = [0] + list(range(2, input.dim()))
computation_dtype = utils.get_computation_dtype(input.dtype)
new_running_mean = running_mean
@ -1821,7 +1821,7 @@ def native_batch_norm(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, running_mean, running_var, training, momentum, eps, False
)
@ -1849,7 +1849,7 @@ def native_batch_norm_decomposition(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
if running_mean is None and running_var is None:
return aten._native_batch_norm_legit(
input, weight, bias, training, momentum, eps
@ -1896,7 +1896,7 @@ def _native_batch_norm_legit_no_training(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
return aten._native_batch_norm_legit.default(
input,
weight,
@ -1919,7 +1919,7 @@ def _native_batch_norm_legit(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, running_mean, running_var, training, momentum, eps, False
)
@ -1934,7 +1934,7 @@ def _native_batch_norm_legit_no_stats(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, None, None, training, momentum, eps, False
)
@ -1951,7 +1951,7 @@ def _native_batch_norm_legit_functional(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
(
output,
save_mean,
@ -2002,7 +2002,7 @@ def _batch_norm_with_update(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
@ -2029,7 +2029,7 @@ def _batch_norm_with_update_functional(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
(
output,
save_mean,
@ -2056,7 +2056,7 @@ def _batch_norm_no_update(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
@ -2192,7 +2192,7 @@ def batch_norm_backward(
eps: float,
output_mask: List[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,
@ -2219,7 +2219,7 @@ def native_batch_norm_backward(
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_dtype = input.dtype
if weight is not None:
weight_dtype = weight.dtype
@ -2325,7 +2325,7 @@ def native_batch_norm_backward_out(
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
result = native_batch_norm_backward(
grad_out,
input,
@ -2403,7 +2403,7 @@ def cudnn_batch_norm_backward(
@register_decomposition(aten._adaptive_avg_pool2d)
@out_wrapper()
@pw_cast_for_opmath
def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]):
# Preconditions
device = input.device
shape = input.shape
@ -2761,7 +2761,7 @@ def _index_copy(
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper("output", "buffer")
@pw_cast_for_opmath
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda or self.is_xpu:
@ -3937,7 +3937,7 @@ def _nll_loss_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
# self can be [N, C] or [C]
# target can be [N] or []
@ -3992,7 +3992,7 @@ def nll_loss_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
assert (
target.dim() <= 1
@ -4020,7 +4020,7 @@ def nll_loss2d_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
return _nll_loss_forward(self, target, weight, reduction, ignore_index)
@ -4520,7 +4520,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
@pw_cast_for_opmath
def upsample_bicubic2d_default(
input: Tensor,
output_size: Tuple[int, int],
output_size: tuple[int, int],
align_corners: bool,
scale_h: Optional[float] = None,
scale_w: Optional[float] = None,
@ -4608,9 +4608,9 @@ def upsample_bicubic2d_default(
@pw_cast_for_opmath
def upsample_bicubic2d_vec(
a: Tensor,
output_size: Optional[Tuple[int, int]],
output_size: Optional[tuple[int, int]],
align_corners: bool,
scale_factors: Optional[Tuple[float, float]] = None,
scale_factors: Optional[tuple[float, float]] = None,
) -> Tensor:
torch._check(
bool(output_size) + bool(scale_factors) == 1,
@ -4619,7 +4619,7 @@ def upsample_bicubic2d_vec(
if output_size is None:
assert scale_factors is not None
output_size = cast(
Tuple[int, int],
tuple[int, int],
tuple(
sym_int(sym_float(w) * scale)
for w, scale in zip(a.shape[2:], scale_factors)
@ -4634,7 +4634,7 @@ def upsample_bicubic2d_vec(
@register_decomposition(aten.reflection_pad3d)
@pw_cast_for_opmath
@out_wrapper()
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
@ -4651,7 +4651,7 @@ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
@register_decomposition(aten.replication_pad3d)
@pw_cast_for_opmath
@out_wrapper()
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return torch.clamp(dim_idx, 0, middle - 1)
@ -4665,7 +4665,7 @@ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _reflection_or_replication_pad(
a: Tensor,
padding: Tuple[int, ...],
padding: tuple[int, ...],
idx_fn: Callable[[int, int, int], Tensor],
) -> Tensor:
dim = len(padding) // 2
@ -4887,7 +4887,7 @@ def multilabel_margin_loss_forward(
input: Tensor,
target: Tensor,
reduction: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
orig_input_shape = input.shape
orig_target_shape = target.shape
input = torch.atleast_2d(input)
@ -4953,7 +4953,7 @@ def scaled_dot_product_flash_attention_for_cpu(
*,
attn_mask: Optional[Tensor] = None,
scale: Optional[float] = None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
torch._check(
torch.is_floating_point(query),
lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional
import torch
import torch._decomp
@ -104,7 +104,7 @@ def trace(self: Tensor) -> Tensor:
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda or self.is_xpu:
@ -138,7 +138,7 @@ def native_layer_norm_backward(
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
@ -224,7 +224,7 @@ def native_batch_norm_backward(
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"
@ -307,7 +307,7 @@ def batch_norm_backward(
eps: float,
output_mask: List[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,