mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/testing (#145200)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
805c4b597a
commit
dea7ad3371
@ -21,7 +21,7 @@ from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial, reduce, wraps
|
||||
from io import StringIO
|
||||
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable
|
||||
from typing import NamedTuple, Optional, Union, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch._logging._internal import trace_log
|
||||
@ -963,7 +963,7 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
|
||||
def run_subtests(
|
||||
cls_inst,
|
||||
subtest_config: Dict[str, List[Any]],
|
||||
subtest_config: dict[str, list[Any]],
|
||||
test_fn: Callable,
|
||||
*test_args,
|
||||
**test_kwargs: Any,
|
||||
@ -982,9 +982,9 @@ def run_subtests(
|
||||
test_kwargs: Keyword arguments to pass to ``test_fn``.
|
||||
"""
|
||||
# Convert the config mapping to a list to have a fixed order
|
||||
subtest_config_items: List[tuple[str, List[Any]]] = list(subtest_config.items())
|
||||
subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
|
||||
subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
|
||||
subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
|
||||
subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
|
||||
subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
|
||||
for values in itertools.product(*subtest_config_values):
|
||||
# Map keyword to chosen value
|
||||
subtest_kwargs = dict(zip(subtest_config_keys, values))
|
||||
@ -1314,7 +1314,7 @@ class MultiThreadedTestCase(TestCase):
|
||||
class SaveForwardInputsModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor],
|
||||
forward_inputs: dict[nn.Module, torch.Tensor],
|
||||
cast_forward_inputs: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -1330,7 +1330,7 @@ class SaveForwardInputsModule(nn.Module):
|
||||
class SaveForwardInputsModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor],
|
||||
forward_inputs: dict[nn.Module, torch.Tensor],
|
||||
cast_forward_inputs: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user