PEP585 update - torch/fx (#145166)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 19:32:07 -08:00
committed by PyTorch MergeBot
parent 6374332d33
commit 0b2a3687b9
57 changed files with 904 additions and 917 deletions

View File

@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module):
(int, type(torch.float)),
(Union[int, float], int),
(Union[int, float], float),
(list[int], int),
(list[int], create_type_hint([int, int])),
(list[int], create_type_hint((int, int))),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
(
list[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
(
list[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(torch.Tensor, torch.nn.Parameter),
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[list[torch.Tensor]], list[torch.Tensor]),
(Optional[list[int]], list[int]),
] + [
# pre-PEP585 signatures
(List[int], int),
(List[int], create_type_hint([int, int])),
(List[int], create_type_hint((int, int))),
@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor],
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module):
List[torch.Tensor],
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
),
(torch.Tensor, torch.nn.Parameter),
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
(Optional[List[torch.Tensor]], List[torch.Tensor]),
(Optional[List[int]], List[int]),
]
for sig_type, arg_type in should_be_equal:
self.assertTrue(type_matches(sig_type, arg_type))
should_fail = [
(int, float),
(Union[int, float], str),
(list[torch.Tensor], List[int]),
] + [
# pre-PEP585 signatures
(List[torch.Tensor], List[int]),
]