Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419)

------

- [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`.
- [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`.

Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449:

- #117449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129419
Approved by: https://github.com/ezyang
ghstack dependencies: #129375, #129376
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:07 +08:00
committed by PyTorch MergeBot
parent 9120992c72
commit 56935684c3
21 changed files with 414 additions and 406 deletions

View File

@ -1,18 +1,17 @@
from ctypes import c_void_p
from typing import List
from torch import Tensor
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
# Tensor to AtenTensorHandle
def unsafe_alloc_void_ptrs_from_tensors(tensors: List[Tensor]) -> List[c_void_p]: ...
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
# AtenTensorHandle to Tensor
def alloc_tensors_by_stealing_from_void_ptrs(
handles: List[c_void_p],
) -> List[Tensor]: ...
handles: list[c_void_p],
) -> list[Tensor]: ...
def alloc_tensor_by_stealing_from_void_ptr(
handle: c_void_p,
) -> Tensor: ...

View File

@ -1,10 +1,9 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Callable, List, Optional, Set
from typing import Any, Callable
import torch
from ._profiler import (
from torch._C._profiler import (
_ProfilerEvent,
ActiveProfilerType,
ProfilerActivity,
@ -47,7 +46,7 @@ class ProfilerEvent:
def name(self) -> str: ...
def node_id(self) -> int: ...
def sequence_nr(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
def shapes(self) -> list[list[int]]: ...
def thread_id(self) -> int: ...
def flops(self) -> float: ...
def is_async(self) -> bool: ...
@ -61,15 +60,15 @@ class _KinetoEvent:
def duration_ns(self) -> int: ...
def is_async(self) -> bool: ...
def linked_correlation_id(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
def dtypes(self) -> List[str]: ...
def concrete_inputs(self) -> List[Any]: ...
def shapes(self) -> list[list[int]]: ...
def dtypes(self) -> list[str]: ...
def concrete_inputs(self) -> list[Any]: ...
def device_type(self) -> DeviceType: ...
def start_thread_id(self) -> int: ...
def end_thread_id(self) -> int: ...
def correlation_id(self) -> int: ...
def fwd_thread_id(self) -> int: ...
def stack(self) -> List[str]: ...
def stack(self) -> list[str]: ...
def scope(self) -> int: ...
def sequence_nr(self) -> int: ...
def flops(self) -> int: ...
@ -77,21 +76,21 @@ class _KinetoEvent:
def privateuse1_elapsed_us(self) -> int: ...
class _ProfilerResult:
def events(self) -> List[_KinetoEvent]: ...
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
def events(self) -> list[_KinetoEvent]: ...
def legacy_events(self) -> list[list[ProfilerEvent]]: ...
def save(self, path: str) -> None: ...
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
def trace_start_ns(self) -> int: ...
class SavedTensor: ...
def _enable_profiler(
config: ProfilerConfig,
activities: Set[ProfilerActivity],
activities: set[ProfilerActivity],
) -> None: ...
def _prepare_profiler(
config: ProfilerConfig,
activities: Set[ProfilerActivity],
activities: set[ProfilerActivity],
) -> None: ...
def _disable_profiler() -> _ProfilerResult: ...
def _profiler_enabled() -> bool: ...
@ -101,7 +100,7 @@ def _get_sequence_nr() -> int: ...
def kineto_available() -> bool: ...
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
def _supported_activities() -> Set[ProfilerActivity]: ...
def _supported_activities() -> set[ProfilerActivity]: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _push_saved_tensors_default_hooks(
@ -111,11 +110,11 @@ def _push_saved_tensors_default_hooks(
def _pop_saved_tensors_default_hooks() -> None: ...
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...
def _saved_tensors_hooks_enable() -> None: ...
def _saved_tensors_hooks_disable(message: str) -> None: ...
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
class CreationMeta(Enum):

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Set
from typing import Any
import torch
@ -7,9 +7,9 @@ import torch
class DistAutogradContext:
def _context_id(self) -> int: ...
def _recv_functions(self) -> Dict[int, Any]: ...
def _send_functions(self) -> Dict[int, Any]: ...
def _known_worker_ids(self) -> Set[int]: ...
def _recv_functions(self) -> dict[int, Any]: ...
def _send_functions(self) -> dict[int, Any]: ...
def _known_worker_ids(self) -> set[int]: ...
def _new_context() -> DistAutogradContext: ...
def _release_context(context_id: int) -> None: ...
@ -18,10 +18,10 @@ def _is_valid_context(worker_id: int) -> bool: ...
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
def _current_context() -> DistAutogradContext: ...
def _init(worker_id: int) -> None: ...
def _get_debug_info() -> Dict[str, str]: ...
def _get_debug_info() -> dict[str, str]: ...
def backward(
context_id: int,
roots: List[torch.Tensor],
roots: list[torch.Tensor],
retain_graph=False,
) -> None: ...
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...

View File

@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, overload, Tuple, Union
from typing import Any, overload
import torch
from torch import Tensor
@ -26,36 +26,36 @@ def _register_builtin_comm_hook(
comm_hook_type: BuiltinCommHookType,
): ...
def _set_global_rank(rank: int) -> None: ...
def _hash_tensors(tensors: List[Tensor]) -> int: ...
def _hash_tensors(tensors: list[Tensor]) -> int: ...
class GradBucket:
def index(self) -> int: ...
def buffer(self) -> Tensor: ...
def gradients(self) -> List[Tensor]: ...
def gradients(self) -> list[Tensor]: ...
def is_last(self) -> bool: ...
def set_buffer(self, tensor: Tensor) -> None: ...
def parameters(self) -> List[Tensor]: ...
def parameters(self) -> list[Tensor]: ...
class Reducer:
def __init__(
self,
params: List[Tensor],
bucket_indices: List[List[int]],
per_bucket_size_limits: List[int],
params: list[Tensor],
bucket_indices: list[list[int]],
per_bucket_size_limits: list[int],
process_group: ProcessGroup,
expect_sparse_gradients: List[bool] = ...,
expect_sparse_gradients: list[bool] = ...,
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
find_unused_parameters: bool = ...,
gradient_as_bucket_view: bool = ...,
param_to_name_mapping: Dict[int, str] = ...,
param_to_name_mapping: dict[int, str] = ...,
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
): ...
) -> None: ...
def prepare_for_forward(self) -> None: ...
def prepare_for_backward(self, output: List[Tensor]) -> None: ...
def get_backward_stats(self) -> List[int]: ...
def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
def get_backward_stats(self) -> list[int]: ...
def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
def _rebuild_buckets(self) -> bool: ...
def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
def _push_all_rebuilt_params(self) -> None: ...
def _set_forward_pass_work_handle(
self,
@ -69,20 +69,20 @@ class Reducer:
def set_logger(self, logger: Logger) -> None: ...
def _remove_autograd_hooks(self) -> None: ...
def _check_reducer_finalized(self) -> None: ...
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
def _reset_state(self) -> None: ...
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
class DDPLoggingData:
strs_map: Dict[str, str]
ints_map: Dict[str, int]
strs_map: dict[str, str]
ints_map: dict[str, int]
class Logger:
def __init__(self, reducer: Reducer): ...
def __init__(self, reducer: Reducer) -> None: ...
def set_construction_data_and_log(
self,
module_name: str,
device_ids: List[int],
device_ids: list[int],
output_device: int,
broadcast_buffers: bool,
has_sync_bn: bool,
@ -109,7 +109,7 @@ class DebugLevel(Enum):
DETAIL = ...
class ReduceOp:
def __init__(self, op: RedOpType): ...
def __init__(self, op: RedOpType) -> None: ...
SUM: RedOpType = ...
AVG: RedOpType = ...
@ -161,7 +161,7 @@ class ReduceScatterOptions:
asyncOp: bool
class BarrierOptions:
device_ids: List[int]
device_ids: list[int]
device: torch.device
timeout: timedelta
@ -182,36 +182,36 @@ class Store:
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
def wait(self, keys: List[str]): ...
def wait(self, keys: list[str]): ...
@overload
def wait(self, keys: List[str], timeout: timedelta): ...
def wait(self, keys: list[str], timeout: timedelta): ...
class FileStore(Store):
def __init__(self, path: str, numWorkers: int = ...): ...
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
class HashStore(Store):
def __init__(self): ...
def __init__(self) -> None: ...
class TCPStore(Store):
def __init__(
self,
host_name: str,
port: int,
world_size: Optional[int] = ...,
world_size: int | None = ...,
is_master: bool = ...,
timeout: timedelta = ...,
wait_for_workers: bool = ...,
multi_tenant: bool = ...,
master_listen_fd: Optional[int] = ...,
use_libuv: Optional[bool] = ...,
): ...
master_listen_fd: int | None = ...,
use_libuv: bool | None = ...,
) -> None: ...
@property
def host(self) -> str: ...
@property
def port(self) -> int: ...
class PrefixStore(Store):
def __init__(self, prefix: str, store: Store): ...
def __init__(self, prefix: str, store: Store) -> None: ...
@property
def underlying_store(self) -> Store: ...
@ -230,7 +230,7 @@ class _StoreCollectives(_ControlCollectives):
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
class _DistributedBackendOptions:
def __init__(self): ...
def __init__(self) -> None: ...
@property
def store(self) -> Store: ...
@store.setter
@ -252,9 +252,9 @@ class _DistributedBackendOptions:
@group_id.setter
def group_id(self, group_id: str) -> None: ...
@property
def global_ranks_in_group(self) -> List[int]: ...
def global_ranks_in_group(self) -> list[int]: ...
@global_ranks_in_group.setter
def global_ranks_in_group(self, ranks: List[int]) -> None: ...
def global_ranks_in_group(self, ranks: list[int]) -> None: ...
class Work:
def is_completed(self) -> bool: ...
@ -264,7 +264,7 @@ class Work:
def get_future(self) -> Future: ...
def source_rank(self) -> int: ...
def _source_rank(self) -> int: ...
def result(self) -> List[Tensor]: ...
def result(self) -> list[Tensor]: ...
def synchronize(self): ...
def boxed(self) -> ScriptObject: ...
@staticmethod
@ -275,17 +275,17 @@ class Backend:
self,
rank: int,
size: int,
): ...
) -> None: ...
@property
def supports_splitting(self) -> bool: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
def _set_sequence_number_for_group(self) -> None: ...
class ProcessGroup:
class Options:
def __init__(self, backend: str, timeout: timedelta = ...): ...
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
@property
def backend(self) -> str: ...
@property
@ -300,13 +300,19 @@ class ProcessGroup:
UCC = ...
MPI = ...
CUSTOM = ...
def __init__(self, store: Store, rank: int, size: int, options: Options): ...
def __init__(
self,
store: Store,
rank: int,
size: int,
options: Options,
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
@overload
def broadcast(
self,
tensors: List[Tensor],
tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
@ -318,13 +324,13 @@ class ProcessGroup:
@overload
def allreduce(
self,
tensors: List[Tensor],
tensors: list[Tensor],
opts: AllreduceOptions = ...,
) -> Work: ...
@overload
def allreduce(
self,
tensors: List[Tensor],
tensors: list[Tensor],
op=...,
) -> Work: ...
@overload
@ -335,19 +341,19 @@ class ProcessGroup:
) -> Work: ...
def allreduce_coalesced(
self,
tensors: List[Tensor],
tensors: list[Tensor],
opts=...,
) -> Work: ...
def reduce_scatter_tensor_coalesced(
self,
outputTensors: List[Tensor],
inputTensors: List[Tensor],
opts: Optional[ReduceScatterOptions] = None,
outputTensors: list[Tensor],
inputTensors: list[Tensor],
opts: ReduceScatterOptions | None = None,
) -> Work: ...
@overload
def reduce(
self,
tensors: List[Tensor],
tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
@ -360,14 +366,14 @@ class ProcessGroup:
@overload
def allgather(
self,
output_tensors: List[List[Tensor]],
input_tensors: List[Tensor],
output_tensors: list[list[Tensor]],
input_tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def allgather(
self,
output_tensors: List[Tensor],
output_tensors: list[Tensor],
input_tensor: Tensor,
) -> Work: ...
def _allgather_base(
@ -378,70 +384,70 @@ class ProcessGroup:
) -> Work: ...
def allgather_coalesced(
self,
output_lists: List[List[Tensor]],
input_list: List[Tensor],
output_lists: list[list[Tensor]],
input_list: list[Tensor],
opts=...,
) -> Work: ...
def allgather_into_tensor_coalesced(
self,
output_lists: List[Tensor],
input_list: List[Tensor],
output_lists: list[Tensor],
input_list: list[Tensor],
opts=...,
) -> Work: ...
@overload
def gather(
self,
output_tensors: List[List[Tensor]],
input_tensors: List[Tensor],
output_tensors: list[list[Tensor]],
input_tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def gather(
self,
output_tensors: List[Tensor],
output_tensors: list[Tensor],
input_tensor: Tensor,
root: int,
) -> Work: ...
@overload
def scatter(
self,
output_tensors: List[Tensor],
input_tensors: List[List[Tensor]],
output_tensors: list[Tensor],
input_tensors: list[list[Tensor]],
opts=...,
) -> Work: ...
@overload
def scatter(
self,
output_tensor: Tensor,
input_tensors: List[Tensor],
input_tensors: list[Tensor],
root: int,
) -> Work: ...
@overload
def reduce_scatter(
self,
output_tensors: List[Tensor],
input_tensors: List[List[Tensor]],
output_tensors: list[Tensor],
input_tensors: list[list[Tensor]],
opts=...,
) -> Work: ...
@overload
def reduce_scatter(
self,
output_tensors: Tensor,
input_tensor: List[Tensor],
input_tensor: list[Tensor],
) -> Work: ...
def _reduce_scatter_base(
self,
outputTensor: Tensor,
inputTensor: Tensor,
opts: Optional[ReduceScatterOptions],
opts: ReduceScatterOptions | None,
) -> Work: ...
@overload
def alltoall_base(
self,
output_tensor: Tensor,
input_tensor: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
output_split_sizes: list[int],
input_split_sizes: list[int],
opts=...,
) -> Work: ...
@overload
@ -449,35 +455,35 @@ class ProcessGroup:
self,
output: Tensor,
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
output_split_sizes: list[int],
input_split_sizes: list[int],
) -> Work: ...
@overload
def alltoall(
self,
output_tensor: List[Tensor],
input_tensor: List[Tensor],
output_tensor: list[Tensor],
input_tensor: list[Tensor],
opts=...,
) -> Work: ...
@overload
def alltoall(
self,
output: List[Tensor],
input: List[Tensor],
output: list[Tensor],
input: list[Tensor],
) -> Work: ...
def send(
self,
tensors: List[Tensor],
tensors: list[Tensor],
dstRank: int,
tag: int,
) -> Work: ...
def recv(
self,
tensors: List[Tensor],
tensors: list[Tensor],
srcRank: int,
tag: int,
) -> Work: ...
def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
def barrier(self, opts=...) -> Work: ...
def boxed(self) -> ScriptObject: ...
@staticmethod
@ -487,13 +493,13 @@ class ProcessGroup:
def _get_backend_name(self) -> str: ...
def _backend_id(self, backend_type: BackendType) -> int: ...
@property
def _device_types(self) -> List[torch.device]: ...
def _device_types(self) -> list[torch.device]: ...
def _get_backend(self, device: torch.device) -> Backend: ...
def _register_backend(
self,
device: torch.device,
backend_type: BackendType,
backend: Optional[Backend],
backend: Backend | None,
) -> None: ...
def _set_group_name(self, name: str) -> None: ...
def _set_group_desc(self, desc: str) -> None: ...
@ -502,9 +508,9 @@ class ProcessGroup:
def _wait_for_pending_works(self) -> None: ...
def _set_sequence_number_for_group(self) -> None: ...
@property
def bound_device_id(self) -> Optional[torch.device]: ...
def bound_device_id(self) -> torch.device | None: ...
@bound_device_id.setter
def bound_device_id(self, device: Optional[torch.device]) -> None: ...
def bound_device_id(self, device: torch.device | None) -> None: ...
@property
def group_name(self) -> str: ...
@property
@ -513,7 +519,7 @@ class ProcessGroup:
class ProcessGroupRoundRobin(ProcessGroup): ...
def _round_robin_process_groups(
process_groups: List[ProcessGroup],
process_groups: list[ProcessGroup],
) -> ProcessGroupRoundRobin: ...
class ProcessGroupGloo(Backend):
@ -526,7 +532,7 @@ class ProcessGroupGloo(Backend):
rank: int,
size: int,
timeout: timedelta,
): ...
) -> None: ...
@staticmethod
def create_device(hostname="", interface="") -> Device: ...
@staticmethod
@ -534,12 +540,12 @@ class ProcessGroupGloo(Backend):
def _set_default_timeout(self, timeout) -> None: ...
class _ProcessGroupWrapper(Backend):
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ...
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
wrapped_pg: Backend
class ProcessGroupNCCL(Backend):
class Options:
def __init__(self, timeout: Optional[timedelta] = None): ...
def __init__(self, timeout: timedelta | None = None) -> None: ...
@property
def backend(self) -> str: ...
@property
@ -557,7 +563,7 @@ class ProcessGroupNCCL(Backend):
rank: int,
size: int,
timeout: timedelta,
): ...
) -> None: ...
def _group_start(self) -> None: ...
def _group_end(self) -> None: ...
def _set_default_timeout(self, timeout) -> None: ...
@ -572,7 +578,7 @@ class ProcessGroupUCC(Backend):
rank: int,
size: int,
timeout: timedelta,
): ...
) -> None: ...
class ProcessGroupMPI(Backend):
def __init__(
@ -580,29 +586,29 @@ class ProcessGroupMPI(Backend):
rank: int,
size: int,
pgComm: int,
): ...
) -> None: ...
@staticmethod
def create(ranks: List[int]) -> ProcessGroupMPI: ...
def create(ranks: list[int]) -> ProcessGroupMPI: ...
def _compute_bucket_assignment_by_size(
tensors: List[Tensor],
bucket_size_limits: List[int],
expect_sparse_gradient: List[bool] = ...,
tensor_indices: List[int] = ...,
) -> Tuple[List[List[int]], List[int]]: ...
tensors: list[Tensor],
bucket_size_limits: list[int],
expect_sparse_gradient: list[bool] = ...,
tensor_indices: list[int] = ...,
) -> tuple[list[list[int]], list[int]]: ...
def _broadcast_coalesced(
process_group: ProcessGroup,
tensors: List[Tensor],
tensors: list[Tensor],
buffer_size: int,
src: int,
): ...
def _test_python_store(store: Store): ...
def _verify_params_across_processes(
process_group: ProcessGroup,
params: List[Tensor],
logger: Optional[Logger],
params: list[Tensor],
logger: Logger | None,
): ...
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
def _register_process_group(
group_name: str,
process_group: ProcessGroup,
@ -614,7 +620,10 @@ def _unregister_process_group(group_name: str) -> None: ...
class _SymmetricMemory:
@staticmethod
def set_group_info(
group_name: str, rank: int, world_size: int, store: Store
group_name: str,
rank: int,
world_size: int,
store: Store,
) -> None: ...
@staticmethod
def empty_strided_p2p(
@ -635,7 +644,7 @@ class _SymmetricMemory:
rank: int,
sizes: torch.types._size,
dtype: torch.dtype,
storage_offset: Optional[int] = 0,
storage_offset: int | None = 0,
) -> torch.Tensor: ...
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...

View File

@ -1,14 +1,13 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
from typing import Any, Generic, overload, TypeVar
import torch
from . import Future
from ._autograd import ProfilerEvent
from ._distributed_c10d import Store
from ._profiler import ProfilerConfig
from torch._C import Future
from torch._C._autograd import ProfilerEvent
from torch._C._distributed_c10d import Store
from torch._C._profiler import ProfilerConfig
# This module is defined in torch/csrc/distributed/rpc/init.cpp
@ -26,10 +25,10 @@ class RpcBackendOptions:
self,
rpc_timeout: float = ...,
init_method: str = ...,
): ...
) -> None: ...
class WorkerInfo:
def __init__(self, name: str, worker_id: int): ...
def __init__(self, name: str, worker_id: int) -> None: ...
@property
def name(self) -> str: ...
@property
@ -44,10 +43,10 @@ class RpcAgent:
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def get_debug_info(self) -> Dict[str, str]: ...
def get_metrics(self) -> Dict[str, str]: ...
def get_worker_infos(self) -> list[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
def get_debug_info(self) -> dict[str, str]: ...
def get_metrics(self) -> dict[str, str]: ...
class PyRRef(Generic[_T]):
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
@ -60,32 +59,32 @@ class PyRRef(Generic[_T]):
def rpc_sync(self, timeout: float = ...) -> Any: ...
def rpc_async(self, timeout: float = ...) -> Any: ...
def remote(self, timeout: float = ...) -> Any: ...
def _serialize(self) -> Tuple: ...
def _serialize(self) -> tuple: ...
@staticmethod
def _deserialize(tp: Tuple) -> PyRRef: ...
def _get_type(self) -> Type[_T]: ...
def _deserialize(tp: tuple) -> PyRRef: ...
def _get_type(self) -> type[_T]: ...
def _get_future(self) -> Future[_T]: ...
def _get_profiling_future(self) -> Future[_T]: ...
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: Dict[str, Dict[torch.device, torch.device]]
devices: List[torch.device]
device_maps: dict[str, dict[torch.device, torch.device]]
devices: list[torch.device]
def __init__(
self,
num_worker_threads: int,
_transports: Optional[List],
_channels: Optional[List],
_transports: list | None,
_channels: list | None,
rpc_timeout: float = ...,
init_method: str = ...,
device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
devices: List[torch.device] = [], # noqa: B006
): ...
device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006
devices: list[torch.device] = [], # noqa: B006
) -> None: ...
def _set_device_map(
self,
to: str,
device_map: Dict[torch.device, torch.device],
device_map: dict[torch.device, torch.device],
): ...
class TensorPipeAgent(RpcAgent):
@ -94,11 +93,11 @@ class TensorPipeAgent(RpcAgent):
store: Store,
name: str,
worker_id: int,
world_size: Optional[int],
world_size: int | None,
opts: _TensorPipeRpcBackendOptionsBase,
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
devices: List[torch.device],
): ...
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
devices: list[torch.device],
) -> None: ...
def join(self, shutdown: bool = False, timeout: float = 0): ...
def shutdown(self): ...
@overload
@ -107,13 +106,13 @@ class TensorPipeAgent(RpcAgent):
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
def get_worker_infos(self) -> list[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
def _update_group_membership(
self,
worker_info: WorkerInfo,
my_devices: List[torch.device],
reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
my_devices: list[torch.device],
reverse_device_map: dict[str, dict[torch.device, torch.device]],
is_join: bool,
): ...
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
@ -128,7 +127,7 @@ def _set_and_start_rpc_agent(agent: RpcAgent): ...
def _reset_current_rpc_agent(): ...
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
def _destroy_rref_context(ignoreRRefLeak: bool): ...
def _rref_context_get_debug_info() -> Dict[str, str]: ...
def _rref_context_get_debug_info() -> dict[str, str]: ...
def _cleanup_python_rpc_handler(): ...
def _invoke_rpc_builtin(
dst: WorkerInfo,
@ -140,15 +139,15 @@ def _invoke_rpc_builtin(
def _invoke_rpc_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_rpc_torchscript(
dstWorkerName: str,
qualifiedNameStr: str,
argsTuple: Tuple,
kwargsDict: Dict,
argsTuple: tuple,
kwargsDict: dict,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
@ -162,7 +161,7 @@ def _invoke_remote_builtin(
def _invoke_remote_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
@ -183,7 +182,7 @@ class RemoteProfilerManager:
def set_current_profiling_key(key: str): ...
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
def _set_profiler_node_id(default_node_id: int): ...
def _enable_jit_rref_pickle(): ...
def _disable_jit_rref_pickle(): ...

View File

@ -1,9 +1,6 @@
from typing import Dict, List
import torch
from ._distributed_c10d import Store
from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
@ -13,13 +10,13 @@ class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
num_worker_threads: int,
rpc_timeout: float,
init_method: str,
messages_to_fail: List[str],
messages_to_delay: Dict[str, float],
messages_to_fail: list[str],
messages_to_delay: dict[str, float],
num_fail_sends: int,
): ...
) -> None: ...
num_send_recv_threads: int
messages_to_fail: List[str]
messages_to_delay: Dict[str, float]
messages_to_fail: list[str]
messages_to_delay: dict[str, float]
num_fail_sends: int
class FaultyTensorPipeAgent(TensorPipeAgent):
@ -30,6 +27,6 @@ class FaultyTensorPipeAgent(TensorPipeAgent):
rank: int,
world_size: int,
options: FaultyTensorPipeRpcBackendOptions,
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
devices: List[torch.device],
): ...
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
devices: list[torch.device],
) -> None: ...

View File

@ -1,10 +1,10 @@
from typing import Callable, Optional
from typing import Callable
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
def set_autograd_compiler(
autograd_compiler: Optional[Callable[[], AutogradCompilerInstance]]
) -> Optional[Callable[[], AutogradCompilerInstance]]: ...
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
) -> Callable[[], AutogradCompilerInstance] | None: ...
def clear_cache() -> None: ...
def is_cache_empty() -> bool: ...
def set_verbose_logger(fn: Optional[Callable[[str], None]]) -> bool: ...
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import types
from typing import List, NewType, Optional
from typing import NewType
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
@ -17,11 +17,11 @@ def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
class _CacheEntry:
def check_fn(self, *args, **kwargs): ...
code: types.CodeType
next: Optional[_CacheEntry]
next: _CacheEntry | None
class _ExtraState:
def invalidate(self, cache_entry: _CacheEntry): ...
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
py_opcode_caches: List[int]
py_opcode_caches: list[int]

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Union
from typing import Any
import torch
@ -17,73 +17,106 @@ class GuardManager:
# Accessors
def globals_dict_manager(
self,
f_globals: Dict[str, Any],
f_globals: dict[str, Any],
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def dict_getitem_manager(
self, key, source, example_value, guard_manager_enum
self,
key,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def global_weakref_manager(
self, global_name: str, source, example_value, guard_manager_enum
self,
global_name: str,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def type_manager(
self, source, example_value, guard_manager_enum
self,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def getattr_manager(
self, attr: str, source, example_value, guard_manager_enum
self,
attr: str,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def lambda_manager(
self, python_lambda, source, example_value, guard_manager_enum
self,
python_lambda,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
# Leaf guards
def add_lambda_guard(self, user_lambda, verbose_code_parts: List[str]) -> None: ...
def add_id_match_guard(self, id_val, verbose_code_parts: List[str]) -> None: ...
def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
def add_equals_match_guard(
self, equals_val, verbose_code_parts: List[str]
self,
equals_val,
verbose_code_parts: list[str],
) -> None: ...
def add_global_state_guard(self, verbose_code_parts: List[str]) -> None: ...
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
class RootGuardManager(GuardManager):
def get_epilogue_lambda_guards(self) -> List[LeafGuard]: ...
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
def add_epilogue_lambda_guard(
self, guard: LeafGuard, verbose_code_parts: List[str]
self,
guard: LeafGuard,
verbose_code_parts: list[str],
) -> None: ...
class DictGuardManager(GuardManager):
def get_key_manager(
self, index, source, example_value, guard_manager_enum
self,
index,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def get_value_manager(
self, index, source, example_value, guard_manager_enum
self,
index,
source,
example_value,
guard_manager_enum,
) -> GuardManager: ...
def install_tensor_aliasing_guard(
guard_managers: List[GuardManager],
tensor_names: List[str],
verbose_code_parts: List[str],
guard_managers: list[GuardManager],
tensor_names: list[str],
verbose_code_parts: list[str],
): ...
def install_no_tensor_aliasing_guard(
guard_managers: List[GuardManager],
tensor_names: List[str],
verbose_code_parts: List[str],
guard_managers: list[GuardManager],
tensor_names: list[str],
verbose_code_parts: list[str],
): ...
class TensorGuards:
def __init__(
self,
*,
dynamic_dims_sizes: Optional[List[Optional[torch.SymInt]]] = None,
dynamic_dims_strides: Optional[List[Optional[torch.SymInt]]] = None,
): ...
dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
dynamic_dims_strides: list[torch.SymInt | None] | None = None,
) -> None: ...
def check(self, *args) -> bool: ...
def check_verbose(self, *args, tensor_check_names=None) -> Union[bool, str]: ...
def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...
def assert_size_stride(
item: torch.Tensor, size: torch.types._size, stride: torch.types._size
item: torch.Tensor,
size: torch.types._size,
stride: torch.types._size,
): ...
def check_obj_id(obj: object, expected: int) -> bool: ...
def check_type_id(obj: object, expected: int) -> bool: ...
def dict_version(d: Dict[Any, Any]) -> int: ...
def dict_version(d: dict[Any, Any]) -> int: ...

View File

@ -1,11 +1,11 @@
from typing import AnyStr, List
from typing import AnyStr
from torch import Tensor
class UndefinedGrad:
def __init__(self) -> None: ...
def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
def __call__(self, *inputs: Tensor) -> list[Tensor]: ...
class DelayedError:
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
def __call__(self, inputs: list[Tensor]) -> list[Tensor]: ...

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Optional, Tuple
from torch import Tensor
@ -15,11 +14,11 @@ def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
def maybe_get_bdim(tensor: Tensor) -> int: ...
def maybe_get_level(tensor: Tensor) -> int: ...
def maybe_current_level() -> Optional[int]: ...
def maybe_current_level() -> int | None: ...
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
def current_level() -> int: ...
def count_jvp_interpreters() -> int: ...
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
@ -52,23 +51,23 @@ class CInterpreter:
def level(self) -> int: ...
class CGradInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def __init__(self, interpreter: CInterpreter) -> None: ...
def lift(self, Tensor) -> Tensor: ...
def prevGradMode(self) -> bool: ...
class CJvpInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def __init__(self, interpreter: CInterpreter) -> None: ...
def lift(self, Tensor) -> Tensor: ...
def prevFwdGradMode(self) -> bool: ...
class CFunctionalizeInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def __init__(self, interpreter: CInterpreter) -> None: ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def functionalizeAddBackViews(self) -> bool: ...
class CVmapInterpreterPtr:
def __init__(self, interpreter: CInterpreter): ...
def __init__(self, interpreter: CInterpreter) -> None: ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def batchSize(self) -> int: ...

View File

@ -1,26 +1,24 @@
# mypy: allow-untyped-defs
from typing import List
from torch import Tensor
# defined in torch/csrc/lazy/python/init.cpp
def _mark_step(device: str, devices: List[str], wait: bool): ...
def _wait_device_ops(devices: List[str]): ...
def _mark_step(device: str, devices: list[str], wait: bool): ...
def _wait_device_ops(devices: list[str]): ...
def _reset_metrics(): ...
def _counter_names() -> List[str]: ...
def _counter_names() -> list[str]: ...
def _counter_value(name: str) -> int: ...
def _metrics_report() -> str: ...
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
def _get_graph_hash(tensors: list[Tensor]) -> str: ...
def _sync_multi(
tensors: List[Tensor],
devices: List[str],
tensors: list[Tensor],
devices: list[str],
wait: bool = True,
sync_ltc_data: bool = True,
): ...
def _get_tensor_id(tensor: Tensor) -> int: ...
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
def _get_tensors_text(tensors: list[Tensor]) -> str: ...
def _get_tensors_dot(tensors: list[Tensor]) -> str: ...
def _get_tensors_backend(tensors: list[Tensor]) -> str: ...
def _get_force_fallback() -> str: ...
def _set_force_fallback(newval: str): ...
def _clear_ir_cache(): ...

View File

@ -1,12 +1,12 @@
# mypy: allow-untyped-defs
# defined in torch/csrc/lazy/python/init.cpp
from typing import Any, List, Tuple
from typing import Any
from torch import Tensor
def _init(): ...
def _get_tensors_ts_device_data_node(
tensors: List[Tensor],
) -> Tuple[List[int], List[Any]]: ...
def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
tensors: list[Tensor],
) -> tuple[list[int], list[Any]]: ...
def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ...

View File

@ -2,7 +2,7 @@
import datetime
from enum import Enum
from typing import Callable, Dict, List, Union
from typing import Callable
class Aggregation(Enum):
VALUE = ...
@ -18,22 +18,22 @@ class Stat:
def __init__(
self,
name: str,
aggregations: List[Aggregation],
aggregations: list[Aggregation],
window_size: int,
max_samples: int = -1,
) -> None: ...
def add(self, v: float) -> None: ...
def get(self) -> Dict[Aggregation, float]: ...
def get(self) -> dict[Aggregation, float]: ...
class Event:
name: str
timestamp: datetime.datetime
data: Dict[str, Union[int, float, bool, str]]
data: dict[str, int | float | bool | str]
def __init__(
self,
name: str,
timestamp: datetime.datetime,
data: Dict[str, Union[int, float, bool, str]],
data: dict[str, int | float | bool | str],
) -> None: ...
def log_event(e: Event) -> None: ...

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Literal
from typing_extensions import TypeAlias
from torch._C import device, dtype, layout
@ -55,10 +55,10 @@ class _EventType(Enum):
class _ExperimentalConfig:
def __init__(
self,
profiler_metrics: List[str] = ...,
profiler_metrics: list[str] = ...,
profiler_measure_per_kernel: bool = ...,
verbose: bool = ...,
performance_events: List[str] = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
) -> None: ...
@ -77,31 +77,23 @@ class ProfilerConfig:
class _ProfilerEvent:
start_tid: int
start_time_ns: int
children: List[_ProfilerEvent]
children: list[_ProfilerEvent]
# TODO(robieta): remove in favor of `self.typed`
extra_fields: Union[
_ExtraFields_TorchOp,
_ExtraFields_Backend,
_ExtraFields_Allocation,
_ExtraFields_OutOfMemory,
_ExtraFields_PyCall,
_ExtraFields_PyCCall,
_ExtraFields_Kineto,
]
extra_fields: _ExtraFields_TorchOp | _ExtraFields_Backend | _ExtraFields_Allocation | _ExtraFields_OutOfMemory | _ExtraFields_PyCall | _ExtraFields_PyCCall | _ExtraFields_Kineto
@property
def typed(
self,
) -> Union[
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
]: ...
) -> (
tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp]
| tuple[Literal[_EventType.Backend], _ExtraFields_Backend]
| tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation]
| tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory]
| tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall]
| tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall]
| tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto]
): ...
@property
def name(self) -> str: ...
@property
@ -109,7 +101,7 @@ class _ProfilerEvent:
@property
def id(self) -> int: ...
@property
def parent(self) -> Optional[_ProfilerEvent]: ...
def parent(self) -> _ProfilerEvent | None: ...
@property
def correlation_id(self) -> int: ...
@property
@ -118,12 +110,12 @@ class _ProfilerEvent:
def duration_time_ns(self) -> int: ...
class _TensorMetadata:
impl_ptr: Optional[int]
storage_data_ptr: Optional[int]
id: Optional[int]
impl_ptr: int | None
storage_data_ptr: int | None
id: int | None
@property
def allocation_id(self) -> Optional[int]: ...
def allocation_id(self) -> int | None: ...
@property
def layout(self) -> layout: ...
@property
@ -131,12 +123,12 @@ class _TensorMetadata:
@property
def dtype(self) -> dtype: ...
@property
def sizes(self) -> List[int]: ...
def sizes(self) -> list[int]: ...
@property
def strides(self) -> List[int]: ...
def strides(self) -> list[int]: ...
Scalar: TypeAlias = Union[int, float, bool, complex]
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
Scalar: TypeAlias = int | float | bool | complex
Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None
class _ExtraFields_TorchOp:
name: str
@ -144,7 +136,7 @@ class _ExtraFields_TorchOp:
allow_tf32_cublas: bool
@property
def inputs(self) -> List[Input]: ...
def inputs(self) -> list[Input]: ...
@property
def scope(self) -> RecordScope: ...
@ -152,13 +144,13 @@ class _ExtraFields_Backend: ...
class _ExtraFields_Allocation:
ptr: int
id: Optional[int]
id: int | None
alloc_size: int
total_allocated: int
total_reserved: int
@property
def allocation_id(self) -> Optional[int]: ...
def allocation_id(self) -> int | None: ...
@property
def device(self) -> device: ...
@ -181,22 +173,22 @@ class _NNModuleInfo:
@property
def parameters(
self,
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ...
class _OptimizerInfo:
@property
def parameters(
self,
) -> List[
Tuple[
) -> list[
tuple[
# Parameter
_TensorMetadata,
#
# Gradient (if present during optimizer.step())
Optional[_TensorMetadata],
_TensorMetadata | None,
#
# Optimizer state for Parameter as (name, tensor) pairs
List[Tuple[str, _TensorMetadata]],
list[tuple[str, _TensorMetadata]],
]
]: ...
@ -210,9 +202,9 @@ class _ExtraFields_PyCall:
@property
def caller(self) -> _PyFrameState: ...
@property
def module(self) -> Optional[_NNModuleInfo]: ...
def module(self) -> _NNModuleInfo | None: ...
@property
def optimizer(self) -> Optional[_OptimizerInfo]: ...
def optimizer(self) -> _OptimizerInfo | None: ...
class _ExtraFields_Kineto: ...
@ -230,15 +222,15 @@ def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback
# The Dict has name, filename, line
def symbolize_tracebacks(
to_symbolize: List[CapturedTraceback],
) -> List[List[Dict[str, str]]]: ...
to_symbolize: list[CapturedTraceback],
) -> list[list[dict[str, str]]]: ...
class _RecordFunctionFast:
def __init__(
self,
name: str,
input_values: Optional[Union[list, tuple]] = None,
keyword_values: Optional[dict] = None,
input_values: list | tuple | None = None,
keyword_values: dict | None = None,
) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
def __exit__(self, *args: Any) -> None: ...

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import enum
from typing import Any, Callable, Dict, List, Optional, overload, Set, Type
from typing import Any, Callable, overload
import torch
from torch.distributed.algorithms.join import Joinable, JoinHook
@ -13,10 +13,10 @@ class _ZeROJoinHook(JoinHook):
class _DDPBucketAssignment:
bucket_index: int
parameters: List[torch.Tensor]
parameters: list[torch.Tensor]
offset: int
device: torch.device
tensor: Optional[torch.Tensor]
tensor: torch.Tensor | None
class _OverlapStatus(enum.IntEnum):
UNINITIALIZED: int = ...
@ -32,7 +32,7 @@ class _OverlapInfo:
bucket_index_to_future: Any = ...
bucket_index_to_bucket: Any = ...
bucket_indices_seen: Any = ...
assigned_ranks_per_bucket: List[Set[int]] = ...
assigned_ranks_per_bucket: list[set[int]] = ...
total_size: int = ...
shard_buckets: bool = ...
def __init__(self) -> None: ...
@ -48,34 +48,34 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
global_rank: int = ...
parameters_as_bucket_view: bool = ...
optim: Any = ...
_device_to_device_index: Dict[torch.device, int] = ...
_device_to_device_index: dict[torch.device, int] = ...
_overlap_with_ddp: bool = ...
_overlap_info: _OverlapInfo = ...
_buckets: List[List[torch.Tensor]] = ...
_bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
_buckets: list[list[torch.Tensor]] = ...
_bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
def __init__(
self,
params: Any,
optimizer_class: Type[Optimizer],
process_group: Optional[Any] = ...,
optimizer_class: type[Optimizer],
process_group: Any | None = ...,
parameters_as_bucket_view: bool = ...,
overlap_with_ddp: bool = ...,
**defaults: Any,
) -> None: ...
def add_param_group(self, param_group: Dict[str, Any]) -> None: ...
def add_param_group(self, param_group: dict[str, Any]) -> None: ...
def consolidate_state_dict(self, to: int = ...) -> None: ...
@overload
def step(self, closure: None = ..., **kwargs: Any) -> None: ...
@overload
def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
def state_dict(self) -> Dict[str, Any]: ...
def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
def state_dict(self) -> dict[str, Any]: ...
def _local_step(
self,
gradients: Optional[List[Optional[torch.Tensor]]] = None,
closure: Optional[Callable[[], float]] = None,
gradients: list[torch.Tensor | None] | None = None,
closure: Callable[[], float] | None = None,
**kwargs: Any,
) -> Optional[float]: ...
) -> float | None: ...
def _get_assigned_rank(self, bucket_index: int) -> int: ...
def _init_zero_for_overlap(self) -> None: ...
def join_hook(self, **kwargs): ...

View File

@ -1,11 +1,15 @@
from ._symbolic_trace import (
from torch.fx._symbolic_trace import (
symbolic_trace as symbolic_trace,
Tracer as Tracer,
wrap as wrap,
)
from .graph import Graph as Graph
from .graph_module import GraphModule as GraphModule
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
from .node import has_side_effect as has_side_effect, map_arg as map_arg, Node as Node
from .proxy import Proxy as Proxy
from .subgraph_rewriter import replace_pattern as replace_pattern
from torch.fx.graph import Graph as Graph
from torch.fx.graph_module import GraphModule as GraphModule
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
from torch.fx.node import (
has_side_effect as has_side_effect,
map_arg as map_arg,
Node as Node,
)
from torch.fx.proxy import Proxy as Proxy
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern

View File

@ -1,18 +1,6 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
overload,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Callable, NamedTuple, overload, TypeVar
from typing_extensions import Never, TypeAlias
from _typeshed import Incomplete
@ -36,6 +24,7 @@ from torch.jit._recursive import (
ScriptMethodStub as ScriptMethodStub,
wrap_cpp_module as wrap_cpp_module,
)
from torch.jit._serialization import validate_map_location as validate_map_location
from torch.jit._state import (
_enabled as _enabled,
_set_jit_function_cache as _set_jit_function_cache,
@ -60,8 +49,6 @@ from torch.package import (
)
from torch.utils import set_module as set_module
from ._serialization import validate_map_location as validate_map_location
ScriptFunction = torch._C.ScriptFunction
type_trace_db: JitTypeTraceStore
@ -116,7 +103,8 @@ class ConstMap:
def __getattr__(self, attr): ...
def unpackage_script_module(
importer: PackageImporter, script_module_id: str
importer: PackageImporter,
script_module_id: str,
) -> torch.nn.Module: ...
_magic_methods: Incomplete
@ -126,7 +114,7 @@ class RecursiveScriptClass:
_props: Incomplete
def __init__(self, cpp_class) -> None: ...
def __getattr__(self, attr): ...
def __setattr__(self, attr, value): ...
def __setattr__(self, attr, value) -> None: ...
def forward_magic_method(self, method_name, *args, **kwargs): ...
def __getstate__(self) -> None: ...
def __iadd__(self, other): ...
@ -138,7 +126,7 @@ class ScriptModule(Module, metaclass=ScriptMeta):
def __init__(self) -> None: ...
forward: Callable[..., Any]
def __getattr__(self, attr): ...
def __setattr__(self, attr, value): ...
def __setattr__(self, attr, value) -> None: ...
def define(self, src): ...
def _replicate_for_data_parallel(self): ...
def __reduce_package__(self, exporter: PackageExporter): ...
@ -146,7 +134,7 @@ class ScriptModule(Module, metaclass=ScriptMeta):
@property
def code(self) -> str: ...
@property
def code_with_constants(self) -> Tuple[str, ConstMap]: ...
def code_with_constants(self) -> tuple[str, ConstMap]: ...
@property
def graph(self) -> torch.Graph: ...
@property
@ -177,7 +165,7 @@ class RecursiveScriptModule(ScriptModule):
def graph_for(self, *args, **kwargs): ...
def define(self, src) -> None: ...
def __getattr__(self, attr): ...
def __setattr__(self, attr, value): ...
def __setattr__(self, attr, value) -> None: ...
def __copy__(self): ...
def __deepcopy__(self, memo): ...
def forward_magic_method(self, method_name, *args, **kwargs): ...
@ -200,59 +188,59 @@ def create_script_dict(obj): ...
def create_script_list(obj, type_hint: Incomplete | None = ...): ...
@overload
def script(
obj: Type[Module],
optimize: Optional[bool] = None,
obj: type[Module],
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> Never: ...
@overload
def script( # type: ignore[misc]
obj: Dict,
optimize: Optional[bool] = None,
obj: dict,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> torch.ScriptDict: ...
@overload
def script( # type: ignore[misc]
obj: List,
optimize: Optional[bool] = None,
obj: list,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> torch.ScriptList: ...
@overload
def script( # type: ignore[misc]
obj: Module,
optimize: Optional[bool] = None,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> RecursiveScriptModule: ...
@overload
def script( # type: ignore[misc]
obj: _ClassVar,
optimize: Optional[bool] = None,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> _ClassVar: ...
@overload
def script( # type: ignore[misc]
obj: Callable,
optimize: Optional[bool] = None,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> ScriptFunction: ...
@overload
def script(
obj: Any,
optimize: Optional[bool] = None,
optimize: bool | None = None,
_frames_up: int = 0,
_rcb: Optional[ResolutionCallback] = None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
_rcb: ResolutionCallback | None = None,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
) -> RecursiveScriptClass: ...
@overload
def script(
@ -260,7 +248,7 @@ def script(
optimize: Incomplete | None = ...,
_frames_up: int = ...,
_rcb: Incomplete | None = ...,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = ...,
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
): ...
def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
@ -279,7 +267,10 @@ class _ScriptProfileColumn:
offset: Incomplete
rows: Incomplete
def __init__(
self, header: str, alignment: int = ..., offset: int = ...
self,
header: str,
alignment: int = ...,
offset: int = ...,
) -> None: ...
def add_row(self, lineno: int, value: Any): ...
def materialize(self): ...
@ -288,7 +279,9 @@ class _ScriptProfileTable:
cols: Incomplete
source_range: Incomplete
def __init__(
self, cols: List[_ScriptProfileColumn], source_range: List[int]
self,
cols: list[_ScriptProfileColumn],
source_range: list[int],
) -> None: ...
def dump_string(self): ...

View File

@ -1,41 +1,29 @@
# mypy: allow-untyped-defs
import builtins
from typing import Optional, Tuple
from typing_extensions import TypeGuard
import torch
from torch import Tensor
from torch import device, dtype, Tensor
class Parameter(Tensor):
def __init__(
self,
data: Tensor = ...,
requires_grad: builtins.bool = ...,
): ...
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
def is_lazy(param: Tensor): ...
def is_lazy(
param: Tensor,
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
class UninitializedParameter(Tensor):
def __init__(
self,
data: Tensor = ...,
requires_grad: builtins.bool = ...,
): ...
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
def materialize(
self,
shape: Tuple[int, ...],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
): ...
shape: tuple[int, ...],
device: device | None = None,
dtype: dtype | None = None,
) -> None: ...
class UninitializedBuffer(Tensor):
def __init__(
self,
data: Tensor = ...,
requires_grad: builtins.bool = ...,
): ...
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
def materialize(
self,
shape: Tuple[int, ...],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
): ...
shape: tuple[int, ...],
device: device | None = None,
dtype: dtype | None = None,
) -> None: ...

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union
from typing import Any, Iterable, NamedTuple, overload, Sequence
from typing_extensions import Self
from torch import Tensor
@ -9,8 +9,8 @@ from torch.types import _dtype
class PackedSequence_(NamedTuple):
data: Tensor
batch_sizes: Tensor
sorted_indices: Optional[Tensor]
unsorted_indices: Optional[Tensor]
sorted_indices: Tensor | None
unsorted_indices: Tensor | None
def bind(optional: Any, fn: Any): ...
@ -18,9 +18,9 @@ class PackedSequence(PackedSequence_):
def __new__(
cls,
data: Tensor,
batch_sizes: Optional[Tensor] = ...,
sorted_indices: Optional[Tensor] = ...,
unsorted_indices: Optional[Tensor] = ...,
batch_sizes: Tensor | None = ...,
sorted_indices: Tensor | None = ...,
unsorted_indices: Tensor | None = ...,
) -> Self: ...
def pin_memory(self: Self) -> Self: ...
def cuda(self: Self, *args: Any, **kwargs: Any) -> Self: ...
@ -43,8 +43,8 @@ class PackedSequence(PackedSequence_):
@overload
def to(
self: Self,
device: Optional[DeviceLikeType] = None,
dtype: Optional[_dtype] = None,
device: DeviceLikeType | None = None,
dtype: _dtype | None = None,
non_blocking: bool = False,
copy: bool = False,
) -> Self: ...
@ -59,7 +59,7 @@ class PackedSequence(PackedSequence_):
def is_cuda(self) -> bool: ...
def is_pinned(self) -> bool: ...
def invert_permutation(permutation: Optional[Tensor]): ...
def invert_permutation(permutation: Tensor | None): ...
def pack_padded_sequence(
input: Tensor,
lengths: Tensor,
@ -70,10 +70,10 @@ def pad_packed_sequence(
sequence: PackedSequence,
batch_first: bool = ...,
padding_value: float = ...,
total_length: Optional[int] = ...,
) -> Tuple[Tensor, ...]: ...
total_length: int | None = ...,
) -> tuple[Tensor, ...]: ...
def pad_sequence(
sequences: Union[Tensor, Iterable[Tensor]],
sequences: Tensor | Iterable[Tensor],
batch_first: bool = False,
padding_value: float = ...,
) -> Tensor: ...
@ -83,7 +83,7 @@ def pack_sequence(
) -> PackedSequence: ...
def get_packed_sequence(
data: Tensor,
batch_sizes: Optional[Tensor],
sorted_indices: Optional[Tensor],
unsorted_indices: Optional[Tensor],
batch_sizes: Tensor | None,
sorted_indices: Tensor | None,
unsorted_indices: Tensor | None,
) -> PackedSequence: ...

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING
"""
This was semi-automatically generated by running
@ -24,13 +24,11 @@ Note that the import should happen before the call to install_config_module(), o
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable() -> Dict[str, Any]: ...
def save_config_portable() -> dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> Dict[str, Any]: ...
def shallow_copy_dict() -> Dict[str, Any]: ...
def load_config(config: Union[bytes, Dict[str, Any]]) -> None: ...
def get_config_copy() -> Dict[str, Any]: ...
def patch(
arg1: Optional[Union[str, Dict[str, Any]]] = None, arg2: Any = None, **kwargs
): ...
def to_dict() -> dict[str, Any]: ...
def shallow_copy_dict() -> dict[str, Any]: ...
def load_config(config: bytes | dict[str, Any]) -> None: ...
def get_config_copy() -> dict[str, Any]: ...
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...