PEP585 update - torch/utils (#145201)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 16:17:30 -08:00
committed by PyTorch MergeBot
parent 693d8c7e94
commit 2f9d378f7b
70 changed files with 491 additions and 550 deletions

View File

@ -4,7 +4,7 @@
import torch
from enum import Enum
from torch._C import _MobileOptimizerType as MobileOptimizerType
from typing import Optional, Set, List, AnyStr
from typing import Optional, AnyStr
class LintCode(Enum):
BUNDLED_INPUT = 1
@ -14,8 +14,8 @@ class LintCode(Enum):
def optimize_for_mobile(
script_module: torch.jit.ScriptModule,
optimization_blocklist: Optional[Set[MobileOptimizerType]] = None,
preserved_methods: Optional[List[AnyStr]] = None,
optimization_blocklist: Optional[set[MobileOptimizerType]] = None,
preserved_methods: Optional[list[AnyStr]] = None,
backend: str = 'CPU') -> torch.jit.RecursiveScriptModule:
"""
Optimize a torch script module for mobile deployment.
@ -43,7 +43,7 @@ def optimize_for_mobile(
# Convert potential byte arrays into strings (if there is any) to pass type checking
# Here we use a new name as assigning it back to preserved_methods will invoke
# mypy errors (i.e. List[AnyStr] = List[str])
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
preserved_methods_str: list[str] = [str(method) for method in preserved_methods]
bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str)
if all(hasattr(script_module, method) for method in bundled_inputs_attributes):
@ -114,7 +114,7 @@ def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
return lint_list
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: List[str]) -> List[str]:
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: list[str]) -> list[str]:
bundled_inputs_attributes = []
# Has bundled inputs for forward