mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi
stub files (#129419)
------ - [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`. - [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`. Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449: - #117449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129419 Approved by: https://github.com/ezyang ghstack dependencies: #129375, #129376
This commit is contained in:
committed by
PyTorch MergeBot
parent
9120992c72
commit
56935684c3
@ -1,18 +1,17 @@
|
||||
from ctypes import c_void_p
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
|
||||
|
||||
# Tensor to AtenTensorHandle
|
||||
def unsafe_alloc_void_ptrs_from_tensors(tensors: List[Tensor]) -> List[c_void_p]: ...
|
||||
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
|
||||
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
|
||||
|
||||
# AtenTensorHandle to Tensor
|
||||
def alloc_tensors_by_stealing_from_void_ptrs(
|
||||
handles: List[c_void_p],
|
||||
) -> List[Tensor]: ...
|
||||
handles: list[c_void_p],
|
||||
) -> list[Tensor]: ...
|
||||
def alloc_tensor_by_stealing_from_void_ptr(
|
||||
handle: c_void_p,
|
||||
) -> Tensor: ...
|
||||
|
@ -1,10 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Set
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from ._profiler import (
|
||||
from torch._C._profiler import (
|
||||
_ProfilerEvent,
|
||||
ActiveProfilerType,
|
||||
ProfilerActivity,
|
||||
@ -47,7 +46,7 @@ class ProfilerEvent:
|
||||
def name(self) -> str: ...
|
||||
def node_id(self) -> int: ...
|
||||
def sequence_nr(self) -> int: ...
|
||||
def shapes(self) -> List[List[int]]: ...
|
||||
def shapes(self) -> list[list[int]]: ...
|
||||
def thread_id(self) -> int: ...
|
||||
def flops(self) -> float: ...
|
||||
def is_async(self) -> bool: ...
|
||||
@ -61,15 +60,15 @@ class _KinetoEvent:
|
||||
def duration_ns(self) -> int: ...
|
||||
def is_async(self) -> bool: ...
|
||||
def linked_correlation_id(self) -> int: ...
|
||||
def shapes(self) -> List[List[int]]: ...
|
||||
def dtypes(self) -> List[str]: ...
|
||||
def concrete_inputs(self) -> List[Any]: ...
|
||||
def shapes(self) -> list[list[int]]: ...
|
||||
def dtypes(self) -> list[str]: ...
|
||||
def concrete_inputs(self) -> list[Any]: ...
|
||||
def device_type(self) -> DeviceType: ...
|
||||
def start_thread_id(self) -> int: ...
|
||||
def end_thread_id(self) -> int: ...
|
||||
def correlation_id(self) -> int: ...
|
||||
def fwd_thread_id(self) -> int: ...
|
||||
def stack(self) -> List[str]: ...
|
||||
def stack(self) -> list[str]: ...
|
||||
def scope(self) -> int: ...
|
||||
def sequence_nr(self) -> int: ...
|
||||
def flops(self) -> int: ...
|
||||
@ -77,21 +76,21 @@ class _KinetoEvent:
|
||||
def privateuse1_elapsed_us(self) -> int: ...
|
||||
|
||||
class _ProfilerResult:
|
||||
def events(self) -> List[_KinetoEvent]: ...
|
||||
def legacy_events(self) -> List[List[ProfilerEvent]]: ...
|
||||
def events(self) -> list[_KinetoEvent]: ...
|
||||
def legacy_events(self) -> list[list[ProfilerEvent]]: ...
|
||||
def save(self, path: str) -> None: ...
|
||||
def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
|
||||
def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
|
||||
def trace_start_ns(self) -> int: ...
|
||||
|
||||
class SavedTensor: ...
|
||||
|
||||
def _enable_profiler(
|
||||
config: ProfilerConfig,
|
||||
activities: Set[ProfilerActivity],
|
||||
activities: set[ProfilerActivity],
|
||||
) -> None: ...
|
||||
def _prepare_profiler(
|
||||
config: ProfilerConfig,
|
||||
activities: Set[ProfilerActivity],
|
||||
activities: set[ProfilerActivity],
|
||||
) -> None: ...
|
||||
def _disable_profiler() -> _ProfilerResult: ...
|
||||
def _profiler_enabled() -> bool: ...
|
||||
@ -101,7 +100,7 @@ def _get_sequence_nr() -> int: ...
|
||||
def kineto_available() -> bool: ...
|
||||
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
|
||||
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
|
||||
def _supported_activities() -> Set[ProfilerActivity]: ...
|
||||
def _supported_activities() -> set[ProfilerActivity]: ...
|
||||
def _enable_record_function(enable: bool) -> None: ...
|
||||
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
|
||||
def _push_saved_tensors_default_hooks(
|
||||
@ -111,11 +110,11 @@ def _push_saved_tensors_default_hooks(
|
||||
def _pop_saved_tensors_default_hooks() -> None: ...
|
||||
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
|
||||
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
|
||||
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
|
||||
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
|
||||
def _profiler_type() -> ActiveProfilerType: ...
|
||||
def _saved_tensors_hooks_enable() -> None: ...
|
||||
def _saved_tensors_hooks_disable(message: str) -> None: ...
|
||||
def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
|
||||
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
|
||||
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
|
||||
|
||||
class CreationMeta(Enum):
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,9 +7,9 @@ import torch
|
||||
|
||||
class DistAutogradContext:
|
||||
def _context_id(self) -> int: ...
|
||||
def _recv_functions(self) -> Dict[int, Any]: ...
|
||||
def _send_functions(self) -> Dict[int, Any]: ...
|
||||
def _known_worker_ids(self) -> Set[int]: ...
|
||||
def _recv_functions(self) -> dict[int, Any]: ...
|
||||
def _send_functions(self) -> dict[int, Any]: ...
|
||||
def _known_worker_ids(self) -> set[int]: ...
|
||||
|
||||
def _new_context() -> DistAutogradContext: ...
|
||||
def _release_context(context_id: int) -> None: ...
|
||||
@ -18,10 +18,10 @@ def _is_valid_context(worker_id: int) -> bool: ...
|
||||
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
|
||||
def _current_context() -> DistAutogradContext: ...
|
||||
def _init(worker_id: int) -> None: ...
|
||||
def _get_debug_info() -> Dict[str, str]: ...
|
||||
def _get_debug_info() -> dict[str, str]: ...
|
||||
def backward(
|
||||
context_id: int,
|
||||
roots: List[torch.Tensor],
|
||||
roots: list[torch.Tensor],
|
||||
retain_graph=False,
|
||||
) -> None: ...
|
||||
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
|
||||
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...
|
||||
|
@ -2,7 +2,7 @@
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, overload, Tuple, Union
|
||||
from typing import Any, overload
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -26,36 +26,36 @@ def _register_builtin_comm_hook(
|
||||
comm_hook_type: BuiltinCommHookType,
|
||||
): ...
|
||||
def _set_global_rank(rank: int) -> None: ...
|
||||
def _hash_tensors(tensors: List[Tensor]) -> int: ...
|
||||
def _hash_tensors(tensors: list[Tensor]) -> int: ...
|
||||
|
||||
class GradBucket:
|
||||
def index(self) -> int: ...
|
||||
def buffer(self) -> Tensor: ...
|
||||
def gradients(self) -> List[Tensor]: ...
|
||||
def gradients(self) -> list[Tensor]: ...
|
||||
def is_last(self) -> bool: ...
|
||||
def set_buffer(self, tensor: Tensor) -> None: ...
|
||||
def parameters(self) -> List[Tensor]: ...
|
||||
def parameters(self) -> list[Tensor]: ...
|
||||
|
||||
class Reducer:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
bucket_indices: List[List[int]],
|
||||
per_bucket_size_limits: List[int],
|
||||
params: list[Tensor],
|
||||
bucket_indices: list[list[int]],
|
||||
per_bucket_size_limits: list[int],
|
||||
process_group: ProcessGroup,
|
||||
expect_sparse_gradients: List[bool] = ...,
|
||||
expect_sparse_gradients: list[bool] = ...,
|
||||
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
|
||||
find_unused_parameters: bool = ...,
|
||||
gradient_as_bucket_view: bool = ...,
|
||||
param_to_name_mapping: Dict[int, str] = ...,
|
||||
param_to_name_mapping: dict[int, str] = ...,
|
||||
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
||||
): ...
|
||||
) -> None: ...
|
||||
def prepare_for_forward(self) -> None: ...
|
||||
def prepare_for_backward(self, output: List[Tensor]) -> None: ...
|
||||
def get_backward_stats(self) -> List[int]: ...
|
||||
def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
|
||||
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
|
||||
def get_backward_stats(self) -> list[int]: ...
|
||||
def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
|
||||
def _rebuild_buckets(self) -> bool: ...
|
||||
def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
|
||||
def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
|
||||
def _push_all_rebuilt_params(self) -> None: ...
|
||||
def _set_forward_pass_work_handle(
|
||||
self,
|
||||
@ -69,20 +69,20 @@ class Reducer:
|
||||
def set_logger(self, logger: Logger) -> None: ...
|
||||
def _remove_autograd_hooks(self) -> None: ...
|
||||
def _check_reducer_finalized(self) -> None: ...
|
||||
def _set_sparse_metadata(self, global_unique_ids: Dict[str, Tensor]) -> None: ...
|
||||
def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
|
||||
def _reset_state(self) -> None: ...
|
||||
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
||||
|
||||
class DDPLoggingData:
|
||||
strs_map: Dict[str, str]
|
||||
ints_map: Dict[str, int]
|
||||
strs_map: dict[str, str]
|
||||
ints_map: dict[str, int]
|
||||
|
||||
class Logger:
|
||||
def __init__(self, reducer: Reducer): ...
|
||||
def __init__(self, reducer: Reducer) -> None: ...
|
||||
def set_construction_data_and_log(
|
||||
self,
|
||||
module_name: str,
|
||||
device_ids: List[int],
|
||||
device_ids: list[int],
|
||||
output_device: int,
|
||||
broadcast_buffers: bool,
|
||||
has_sync_bn: bool,
|
||||
@ -109,7 +109,7 @@ class DebugLevel(Enum):
|
||||
DETAIL = ...
|
||||
|
||||
class ReduceOp:
|
||||
def __init__(self, op: RedOpType): ...
|
||||
def __init__(self, op: RedOpType) -> None: ...
|
||||
|
||||
SUM: RedOpType = ...
|
||||
AVG: RedOpType = ...
|
||||
@ -161,7 +161,7 @@ class ReduceScatterOptions:
|
||||
asyncOp: bool
|
||||
|
||||
class BarrierOptions:
|
||||
device_ids: List[int]
|
||||
device_ids: list[int]
|
||||
device: torch.device
|
||||
timeout: timedelta
|
||||
|
||||
@ -182,36 +182,36 @@ class Store:
|
||||
def num_keys(self) -> int: ...
|
||||
def set_timeout(self, timeout: timedelta): ...
|
||||
@overload
|
||||
def wait(self, keys: List[str]): ...
|
||||
def wait(self, keys: list[str]): ...
|
||||
@overload
|
||||
def wait(self, keys: List[str], timeout: timedelta): ...
|
||||
def wait(self, keys: list[str], timeout: timedelta): ...
|
||||
|
||||
class FileStore(Store):
|
||||
def __init__(self, path: str, numWorkers: int = ...): ...
|
||||
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
|
||||
|
||||
class HashStore(Store):
|
||||
def __init__(self): ...
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
class TCPStore(Store):
|
||||
def __init__(
|
||||
self,
|
||||
host_name: str,
|
||||
port: int,
|
||||
world_size: Optional[int] = ...,
|
||||
world_size: int | None = ...,
|
||||
is_master: bool = ...,
|
||||
timeout: timedelta = ...,
|
||||
wait_for_workers: bool = ...,
|
||||
multi_tenant: bool = ...,
|
||||
master_listen_fd: Optional[int] = ...,
|
||||
use_libuv: Optional[bool] = ...,
|
||||
): ...
|
||||
master_listen_fd: int | None = ...,
|
||||
use_libuv: bool | None = ...,
|
||||
) -> None: ...
|
||||
@property
|
||||
def host(self) -> str: ...
|
||||
@property
|
||||
def port(self) -> int: ...
|
||||
|
||||
class PrefixStore(Store):
|
||||
def __init__(self, prefix: str, store: Store): ...
|
||||
def __init__(self, prefix: str, store: Store) -> None: ...
|
||||
@property
|
||||
def underlying_store(self) -> Store: ...
|
||||
|
||||
@ -230,7 +230,7 @@ class _StoreCollectives(_ControlCollectives):
|
||||
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
|
||||
|
||||
class _DistributedBackendOptions:
|
||||
def __init__(self): ...
|
||||
def __init__(self) -> None: ...
|
||||
@property
|
||||
def store(self) -> Store: ...
|
||||
@store.setter
|
||||
@ -252,9 +252,9 @@ class _DistributedBackendOptions:
|
||||
@group_id.setter
|
||||
def group_id(self, group_id: str) -> None: ...
|
||||
@property
|
||||
def global_ranks_in_group(self) -> List[int]: ...
|
||||
def global_ranks_in_group(self) -> list[int]: ...
|
||||
@global_ranks_in_group.setter
|
||||
def global_ranks_in_group(self, ranks: List[int]) -> None: ...
|
||||
def global_ranks_in_group(self, ranks: list[int]) -> None: ...
|
||||
|
||||
class Work:
|
||||
def is_completed(self) -> bool: ...
|
||||
@ -264,7 +264,7 @@ class Work:
|
||||
def get_future(self) -> Future: ...
|
||||
def source_rank(self) -> int: ...
|
||||
def _source_rank(self) -> int: ...
|
||||
def result(self) -> List[Tensor]: ...
|
||||
def result(self) -> list[Tensor]: ...
|
||||
def synchronize(self): ...
|
||||
def boxed(self) -> ScriptObject: ...
|
||||
@staticmethod
|
||||
@ -275,17 +275,17 @@ class Backend:
|
||||
self,
|
||||
rank: int,
|
||||
size: int,
|
||||
): ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def supports_splitting(self) -> bool: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
def eager_connect_single_device(self, device: Optional[torch.device]) -> None: ...
|
||||
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
|
||||
def _set_sequence_number_for_group(self) -> None: ...
|
||||
|
||||
class ProcessGroup:
|
||||
class Options:
|
||||
def __init__(self, backend: str, timeout: timedelta = ...): ...
|
||||
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
|
||||
@property
|
||||
def backend(self) -> str: ...
|
||||
@property
|
||||
@ -300,13 +300,19 @@ class ProcessGroup:
|
||||
UCC = ...
|
||||
MPI = ...
|
||||
CUSTOM = ...
|
||||
def __init__(self, store: Store, rank: int, size: int, options: Options): ...
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
rank: int,
|
||||
size: int,
|
||||
options: Options,
|
||||
) -> None: ...
|
||||
def rank(self) -> int: ...
|
||||
def size(self) -> int: ...
|
||||
@overload
|
||||
def broadcast(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
@ -318,13 +324,13 @@ class ProcessGroup:
|
||||
@overload
|
||||
def allreduce(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
opts: AllreduceOptions = ...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def allreduce(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
op=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
@ -335,19 +341,19 @@ class ProcessGroup:
|
||||
) -> Work: ...
|
||||
def allreduce_coalesced(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
def reduce_scatter_tensor_coalesced(
|
||||
self,
|
||||
outputTensors: List[Tensor],
|
||||
inputTensors: List[Tensor],
|
||||
opts: Optional[ReduceScatterOptions] = None,
|
||||
outputTensors: list[Tensor],
|
||||
inputTensors: list[Tensor],
|
||||
opts: ReduceScatterOptions | None = None,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def reduce(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
@ -360,14 +366,14 @@ class ProcessGroup:
|
||||
@overload
|
||||
def allgather(
|
||||
self,
|
||||
output_tensors: List[List[Tensor]],
|
||||
input_tensors: List[Tensor],
|
||||
output_tensors: list[list[Tensor]],
|
||||
input_tensors: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def allgather(
|
||||
self,
|
||||
output_tensors: List[Tensor],
|
||||
output_tensors: list[Tensor],
|
||||
input_tensor: Tensor,
|
||||
) -> Work: ...
|
||||
def _allgather_base(
|
||||
@ -378,70 +384,70 @@ class ProcessGroup:
|
||||
) -> Work: ...
|
||||
def allgather_coalesced(
|
||||
self,
|
||||
output_lists: List[List[Tensor]],
|
||||
input_list: List[Tensor],
|
||||
output_lists: list[list[Tensor]],
|
||||
input_list: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
def allgather_into_tensor_coalesced(
|
||||
self,
|
||||
output_lists: List[Tensor],
|
||||
input_list: List[Tensor],
|
||||
output_lists: list[Tensor],
|
||||
input_list: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def gather(
|
||||
self,
|
||||
output_tensors: List[List[Tensor]],
|
||||
input_tensors: List[Tensor],
|
||||
output_tensors: list[list[Tensor]],
|
||||
input_tensors: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def gather(
|
||||
self,
|
||||
output_tensors: List[Tensor],
|
||||
output_tensors: list[Tensor],
|
||||
input_tensor: Tensor,
|
||||
root: int,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def scatter(
|
||||
self,
|
||||
output_tensors: List[Tensor],
|
||||
input_tensors: List[List[Tensor]],
|
||||
output_tensors: list[Tensor],
|
||||
input_tensors: list[list[Tensor]],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def scatter(
|
||||
self,
|
||||
output_tensor: Tensor,
|
||||
input_tensors: List[Tensor],
|
||||
input_tensors: list[Tensor],
|
||||
root: int,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensors: List[Tensor],
|
||||
input_tensors: List[List[Tensor]],
|
||||
output_tensors: list[Tensor],
|
||||
input_tensors: list[list[Tensor]],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensors: Tensor,
|
||||
input_tensor: List[Tensor],
|
||||
input_tensor: list[Tensor],
|
||||
) -> Work: ...
|
||||
def _reduce_scatter_base(
|
||||
self,
|
||||
outputTensor: Tensor,
|
||||
inputTensor: Tensor,
|
||||
opts: Optional[ReduceScatterOptions],
|
||||
opts: ReduceScatterOptions | None,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def alltoall_base(
|
||||
self,
|
||||
output_tensor: Tensor,
|
||||
input_tensor: Tensor,
|
||||
output_split_sizes: List[int],
|
||||
input_split_sizes: List[int],
|
||||
output_split_sizes: list[int],
|
||||
input_split_sizes: list[int],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
@ -449,35 +455,35 @@ class ProcessGroup:
|
||||
self,
|
||||
output: Tensor,
|
||||
input: Tensor,
|
||||
output_split_sizes: List[int],
|
||||
input_split_sizes: List[int],
|
||||
output_split_sizes: list[int],
|
||||
input_split_sizes: list[int],
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def alltoall(
|
||||
self,
|
||||
output_tensor: List[Tensor],
|
||||
input_tensor: List[Tensor],
|
||||
output_tensor: list[Tensor],
|
||||
input_tensor: list[Tensor],
|
||||
opts=...,
|
||||
) -> Work: ...
|
||||
@overload
|
||||
def alltoall(
|
||||
self,
|
||||
output: List[Tensor],
|
||||
input: List[Tensor],
|
||||
output: list[Tensor],
|
||||
input: list[Tensor],
|
||||
) -> Work: ...
|
||||
def send(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
dstRank: int,
|
||||
tag: int,
|
||||
) -> Work: ...
|
||||
def recv(
|
||||
self,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
srcRank: int,
|
||||
tag: int,
|
||||
) -> Work: ...
|
||||
def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
|
||||
def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
|
||||
def barrier(self, opts=...) -> Work: ...
|
||||
def boxed(self) -> ScriptObject: ...
|
||||
@staticmethod
|
||||
@ -487,13 +493,13 @@ class ProcessGroup:
|
||||
def _get_backend_name(self) -> str: ...
|
||||
def _backend_id(self, backend_type: BackendType) -> int: ...
|
||||
@property
|
||||
def _device_types(self) -> List[torch.device]: ...
|
||||
def _device_types(self) -> list[torch.device]: ...
|
||||
def _get_backend(self, device: torch.device) -> Backend: ...
|
||||
def _register_backend(
|
||||
self,
|
||||
device: torch.device,
|
||||
backend_type: BackendType,
|
||||
backend: Optional[Backend],
|
||||
backend: Backend | None,
|
||||
) -> None: ...
|
||||
def _set_group_name(self, name: str) -> None: ...
|
||||
def _set_group_desc(self, desc: str) -> None: ...
|
||||
@ -502,9 +508,9 @@ class ProcessGroup:
|
||||
def _wait_for_pending_works(self) -> None: ...
|
||||
def _set_sequence_number_for_group(self) -> None: ...
|
||||
@property
|
||||
def bound_device_id(self) -> Optional[torch.device]: ...
|
||||
def bound_device_id(self) -> torch.device | None: ...
|
||||
@bound_device_id.setter
|
||||
def bound_device_id(self, device: Optional[torch.device]) -> None: ...
|
||||
def bound_device_id(self, device: torch.device | None) -> None: ...
|
||||
@property
|
||||
def group_name(self) -> str: ...
|
||||
@property
|
||||
@ -513,7 +519,7 @@ class ProcessGroup:
|
||||
class ProcessGroupRoundRobin(ProcessGroup): ...
|
||||
|
||||
def _round_robin_process_groups(
|
||||
process_groups: List[ProcessGroup],
|
||||
process_groups: list[ProcessGroup],
|
||||
) -> ProcessGroupRoundRobin: ...
|
||||
|
||||
class ProcessGroupGloo(Backend):
|
||||
@ -526,7 +532,7 @@ class ProcessGroupGloo(Backend):
|
||||
rank: int,
|
||||
size: int,
|
||||
timeout: timedelta,
|
||||
): ...
|
||||
) -> None: ...
|
||||
@staticmethod
|
||||
def create_device(hostname="", interface="") -> Device: ...
|
||||
@staticmethod
|
||||
@ -534,12 +540,12 @@ class ProcessGroupGloo(Backend):
|
||||
def _set_default_timeout(self, timeout) -> None: ...
|
||||
|
||||
class _ProcessGroupWrapper(Backend):
|
||||
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo): ...
|
||||
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
|
||||
wrapped_pg: Backend
|
||||
|
||||
class ProcessGroupNCCL(Backend):
|
||||
class Options:
|
||||
def __init__(self, timeout: Optional[timedelta] = None): ...
|
||||
def __init__(self, timeout: timedelta | None = None) -> None: ...
|
||||
@property
|
||||
def backend(self) -> str: ...
|
||||
@property
|
||||
@ -557,7 +563,7 @@ class ProcessGroupNCCL(Backend):
|
||||
rank: int,
|
||||
size: int,
|
||||
timeout: timedelta,
|
||||
): ...
|
||||
) -> None: ...
|
||||
def _group_start(self) -> None: ...
|
||||
def _group_end(self) -> None: ...
|
||||
def _set_default_timeout(self, timeout) -> None: ...
|
||||
@ -572,7 +578,7 @@ class ProcessGroupUCC(Backend):
|
||||
rank: int,
|
||||
size: int,
|
||||
timeout: timedelta,
|
||||
): ...
|
||||
) -> None: ...
|
||||
|
||||
class ProcessGroupMPI(Backend):
|
||||
def __init__(
|
||||
@ -580,29 +586,29 @@ class ProcessGroupMPI(Backend):
|
||||
rank: int,
|
||||
size: int,
|
||||
pgComm: int,
|
||||
): ...
|
||||
) -> None: ...
|
||||
@staticmethod
|
||||
def create(ranks: List[int]) -> ProcessGroupMPI: ...
|
||||
def create(ranks: list[int]) -> ProcessGroupMPI: ...
|
||||
|
||||
def _compute_bucket_assignment_by_size(
|
||||
tensors: List[Tensor],
|
||||
bucket_size_limits: List[int],
|
||||
expect_sparse_gradient: List[bool] = ...,
|
||||
tensor_indices: List[int] = ...,
|
||||
) -> Tuple[List[List[int]], List[int]]: ...
|
||||
tensors: list[Tensor],
|
||||
bucket_size_limits: list[int],
|
||||
expect_sparse_gradient: list[bool] = ...,
|
||||
tensor_indices: list[int] = ...,
|
||||
) -> tuple[list[list[int]], list[int]]: ...
|
||||
def _broadcast_coalesced(
|
||||
process_group: ProcessGroup,
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
buffer_size: int,
|
||||
src: int,
|
||||
): ...
|
||||
def _test_python_store(store: Store): ...
|
||||
def _verify_params_across_processes(
|
||||
process_group: ProcessGroup,
|
||||
params: List[Tensor],
|
||||
logger: Optional[Logger],
|
||||
params: list[Tensor],
|
||||
logger: Logger | None,
|
||||
): ...
|
||||
def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
|
||||
def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
|
||||
def _register_process_group(
|
||||
group_name: str,
|
||||
process_group: ProcessGroup,
|
||||
@ -614,7 +620,10 @@ def _unregister_process_group(group_name: str) -> None: ...
|
||||
class _SymmetricMemory:
|
||||
@staticmethod
|
||||
def set_group_info(
|
||||
group_name: str, rank: int, world_size: int, store: Store
|
||||
group_name: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
store: Store,
|
||||
) -> None: ...
|
||||
@staticmethod
|
||||
def empty_strided_p2p(
|
||||
@ -635,7 +644,7 @@ class _SymmetricMemory:
|
||||
rank: int,
|
||||
sizes: torch.types._size,
|
||||
dtype: torch.dtype,
|
||||
storage_offset: Optional[int] = 0,
|
||||
storage_offset: int | None = 0,
|
||||
) -> torch.Tensor: ...
|
||||
def barrier(self, channel: int = 0) -> None: ...
|
||||
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
|
||||
|
@ -1,14 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Generic, List, Optional, overload, Tuple, Type, TypeVar
|
||||
from typing import Any, Generic, overload, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from . import Future
|
||||
from ._autograd import ProfilerEvent
|
||||
from ._distributed_c10d import Store
|
||||
from ._profiler import ProfilerConfig
|
||||
from torch._C import Future
|
||||
from torch._C._autograd import ProfilerEvent
|
||||
from torch._C._distributed_c10d import Store
|
||||
from torch._C._profiler import ProfilerConfig
|
||||
|
||||
# This module is defined in torch/csrc/distributed/rpc/init.cpp
|
||||
|
||||
@ -26,10 +25,10 @@ class RpcBackendOptions:
|
||||
self,
|
||||
rpc_timeout: float = ...,
|
||||
init_method: str = ...,
|
||||
): ...
|
||||
) -> None: ...
|
||||
|
||||
class WorkerInfo:
|
||||
def __init__(self, name: str, worker_id: int): ...
|
||||
def __init__(self, name: str, worker_id: int) -> None: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
@ -44,10 +43,10 @@ class RpcAgent:
|
||||
def get_worker_info(self) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
||||
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
||||
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
||||
def get_debug_info(self) -> Dict[str, str]: ...
|
||||
def get_metrics(self) -> Dict[str, str]: ...
|
||||
def get_worker_infos(self) -> list[WorkerInfo]: ...
|
||||
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
|
||||
def get_debug_info(self) -> dict[str, str]: ...
|
||||
def get_metrics(self) -> dict[str, str]: ...
|
||||
|
||||
class PyRRef(Generic[_T]):
|
||||
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
|
||||
@ -60,32 +59,32 @@ class PyRRef(Generic[_T]):
|
||||
def rpc_sync(self, timeout: float = ...) -> Any: ...
|
||||
def rpc_async(self, timeout: float = ...) -> Any: ...
|
||||
def remote(self, timeout: float = ...) -> Any: ...
|
||||
def _serialize(self) -> Tuple: ...
|
||||
def _serialize(self) -> tuple: ...
|
||||
@staticmethod
|
||||
def _deserialize(tp: Tuple) -> PyRRef: ...
|
||||
def _get_type(self) -> Type[_T]: ...
|
||||
def _deserialize(tp: tuple) -> PyRRef: ...
|
||||
def _get_type(self) -> type[_T]: ...
|
||||
def _get_future(self) -> Future[_T]: ...
|
||||
def _get_profiling_future(self) -> Future[_T]: ...
|
||||
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
|
||||
|
||||
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
|
||||
num_worker_threads: int
|
||||
device_maps: Dict[str, Dict[torch.device, torch.device]]
|
||||
devices: List[torch.device]
|
||||
device_maps: dict[str, dict[torch.device, torch.device]]
|
||||
devices: list[torch.device]
|
||||
def __init__(
|
||||
self,
|
||||
num_worker_threads: int,
|
||||
_transports: Optional[List],
|
||||
_channels: Optional[List],
|
||||
_transports: list | None,
|
||||
_channels: list | None,
|
||||
rpc_timeout: float = ...,
|
||||
init_method: str = ...,
|
||||
device_maps: Dict[str, Dict[torch.device, torch.device]] = {}, # noqa: B006
|
||||
devices: List[torch.device] = [], # noqa: B006
|
||||
): ...
|
||||
device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006
|
||||
devices: list[torch.device] = [], # noqa: B006
|
||||
) -> None: ...
|
||||
def _set_device_map(
|
||||
self,
|
||||
to: str,
|
||||
device_map: Dict[torch.device, torch.device],
|
||||
device_map: dict[torch.device, torch.device],
|
||||
): ...
|
||||
|
||||
class TensorPipeAgent(RpcAgent):
|
||||
@ -94,11 +93,11 @@ class TensorPipeAgent(RpcAgent):
|
||||
store: Store,
|
||||
name: str,
|
||||
worker_id: int,
|
||||
world_size: Optional[int],
|
||||
world_size: int | None,
|
||||
opts: _TensorPipeRpcBackendOptionsBase,
|
||||
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
||||
devices: List[torch.device],
|
||||
): ...
|
||||
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
|
||||
devices: list[torch.device],
|
||||
) -> None: ...
|
||||
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
||||
def shutdown(self): ...
|
||||
@overload
|
||||
@ -107,13 +106,13 @@ class TensorPipeAgent(RpcAgent):
|
||||
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
|
||||
@overload
|
||||
def get_worker_info(self, id: int) -> WorkerInfo: ...
|
||||
def get_worker_infos(self) -> List[WorkerInfo]: ...
|
||||
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
|
||||
def get_worker_infos(self) -> list[WorkerInfo]: ...
|
||||
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
|
||||
def _update_group_membership(
|
||||
self,
|
||||
worker_info: WorkerInfo,
|
||||
my_devices: List[torch.device],
|
||||
reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
|
||||
my_devices: list[torch.device],
|
||||
reverse_device_map: dict[str, dict[torch.device, torch.device]],
|
||||
is_join: bool,
|
||||
): ...
|
||||
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
|
||||
@ -128,7 +127,7 @@ def _set_and_start_rpc_agent(agent: RpcAgent): ...
|
||||
def _reset_current_rpc_agent(): ...
|
||||
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
|
||||
def _destroy_rref_context(ignoreRRefLeak: bool): ...
|
||||
def _rref_context_get_debug_info() -> Dict[str, str]: ...
|
||||
def _rref_context_get_debug_info() -> dict[str, str]: ...
|
||||
def _cleanup_python_rpc_handler(): ...
|
||||
def _invoke_rpc_builtin(
|
||||
dst: WorkerInfo,
|
||||
@ -140,15 +139,15 @@ def _invoke_rpc_builtin(
|
||||
def _invoke_rpc_python_udf(
|
||||
dst: WorkerInfo,
|
||||
pickledPythonUDF: str,
|
||||
tensors: List[torch.Tensor],
|
||||
tensors: list[torch.Tensor],
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
): ...
|
||||
def _invoke_rpc_torchscript(
|
||||
dstWorkerName: str,
|
||||
qualifiedNameStr: str,
|
||||
argsTuple: Tuple,
|
||||
kwargsDict: Dict,
|
||||
argsTuple: tuple,
|
||||
kwargsDict: dict,
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
): ...
|
||||
@ -162,7 +161,7 @@ def _invoke_remote_builtin(
|
||||
def _invoke_remote_python_udf(
|
||||
dst: WorkerInfo,
|
||||
pickledPythonUDF: str,
|
||||
tensors: List[torch.Tensor],
|
||||
tensors: list[torch.Tensor],
|
||||
rpcTimeoutSeconds: float,
|
||||
isAsyncExecution: bool,
|
||||
): ...
|
||||
@ -183,7 +182,7 @@ class RemoteProfilerManager:
|
||||
def set_current_profiling_key(key: str): ...
|
||||
|
||||
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
|
||||
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
|
||||
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
|
||||
def _set_profiler_node_id(default_node_id: int): ...
|
||||
def _enable_jit_rref_pickle(): ...
|
||||
def _disable_jit_rref_pickle(): ...
|
||||
|
@ -1,9 +1,6 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ._distributed_c10d import Store
|
||||
from ._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
|
||||
from torch._C._distributed_c10d import Store
|
||||
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
|
||||
|
||||
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
|
||||
|
||||
@ -13,13 +10,13 @@ class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
||||
num_worker_threads: int,
|
||||
rpc_timeout: float,
|
||||
init_method: str,
|
||||
messages_to_fail: List[str],
|
||||
messages_to_delay: Dict[str, float],
|
||||
messages_to_fail: list[str],
|
||||
messages_to_delay: dict[str, float],
|
||||
num_fail_sends: int,
|
||||
): ...
|
||||
) -> None: ...
|
||||
num_send_recv_threads: int
|
||||
messages_to_fail: List[str]
|
||||
messages_to_delay: Dict[str, float]
|
||||
messages_to_fail: list[str]
|
||||
messages_to_delay: dict[str, float]
|
||||
num_fail_sends: int
|
||||
|
||||
class FaultyTensorPipeAgent(TensorPipeAgent):
|
||||
@ -30,6 +27,6 @@ class FaultyTensorPipeAgent(TensorPipeAgent):
|
||||
rank: int,
|
||||
world_size: int,
|
||||
options: FaultyTensorPipeRpcBackendOptions,
|
||||
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
||||
devices: List[torch.device],
|
||||
): ...
|
||||
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
|
||||
devices: list[torch.device],
|
||||
) -> None: ...
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
|
||||
|
||||
def set_autograd_compiler(
|
||||
autograd_compiler: Optional[Callable[[], AutogradCompilerInstance]]
|
||||
) -> Optional[Callable[[], AutogradCompilerInstance]]: ...
|
||||
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
|
||||
) -> Callable[[], AutogradCompilerInstance] | None: ...
|
||||
def clear_cache() -> None: ...
|
||||
def is_cache_empty() -> bool: ...
|
||||
def set_verbose_logger(fn: Optional[Callable[[str], None]]) -> bool: ...
|
||||
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import types
|
||||
from typing import List, NewType, Optional
|
||||
from typing import NewType
|
||||
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
|
||||
@ -17,11 +17,11 @@ def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
||||
class _CacheEntry:
|
||||
def check_fn(self, *args, **kwargs): ...
|
||||
code: types.CodeType
|
||||
next: Optional[_CacheEntry]
|
||||
next: _CacheEntry | None
|
||||
|
||||
class _ExtraState:
|
||||
def invalidate(self, cache_entry: _CacheEntry): ...
|
||||
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
|
||||
|
||||
py_opcode_caches: List[int]
|
||||
py_opcode_caches: list[int]
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -17,73 +17,106 @@ class GuardManager:
|
||||
# Accessors
|
||||
def globals_dict_manager(
|
||||
self,
|
||||
f_globals: Dict[str, Any],
|
||||
f_globals: dict[str, Any],
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def dict_getitem_manager(
|
||||
self, key, source, example_value, guard_manager_enum
|
||||
self,
|
||||
key,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def global_weakref_manager(
|
||||
self, global_name: str, source, example_value, guard_manager_enum
|
||||
self,
|
||||
global_name: str,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def type_manager(
|
||||
self, source, example_value, guard_manager_enum
|
||||
self,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def getattr_manager(
|
||||
self, attr: str, source, example_value, guard_manager_enum
|
||||
self,
|
||||
attr: str,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def lambda_manager(
|
||||
self, python_lambda, source, example_value, guard_manager_enum
|
||||
self,
|
||||
python_lambda,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
|
||||
# Leaf guards
|
||||
def add_lambda_guard(self, user_lambda, verbose_code_parts: List[str]) -> None: ...
|
||||
def add_id_match_guard(self, id_val, verbose_code_parts: List[str]) -> None: ...
|
||||
def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
|
||||
def add_equals_match_guard(
|
||||
self, equals_val, verbose_code_parts: List[str]
|
||||
self,
|
||||
equals_val,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
def add_global_state_guard(self, verbose_code_parts: List[str]) -> None: ...
|
||||
def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
|
||||
|
||||
class RootGuardManager(GuardManager):
|
||||
def get_epilogue_lambda_guards(self) -> List[LeafGuard]: ...
|
||||
def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
|
||||
def add_epilogue_lambda_guard(
|
||||
self, guard: LeafGuard, verbose_code_parts: List[str]
|
||||
self,
|
||||
guard: LeafGuard,
|
||||
verbose_code_parts: list[str],
|
||||
) -> None: ...
|
||||
|
||||
class DictGuardManager(GuardManager):
|
||||
def get_key_manager(
|
||||
self, index, source, example_value, guard_manager_enum
|
||||
self,
|
||||
index,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
def get_value_manager(
|
||||
self, index, source, example_value, guard_manager_enum
|
||||
self,
|
||||
index,
|
||||
source,
|
||||
example_value,
|
||||
guard_manager_enum,
|
||||
) -> GuardManager: ...
|
||||
|
||||
def install_tensor_aliasing_guard(
|
||||
guard_managers: List[GuardManager],
|
||||
tensor_names: List[str],
|
||||
verbose_code_parts: List[str],
|
||||
guard_managers: list[GuardManager],
|
||||
tensor_names: list[str],
|
||||
verbose_code_parts: list[str],
|
||||
): ...
|
||||
def install_no_tensor_aliasing_guard(
|
||||
guard_managers: List[GuardManager],
|
||||
tensor_names: List[str],
|
||||
verbose_code_parts: List[str],
|
||||
guard_managers: list[GuardManager],
|
||||
tensor_names: list[str],
|
||||
verbose_code_parts: list[str],
|
||||
): ...
|
||||
|
||||
class TensorGuards:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dynamic_dims_sizes: Optional[List[Optional[torch.SymInt]]] = None,
|
||||
dynamic_dims_strides: Optional[List[Optional[torch.SymInt]]] = None,
|
||||
): ...
|
||||
dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
|
||||
dynamic_dims_strides: list[torch.SymInt | None] | None = None,
|
||||
) -> None: ...
|
||||
def check(self, *args) -> bool: ...
|
||||
def check_verbose(self, *args, tensor_check_names=None) -> Union[bool, str]: ...
|
||||
def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...
|
||||
|
||||
def assert_size_stride(
|
||||
item: torch.Tensor, size: torch.types._size, stride: torch.types._size
|
||||
item: torch.Tensor,
|
||||
size: torch.types._size,
|
||||
stride: torch.types._size,
|
||||
): ...
|
||||
def check_obj_id(obj: object, expected: int) -> bool: ...
|
||||
def check_type_id(obj: object, expected: int) -> bool: ...
|
||||
def dict_version(d: Dict[Any, Any]) -> int: ...
|
||||
def dict_version(d: dict[Any, Any]) -> int: ...
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import AnyStr, List
|
||||
from typing import AnyStr
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
class UndefinedGrad:
|
||||
def __init__(self) -> None: ...
|
||||
def __call__(self, *inputs: Tensor) -> List[Tensor]: ...
|
||||
def __call__(self, *inputs: Tensor) -> list[Tensor]: ...
|
||||
|
||||
class DelayedError:
|
||||
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
|
||||
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
|
||||
def __call__(self, inputs: list[Tensor]) -> list[Tensor]: ...
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@ -15,11 +14,11 @@ def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
|
||||
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
|
||||
def maybe_get_bdim(tensor: Tensor) -> int: ...
|
||||
def maybe_get_level(tensor: Tensor) -> int: ...
|
||||
def maybe_current_level() -> Optional[int]: ...
|
||||
def maybe_current_level() -> int | None: ...
|
||||
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
|
||||
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
||||
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
|
||||
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
|
||||
def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
|
||||
def current_level() -> int: ...
|
||||
def count_jvp_interpreters() -> int: ...
|
||||
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
|
||||
@ -52,23 +51,23 @@ class CInterpreter:
|
||||
def level(self) -> int: ...
|
||||
|
||||
class CGradInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||
def lift(self, Tensor) -> Tensor: ...
|
||||
def prevGradMode(self) -> bool: ...
|
||||
|
||||
class CJvpInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||
def lift(self, Tensor) -> Tensor: ...
|
||||
def prevFwdGradMode(self) -> bool: ...
|
||||
|
||||
class CFunctionalizeInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||
def key(self) -> TransformType: ...
|
||||
def level(self) -> int: ...
|
||||
def functionalizeAddBackViews(self) -> bool: ...
|
||||
|
||||
class CVmapInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter): ...
|
||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||
def key(self) -> TransformType: ...
|
||||
def level(self) -> int: ...
|
||||
def batchSize(self) -> int: ...
|
||||
|
@ -1,26 +1,24 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
# defined in torch/csrc/lazy/python/init.cpp
|
||||
def _mark_step(device: str, devices: List[str], wait: bool): ...
|
||||
def _wait_device_ops(devices: List[str]): ...
|
||||
def _mark_step(device: str, devices: list[str], wait: bool): ...
|
||||
def _wait_device_ops(devices: list[str]): ...
|
||||
def _reset_metrics(): ...
|
||||
def _counter_names() -> List[str]: ...
|
||||
def _counter_names() -> list[str]: ...
|
||||
def _counter_value(name: str) -> int: ...
|
||||
def _metrics_report() -> str: ...
|
||||
def _get_graph_hash(tensors: List[Tensor]) -> str: ...
|
||||
def _get_graph_hash(tensors: list[Tensor]) -> str: ...
|
||||
def _sync_multi(
|
||||
tensors: List[Tensor],
|
||||
devices: List[str],
|
||||
tensors: list[Tensor],
|
||||
devices: list[str],
|
||||
wait: bool = True,
|
||||
sync_ltc_data: bool = True,
|
||||
): ...
|
||||
def _get_tensor_id(tensor: Tensor) -> int: ...
|
||||
def _get_tensors_text(tensors: List[Tensor]) -> str: ...
|
||||
def _get_tensors_dot(tensors: List[Tensor]) -> str: ...
|
||||
def _get_tensors_backend(tensors: List[Tensor]) -> str: ...
|
||||
def _get_tensors_text(tensors: list[Tensor]) -> str: ...
|
||||
def _get_tensors_dot(tensors: list[Tensor]) -> str: ...
|
||||
def _get_tensors_backend(tensors: list[Tensor]) -> str: ...
|
||||
def _get_force_fallback() -> str: ...
|
||||
def _set_force_fallback(newval: str): ...
|
||||
def _clear_ir_cache(): ...
|
||||
|
@ -1,12 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# defined in torch/csrc/lazy/python/init.cpp
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
def _init(): ...
|
||||
def _get_tensors_ts_device_data_node(
|
||||
tensors: List[Tensor],
|
||||
) -> Tuple[List[int], List[Any]]: ...
|
||||
def _run_cached_graph(hash_str: str, graph_inputs: List[Any]) -> List[Tensor]: ...
|
||||
tensors: list[Tensor],
|
||||
) -> tuple[list[int], list[Any]]: ...
|
||||
def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ...
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Union
|
||||
from typing import Callable
|
||||
|
||||
class Aggregation(Enum):
|
||||
VALUE = ...
|
||||
@ -18,22 +18,22 @@ class Stat:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
aggregations: List[Aggregation],
|
||||
aggregations: list[Aggregation],
|
||||
window_size: int,
|
||||
max_samples: int = -1,
|
||||
) -> None: ...
|
||||
def add(self, v: float) -> None: ...
|
||||
def get(self) -> Dict[Aggregation, float]: ...
|
||||
def get(self) -> dict[Aggregation, float]: ...
|
||||
|
||||
class Event:
|
||||
name: str
|
||||
timestamp: datetime.datetime
|
||||
data: Dict[str, Union[int, float, bool, str]]
|
||||
data: dict[str, int | float | bool | str]
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
timestamp: datetime.datetime,
|
||||
data: Dict[str, Union[int, float, bool, str]],
|
||||
data: dict[str, int | float | bool | str],
|
||||
) -> None: ...
|
||||
|
||||
def log_event(e: Event) -> None: ...
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Literal
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from torch._C import device, dtype, layout
|
||||
@ -55,10 +55,10 @@ class _EventType(Enum):
|
||||
class _ExperimentalConfig:
|
||||
def __init__(
|
||||
self,
|
||||
profiler_metrics: List[str] = ...,
|
||||
profiler_metrics: list[str] = ...,
|
||||
profiler_measure_per_kernel: bool = ...,
|
||||
verbose: bool = ...,
|
||||
performance_events: List[str] = ...,
|
||||
performance_events: list[str] = ...,
|
||||
enable_cuda_sync_events: bool = ...,
|
||||
) -> None: ...
|
||||
|
||||
@ -77,31 +77,23 @@ class ProfilerConfig:
|
||||
class _ProfilerEvent:
|
||||
start_tid: int
|
||||
start_time_ns: int
|
||||
children: List[_ProfilerEvent]
|
||||
children: list[_ProfilerEvent]
|
||||
|
||||
# TODO(robieta): remove in favor of `self.typed`
|
||||
extra_fields: Union[
|
||||
_ExtraFields_TorchOp,
|
||||
_ExtraFields_Backend,
|
||||
_ExtraFields_Allocation,
|
||||
_ExtraFields_OutOfMemory,
|
||||
_ExtraFields_PyCall,
|
||||
_ExtraFields_PyCCall,
|
||||
_ExtraFields_Kineto,
|
||||
]
|
||||
extra_fields: _ExtraFields_TorchOp | _ExtraFields_Backend | _ExtraFields_Allocation | _ExtraFields_OutOfMemory | _ExtraFields_PyCall | _ExtraFields_PyCCall | _ExtraFields_Kineto
|
||||
|
||||
@property
|
||||
def typed(
|
||||
self,
|
||||
) -> Union[
|
||||
Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp],
|
||||
Tuple[Literal[_EventType.Backend], _ExtraFields_Backend],
|
||||
Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation],
|
||||
Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory],
|
||||
Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall],
|
||||
Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall],
|
||||
Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto],
|
||||
]: ...
|
||||
) -> (
|
||||
tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp]
|
||||
| tuple[Literal[_EventType.Backend], _ExtraFields_Backend]
|
||||
| tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation]
|
||||
| tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory]
|
||||
| tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall]
|
||||
| tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall]
|
||||
| tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto]
|
||||
): ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
@ -109,7 +101,7 @@ class _ProfilerEvent:
|
||||
@property
|
||||
def id(self) -> int: ...
|
||||
@property
|
||||
def parent(self) -> Optional[_ProfilerEvent]: ...
|
||||
def parent(self) -> _ProfilerEvent | None: ...
|
||||
@property
|
||||
def correlation_id(self) -> int: ...
|
||||
@property
|
||||
@ -118,12 +110,12 @@ class _ProfilerEvent:
|
||||
def duration_time_ns(self) -> int: ...
|
||||
|
||||
class _TensorMetadata:
|
||||
impl_ptr: Optional[int]
|
||||
storage_data_ptr: Optional[int]
|
||||
id: Optional[int]
|
||||
impl_ptr: int | None
|
||||
storage_data_ptr: int | None
|
||||
id: int | None
|
||||
|
||||
@property
|
||||
def allocation_id(self) -> Optional[int]: ...
|
||||
def allocation_id(self) -> int | None: ...
|
||||
@property
|
||||
def layout(self) -> layout: ...
|
||||
@property
|
||||
@ -131,12 +123,12 @@ class _TensorMetadata:
|
||||
@property
|
||||
def dtype(self) -> dtype: ...
|
||||
@property
|
||||
def sizes(self) -> List[int]: ...
|
||||
def sizes(self) -> list[int]: ...
|
||||
@property
|
||||
def strides(self) -> List[int]: ...
|
||||
def strides(self) -> list[int]: ...
|
||||
|
||||
Scalar: TypeAlias = Union[int, float, bool, complex]
|
||||
Input: TypeAlias = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]]
|
||||
Scalar: TypeAlias = int | float | bool | complex
|
||||
Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None
|
||||
|
||||
class _ExtraFields_TorchOp:
|
||||
name: str
|
||||
@ -144,7 +136,7 @@ class _ExtraFields_TorchOp:
|
||||
allow_tf32_cublas: bool
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Input]: ...
|
||||
def inputs(self) -> list[Input]: ...
|
||||
@property
|
||||
def scope(self) -> RecordScope: ...
|
||||
|
||||
@ -152,13 +144,13 @@ class _ExtraFields_Backend: ...
|
||||
|
||||
class _ExtraFields_Allocation:
|
||||
ptr: int
|
||||
id: Optional[int]
|
||||
id: int | None
|
||||
alloc_size: int
|
||||
total_allocated: int
|
||||
total_reserved: int
|
||||
|
||||
@property
|
||||
def allocation_id(self) -> Optional[int]: ...
|
||||
def allocation_id(self) -> int | None: ...
|
||||
@property
|
||||
def device(self) -> device: ...
|
||||
|
||||
@ -181,22 +173,22 @@ class _NNModuleInfo:
|
||||
@property
|
||||
def parameters(
|
||||
self,
|
||||
) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ...
|
||||
) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ...
|
||||
|
||||
class _OptimizerInfo:
|
||||
@property
|
||||
def parameters(
|
||||
self,
|
||||
) -> List[
|
||||
Tuple[
|
||||
) -> list[
|
||||
tuple[
|
||||
# Parameter
|
||||
_TensorMetadata,
|
||||
#
|
||||
# Gradient (if present during optimizer.step())
|
||||
Optional[_TensorMetadata],
|
||||
_TensorMetadata | None,
|
||||
#
|
||||
# Optimizer state for Parameter as (name, tensor) pairs
|
||||
List[Tuple[str, _TensorMetadata]],
|
||||
list[tuple[str, _TensorMetadata]],
|
||||
]
|
||||
]: ...
|
||||
|
||||
@ -210,9 +202,9 @@ class _ExtraFields_PyCall:
|
||||
@property
|
||||
def caller(self) -> _PyFrameState: ...
|
||||
@property
|
||||
def module(self) -> Optional[_NNModuleInfo]: ...
|
||||
def module(self) -> _NNModuleInfo | None: ...
|
||||
@property
|
||||
def optimizer(self) -> Optional[_OptimizerInfo]: ...
|
||||
def optimizer(self) -> _OptimizerInfo | None: ...
|
||||
|
||||
class _ExtraFields_Kineto: ...
|
||||
|
||||
@ -230,15 +222,15 @@ def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback
|
||||
|
||||
# The Dict has name, filename, line
|
||||
def symbolize_tracebacks(
|
||||
to_symbolize: List[CapturedTraceback],
|
||||
) -> List[List[Dict[str, str]]]: ...
|
||||
to_symbolize: list[CapturedTraceback],
|
||||
) -> list[list[dict[str, str]]]: ...
|
||||
|
||||
class _RecordFunctionFast:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_values: Optional[Union[list, tuple]] = None,
|
||||
keyword_values: Optional[dict] = None,
|
||||
input_values: list | tuple | None = None,
|
||||
keyword_values: dict | None = None,
|
||||
) -> None: ...
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: ...
|
||||
def __exit__(self, *args: Any) -> None: ...
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import enum
|
||||
from typing import Any, Callable, Dict, List, Optional, overload, Set, Type
|
||||
from typing import Any, Callable, overload
|
||||
|
||||
import torch
|
||||
from torch.distributed.algorithms.join import Joinable, JoinHook
|
||||
@ -13,10 +13,10 @@ class _ZeROJoinHook(JoinHook):
|
||||
|
||||
class _DDPBucketAssignment:
|
||||
bucket_index: int
|
||||
parameters: List[torch.Tensor]
|
||||
parameters: list[torch.Tensor]
|
||||
offset: int
|
||||
device: torch.device
|
||||
tensor: Optional[torch.Tensor]
|
||||
tensor: torch.Tensor | None
|
||||
|
||||
class _OverlapStatus(enum.IntEnum):
|
||||
UNINITIALIZED: int = ...
|
||||
@ -32,7 +32,7 @@ class _OverlapInfo:
|
||||
bucket_index_to_future: Any = ...
|
||||
bucket_index_to_bucket: Any = ...
|
||||
bucket_indices_seen: Any = ...
|
||||
assigned_ranks_per_bucket: List[Set[int]] = ...
|
||||
assigned_ranks_per_bucket: list[set[int]] = ...
|
||||
total_size: int = ...
|
||||
shard_buckets: bool = ...
|
||||
def __init__(self) -> None: ...
|
||||
@ -48,34 +48,34 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
global_rank: int = ...
|
||||
parameters_as_bucket_view: bool = ...
|
||||
optim: Any = ...
|
||||
_device_to_device_index: Dict[torch.device, int] = ...
|
||||
_device_to_device_index: dict[torch.device, int] = ...
|
||||
_overlap_with_ddp: bool = ...
|
||||
_overlap_info: _OverlapInfo = ...
|
||||
_buckets: List[List[torch.Tensor]] = ...
|
||||
_bucket_assignments_per_rank: List[Dict[int, _DDPBucketAssignment]] = ...
|
||||
_buckets: list[list[torch.Tensor]] = ...
|
||||
_bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
|
||||
def __init__(
|
||||
self,
|
||||
params: Any,
|
||||
optimizer_class: Type[Optimizer],
|
||||
process_group: Optional[Any] = ...,
|
||||
optimizer_class: type[Optimizer],
|
||||
process_group: Any | None = ...,
|
||||
parameters_as_bucket_view: bool = ...,
|
||||
overlap_with_ddp: bool = ...,
|
||||
**defaults: Any,
|
||||
) -> None: ...
|
||||
def add_param_group(self, param_group: Dict[str, Any]) -> None: ...
|
||||
def add_param_group(self, param_group: dict[str, Any]) -> None: ...
|
||||
def consolidate_state_dict(self, to: int = ...) -> None: ...
|
||||
@overload
|
||||
def step(self, closure: None = ..., **kwargs: Any) -> None: ...
|
||||
@overload
|
||||
def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
|
||||
def state_dict(self) -> Dict[str, Any]: ...
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
|
||||
def state_dict(self) -> dict[str, Any]: ...
|
||||
def _local_step(
|
||||
self,
|
||||
gradients: Optional[List[Optional[torch.Tensor]]] = None,
|
||||
closure: Optional[Callable[[], float]] = None,
|
||||
gradients: list[torch.Tensor | None] | None = None,
|
||||
closure: Callable[[], float] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[float]: ...
|
||||
) -> float | None: ...
|
||||
def _get_assigned_rank(self, bucket_index: int) -> int: ...
|
||||
def _init_zero_for_overlap(self) -> None: ...
|
||||
def join_hook(self, **kwargs): ...
|
||||
|
@ -1,11 +1,15 @@
|
||||
from ._symbolic_trace import (
|
||||
from torch.fx._symbolic_trace import (
|
||||
symbolic_trace as symbolic_trace,
|
||||
Tracer as Tracer,
|
||||
wrap as wrap,
|
||||
)
|
||||
from .graph import Graph as Graph
|
||||
from .graph_module import GraphModule as GraphModule
|
||||
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
|
||||
from .node import has_side_effect as has_side_effect, map_arg as map_arg, Node as Node
|
||||
from .proxy import Proxy as Proxy
|
||||
from .subgraph_rewriter import replace_pattern as replace_pattern
|
||||
from torch.fx.graph import Graph as Graph
|
||||
from torch.fx.graph_module import GraphModule as GraphModule
|
||||
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
|
||||
from torch.fx.node import (
|
||||
has_side_effect as has_side_effect,
|
||||
map_arg as map_arg,
|
||||
Node as Node,
|
||||
)
|
||||
from torch.fx.proxy import Proxy as Proxy
|
||||
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern
|
||||
|
@ -1,18 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, NamedTuple, overload, TypeVar
|
||||
from typing_extensions import Never, TypeAlias
|
||||
|
||||
from _typeshed import Incomplete
|
||||
@ -36,6 +24,7 @@ from torch.jit._recursive import (
|
||||
ScriptMethodStub as ScriptMethodStub,
|
||||
wrap_cpp_module as wrap_cpp_module,
|
||||
)
|
||||
from torch.jit._serialization import validate_map_location as validate_map_location
|
||||
from torch.jit._state import (
|
||||
_enabled as _enabled,
|
||||
_set_jit_function_cache as _set_jit_function_cache,
|
||||
@ -60,8 +49,6 @@ from torch.package import (
|
||||
)
|
||||
from torch.utils import set_module as set_module
|
||||
|
||||
from ._serialization import validate_map_location as validate_map_location
|
||||
|
||||
ScriptFunction = torch._C.ScriptFunction
|
||||
|
||||
type_trace_db: JitTypeTraceStore
|
||||
@ -116,7 +103,8 @@ class ConstMap:
|
||||
def __getattr__(self, attr): ...
|
||||
|
||||
def unpackage_script_module(
|
||||
importer: PackageImporter, script_module_id: str
|
||||
importer: PackageImporter,
|
||||
script_module_id: str,
|
||||
) -> torch.nn.Module: ...
|
||||
|
||||
_magic_methods: Incomplete
|
||||
@ -126,7 +114,7 @@ class RecursiveScriptClass:
|
||||
_props: Incomplete
|
||||
def __init__(self, cpp_class) -> None: ...
|
||||
def __getattr__(self, attr): ...
|
||||
def __setattr__(self, attr, value): ...
|
||||
def __setattr__(self, attr, value) -> None: ...
|
||||
def forward_magic_method(self, method_name, *args, **kwargs): ...
|
||||
def __getstate__(self) -> None: ...
|
||||
def __iadd__(self, other): ...
|
||||
@ -138,7 +126,7 @@ class ScriptModule(Module, metaclass=ScriptMeta):
|
||||
def __init__(self) -> None: ...
|
||||
forward: Callable[..., Any]
|
||||
def __getattr__(self, attr): ...
|
||||
def __setattr__(self, attr, value): ...
|
||||
def __setattr__(self, attr, value) -> None: ...
|
||||
def define(self, src): ...
|
||||
def _replicate_for_data_parallel(self): ...
|
||||
def __reduce_package__(self, exporter: PackageExporter): ...
|
||||
@ -146,7 +134,7 @@ class ScriptModule(Module, metaclass=ScriptMeta):
|
||||
@property
|
||||
def code(self) -> str: ...
|
||||
@property
|
||||
def code_with_constants(self) -> Tuple[str, ConstMap]: ...
|
||||
def code_with_constants(self) -> tuple[str, ConstMap]: ...
|
||||
@property
|
||||
def graph(self) -> torch.Graph: ...
|
||||
@property
|
||||
@ -177,7 +165,7 @@ class RecursiveScriptModule(ScriptModule):
|
||||
def graph_for(self, *args, **kwargs): ...
|
||||
def define(self, src) -> None: ...
|
||||
def __getattr__(self, attr): ...
|
||||
def __setattr__(self, attr, value): ...
|
||||
def __setattr__(self, attr, value) -> None: ...
|
||||
def __copy__(self): ...
|
||||
def __deepcopy__(self, memo): ...
|
||||
def forward_magic_method(self, method_name, *args, **kwargs): ...
|
||||
@ -200,59 +188,59 @@ def create_script_dict(obj): ...
|
||||
def create_script_list(obj, type_hint: Incomplete | None = ...): ...
|
||||
@overload
|
||||
def script(
|
||||
obj: Type[Module],
|
||||
optimize: Optional[bool] = None,
|
||||
obj: type[Module],
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> Never: ...
|
||||
@overload
|
||||
def script( # type: ignore[misc]
|
||||
obj: Dict,
|
||||
optimize: Optional[bool] = None,
|
||||
obj: dict,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> torch.ScriptDict: ...
|
||||
@overload
|
||||
def script( # type: ignore[misc]
|
||||
obj: List,
|
||||
optimize: Optional[bool] = None,
|
||||
obj: list,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> torch.ScriptList: ...
|
||||
@overload
|
||||
def script( # type: ignore[misc]
|
||||
obj: Module,
|
||||
optimize: Optional[bool] = None,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> RecursiveScriptModule: ...
|
||||
@overload
|
||||
def script( # type: ignore[misc]
|
||||
obj: _ClassVar,
|
||||
optimize: Optional[bool] = None,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> _ClassVar: ...
|
||||
@overload
|
||||
def script( # type: ignore[misc]
|
||||
obj: Callable,
|
||||
optimize: Optional[bool] = None,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> ScriptFunction: ...
|
||||
@overload
|
||||
def script(
|
||||
obj: Any,
|
||||
optimize: Optional[bool] = None,
|
||||
optimize: bool | None = None,
|
||||
_frames_up: int = 0,
|
||||
_rcb: Optional[ResolutionCallback] = None,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
|
||||
_rcb: ResolutionCallback | None = None,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
|
||||
) -> RecursiveScriptClass: ...
|
||||
@overload
|
||||
def script(
|
||||
@ -260,7 +248,7 @@ def script(
|
||||
optimize: Incomplete | None = ...,
|
||||
_frames_up: int = ...,
|
||||
_rcb: Incomplete | None = ...,
|
||||
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = ...,
|
||||
example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
|
||||
): ...
|
||||
def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
|
||||
def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
|
||||
@ -279,7 +267,10 @@ class _ScriptProfileColumn:
|
||||
offset: Incomplete
|
||||
rows: Incomplete
|
||||
def __init__(
|
||||
self, header: str, alignment: int = ..., offset: int = ...
|
||||
self,
|
||||
header: str,
|
||||
alignment: int = ...,
|
||||
offset: int = ...,
|
||||
) -> None: ...
|
||||
def add_row(self, lineno: int, value: Any): ...
|
||||
def materialize(self): ...
|
||||
@ -288,7 +279,9 @@ class _ScriptProfileTable:
|
||||
cols: Incomplete
|
||||
source_range: Incomplete
|
||||
def __init__(
|
||||
self, cols: List[_ScriptProfileColumn], source_range: List[int]
|
||||
self,
|
||||
cols: list[_ScriptProfileColumn],
|
||||
source_range: list[int],
|
||||
) -> None: ...
|
||||
def dump_string(self): ...
|
||||
|
||||
|
@ -1,41 +1,29 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import builtins
|
||||
from typing import Optional, Tuple
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import device, dtype, Tensor
|
||||
|
||||
class Parameter(Tensor):
|
||||
def __init__(
|
||||
self,
|
||||
data: Tensor = ...,
|
||||
requires_grad: builtins.bool = ...,
|
||||
): ...
|
||||
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
||||
|
||||
def is_lazy(param: Tensor): ...
|
||||
def is_lazy(
|
||||
param: Tensor,
|
||||
) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ...
|
||||
|
||||
class UninitializedParameter(Tensor):
|
||||
def __init__(
|
||||
self,
|
||||
data: Tensor = ...,
|
||||
requires_grad: builtins.bool = ...,
|
||||
): ...
|
||||
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
): ...
|
||||
shape: tuple[int, ...],
|
||||
device: device | None = None,
|
||||
dtype: dtype | None = None,
|
||||
) -> None: ...
|
||||
|
||||
class UninitializedBuffer(Tensor):
|
||||
def __init__(
|
||||
self,
|
||||
data: Tensor = ...,
|
||||
requires_grad: builtins.bool = ...,
|
||||
): ...
|
||||
def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ...
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
): ...
|
||||
shape: tuple[int, ...],
|
||||
device: device | None = None,
|
||||
dtype: dtype | None = None,
|
||||
) -> None: ...
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Iterable, NamedTuple, Optional, overload, Sequence, Tuple, Union
|
||||
from typing import Any, Iterable, NamedTuple, overload, Sequence
|
||||
from typing_extensions import Self
|
||||
|
||||
from torch import Tensor
|
||||
@ -9,8 +9,8 @@ from torch.types import _dtype
|
||||
class PackedSequence_(NamedTuple):
|
||||
data: Tensor
|
||||
batch_sizes: Tensor
|
||||
sorted_indices: Optional[Tensor]
|
||||
unsorted_indices: Optional[Tensor]
|
||||
sorted_indices: Tensor | None
|
||||
unsorted_indices: Tensor | None
|
||||
|
||||
def bind(optional: Any, fn: Any): ...
|
||||
|
||||
@ -18,9 +18,9 @@ class PackedSequence(PackedSequence_):
|
||||
def __new__(
|
||||
cls,
|
||||
data: Tensor,
|
||||
batch_sizes: Optional[Tensor] = ...,
|
||||
sorted_indices: Optional[Tensor] = ...,
|
||||
unsorted_indices: Optional[Tensor] = ...,
|
||||
batch_sizes: Tensor | None = ...,
|
||||
sorted_indices: Tensor | None = ...,
|
||||
unsorted_indices: Tensor | None = ...,
|
||||
) -> Self: ...
|
||||
def pin_memory(self: Self) -> Self: ...
|
||||
def cuda(self: Self, *args: Any, **kwargs: Any) -> Self: ...
|
||||
@ -43,8 +43,8 @@ class PackedSequence(PackedSequence_):
|
||||
@overload
|
||||
def to(
|
||||
self: Self,
|
||||
device: Optional[DeviceLikeType] = None,
|
||||
dtype: Optional[_dtype] = None,
|
||||
device: DeviceLikeType | None = None,
|
||||
dtype: _dtype | None = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
) -> Self: ...
|
||||
@ -59,7 +59,7 @@ class PackedSequence(PackedSequence_):
|
||||
def is_cuda(self) -> bool: ...
|
||||
def is_pinned(self) -> bool: ...
|
||||
|
||||
def invert_permutation(permutation: Optional[Tensor]): ...
|
||||
def invert_permutation(permutation: Tensor | None): ...
|
||||
def pack_padded_sequence(
|
||||
input: Tensor,
|
||||
lengths: Tensor,
|
||||
@ -70,10 +70,10 @@ def pad_packed_sequence(
|
||||
sequence: PackedSequence,
|
||||
batch_first: bool = ...,
|
||||
padding_value: float = ...,
|
||||
total_length: Optional[int] = ...,
|
||||
) -> Tuple[Tensor, ...]: ...
|
||||
total_length: int | None = ...,
|
||||
) -> tuple[Tensor, ...]: ...
|
||||
def pad_sequence(
|
||||
sequences: Union[Tensor, Iterable[Tensor]],
|
||||
sequences: Tensor | Iterable[Tensor],
|
||||
batch_first: bool = False,
|
||||
padding_value: float = ...,
|
||||
) -> Tensor: ...
|
||||
@ -83,7 +83,7 @@ def pack_sequence(
|
||||
) -> PackedSequence: ...
|
||||
def get_packed_sequence(
|
||||
data: Tensor,
|
||||
batch_sizes: Optional[Tensor],
|
||||
sorted_indices: Optional[Tensor],
|
||||
unsorted_indices: Optional[Tensor],
|
||||
batch_sizes: Tensor | None,
|
||||
sorted_indices: Tensor | None,
|
||||
unsorted_indices: Tensor | None,
|
||||
) -> PackedSequence: ...
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
"""
|
||||
This was semi-automatically generated by running
|
||||
@ -24,13 +24,11 @@ Note that the import should happen before the call to install_config_module(), o
|
||||
assert TYPE_CHECKING, "Do not use at runtime"
|
||||
|
||||
def save_config() -> bytes: ...
|
||||
def save_config_portable() -> Dict[str, Any]: ...
|
||||
def save_config_portable() -> dict[str, Any]: ...
|
||||
def codegen_config() -> str: ...
|
||||
def get_hash() -> bytes: ...
|
||||
def to_dict() -> Dict[str, Any]: ...
|
||||
def shallow_copy_dict() -> Dict[str, Any]: ...
|
||||
def load_config(config: Union[bytes, Dict[str, Any]]) -> None: ...
|
||||
def get_config_copy() -> Dict[str, Any]: ...
|
||||
def patch(
|
||||
arg1: Optional[Union[str, Dict[str, Any]]] = None, arg2: Any = None, **kwargs
|
||||
): ...
|
||||
def to_dict() -> dict[str, Any]: ...
|
||||
def shallow_copy_dict() -> dict[str, Any]: ...
|
||||
def load_config(config: bytes | dict[str, Any]) -> None: ...
|
||||
def get_config_copy() -> dict[str, Any]: ...
|
||||
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...
|
||||
|
Reference in New Issue
Block a user