mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/fx (#145166)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
6374332d33
commit
0b2a3687b9
@ -1524,6 +1524,29 @@ class {test_classname}(torch.nn.Module):
|
||||
(int, type(torch.float)),
|
||||
(Union[int, float], int),
|
||||
(Union[int, float], float),
|
||||
(list[int], int),
|
||||
(list[int], create_type_hint([int, int])),
|
||||
(list[int], create_type_hint((int, int))),
|
||||
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
||||
(
|
||||
list[torch.Tensor],
|
||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||
),
|
||||
(torch.Tensor, torch.nn.Parameter),
|
||||
(list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||
(list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||
(
|
||||
list[torch.Tensor],
|
||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||
),
|
||||
(torch.Tensor, torch.nn.Parameter),
|
||||
(list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||
(list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||
(Optional[list[torch.Tensor]], list[torch.Tensor]),
|
||||
(Optional[list[int]], list[int]),
|
||||
] + [
|
||||
# pre-PEP585 signatures
|
||||
(List[int], int),
|
||||
(List[int], create_type_hint([int, int])),
|
||||
(List[int], create_type_hint((int, int))),
|
||||
@ -1532,7 +1555,6 @@ class {test_classname}(torch.nn.Module):
|
||||
List[torch.Tensor],
|
||||
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
||||
),
|
||||
(torch.Tensor, torch.nn.Parameter),
|
||||
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
||||
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
||||
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
||||
@ -1540,18 +1562,21 @@ class {test_classname}(torch.nn.Module):
|
||||
List[torch.Tensor],
|
||||
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
||||
),
|
||||
(torch.Tensor, torch.nn.Parameter),
|
||||
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
||||
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
||||
(Optional[List[torch.Tensor]], List[torch.Tensor]),
|
||||
(Optional[List[int]], List[int]),
|
||||
]
|
||||
|
||||
for sig_type, arg_type in should_be_equal:
|
||||
self.assertTrue(type_matches(sig_type, arg_type))
|
||||
|
||||
should_fail = [
|
||||
(int, float),
|
||||
(Union[int, float], str),
|
||||
(list[torch.Tensor], List[int]),
|
||||
] + [
|
||||
# pre-PEP585 signatures
|
||||
(List[torch.Tensor], List[int]),
|
||||
]
|
||||
|
||||
|
Reference in New Issue
Block a user