PEP585 update - torch/testing (#145200)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 21:32:39 -08:00
committed by PyTorch MergeBot
parent 805c4b597a
commit dea7ad3371
37 changed files with 262 additions and 298 deletions

View File

@ -21,7 +21,7 @@ from datetime import timedelta
from enum import Enum
from functools import partial, reduce, wraps
from io import StringIO
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable
from typing import NamedTuple, Optional, Union, Any, Callable
from unittest.mock import patch
from torch._logging._internal import trace_log
@ -963,7 +963,7 @@ class DistributedTestBase(MultiProcessTestCase):
def run_subtests(
cls_inst,
subtest_config: Dict[str, List[Any]],
subtest_config: dict[str, list[Any]],
test_fn: Callable,
*test_args,
**test_kwargs: Any,
@ -982,9 +982,9 @@ def run_subtests(
test_kwargs: Keyword arguments to pass to ``test_fn``.
"""
# Convert the config mapping to a list to have a fixed order
subtest_config_items: List[tuple[str, List[Any]]] = list(subtest_config.items())
subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
for values in itertools.product(*subtest_config_values):
# Map keyword to chosen value
subtest_kwargs = dict(zip(subtest_config_keys, values))
@ -1314,7 +1314,7 @@ class MultiThreadedTestCase(TestCase):
class SaveForwardInputsModule(nn.Module):
def __init__(
self,
forward_inputs: Dict[nn.Module, torch.Tensor],
forward_inputs: dict[nn.Module, torch.Tensor],
cast_forward_inputs: bool,
) -> None:
super().__init__()
@ -1330,7 +1330,7 @@ class SaveForwardInputsModule(nn.Module):
class SaveForwardInputsModel(nn.Module):
def __init__(
self,
forward_inputs: Dict[nn.Module, torch.Tensor],
forward_inputs: dict[nn.Module, torch.Tensor],
cast_forward_inputs: bool,
) -> None:
super().__init__()