mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
PEP585 update - torch/utils (#145201)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
693d8c7e94
commit
2f9d378f7b
@ -4,7 +4,7 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
from torch._C import _MobileOptimizerType as MobileOptimizerType
|
||||
from typing import Optional, Set, List, AnyStr
|
||||
from typing import Optional, AnyStr
|
||||
|
||||
class LintCode(Enum):
|
||||
BUNDLED_INPUT = 1
|
||||
@ -14,8 +14,8 @@ class LintCode(Enum):
|
||||
|
||||
def optimize_for_mobile(
|
||||
script_module: torch.jit.ScriptModule,
|
||||
optimization_blocklist: Optional[Set[MobileOptimizerType]] = None,
|
||||
preserved_methods: Optional[List[AnyStr]] = None,
|
||||
optimization_blocklist: Optional[set[MobileOptimizerType]] = None,
|
||||
preserved_methods: Optional[list[AnyStr]] = None,
|
||||
backend: str = 'CPU') -> torch.jit.RecursiveScriptModule:
|
||||
"""
|
||||
Optimize a torch script module for mobile deployment.
|
||||
@ -43,7 +43,7 @@ def optimize_for_mobile(
|
||||
# Convert potential byte arrays into strings (if there is any) to pass type checking
|
||||
# Here we use a new name as assigning it back to preserved_methods will invoke
|
||||
# mypy errors (i.e. List[AnyStr] = List[str])
|
||||
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
|
||||
preserved_methods_str: list[str] = [str(method) for method in preserved_methods]
|
||||
|
||||
bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str)
|
||||
if all(hasattr(script_module, method) for method in bundled_inputs_attributes):
|
||||
@ -114,7 +114,7 @@ def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
|
||||
|
||||
return lint_list
|
||||
|
||||
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: List[str]) -> List[str]:
|
||||
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: list[str]) -> list[str]:
|
||||
|
||||
bundled_inputs_attributes = []
|
||||
# Has bundled inputs for forward
|
||||
|
||||
Reference in New Issue
Block a user