mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665 Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now) The first one type is Truncated: ``` default_pg, _ = _new_process_group_helper( File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper backend_class = creator_fn(dist_backend_opts, backend_options) File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg return FakeProcessGroup._create_internal( RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero. Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first): C++ CapturedTraceback: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 #5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0 #6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0 #7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0 #8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0 #9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0 #10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0 ``` and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR The other one type is ``` Traceback (most recent call last): File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import self.assertEqual(out, "") File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual raise error_metas.pop()[0].to_error( # type: ignore[index] AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != '' - /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode. - if is_available() and not torch._C._c10d_init(): To execute this test, run the following from the base repo dir: python test/test_testing.py TestImports.test_no_warning_on_import ``` which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one? Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479 Approved by: https://github.com/ezyang
866 lines
25 KiB
Python
866 lines
25 KiB
Python
# mypy: allow-untyped-defs
|
|
# mypy: disable-error-code="type-arg"
|
|
from datetime import timedelta
|
|
from enum import Enum
|
|
from typing import Any, Optional, overload, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch._C import ScriptObject
|
|
from torch._C._autograd import DeviceType
|
|
from torch.futures import Future
|
|
|
|
# This module is defined in torch/csrc/distributed/c10d/init.cpp
|
|
|
|
_DEFAULT_FIRST_BUCKET_BYTES: int
|
|
_DEFAULT_NO_TIMEOUT: timedelta
|
|
_DEFAULT_PG_TIMEOUT: timedelta
|
|
_DEFAULT_PG_NCCL_TIMEOUT: timedelta
|
|
|
|
class BuiltinCommHookType(Enum):
|
|
ALLREDUCE = ...
|
|
FP16_COMPRESS = ...
|
|
|
|
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
|
|
def _register_builtin_comm_hook(
|
|
reducer: Reducer,
|
|
comm_hook_type: BuiltinCommHookType,
|
|
): ...
|
|
def _set_global_rank(rank: int) -> None: ...
|
|
def _hash_tensors(tensors: list[Tensor]) -> int: ...
|
|
|
|
class GradBucket:
|
|
def index(self) -> int: ...
|
|
def buffer(self) -> Tensor: ...
|
|
def gradients(self) -> list[Tensor]: ...
|
|
def is_last(self) -> bool: ...
|
|
def set_buffer(self, tensor: Tensor) -> None: ...
|
|
def parameters(self) -> list[Tensor]: ...
|
|
|
|
class Reducer:
|
|
def __init__(
|
|
self,
|
|
params: list[Tensor],
|
|
bucket_indices: list[list[int]],
|
|
per_bucket_size_limits: list[int],
|
|
process_group: ProcessGroup,
|
|
expect_sparse_gradients: list[bool] = ...,
|
|
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
|
|
find_unused_parameters: bool = ...,
|
|
gradient_as_bucket_view: bool = ...,
|
|
param_to_name_mapping: dict[int, str] = ...,
|
|
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
|
skip_all_reduce_unused_params: bool = ...,
|
|
use_python_reducer: bool = ...,
|
|
) -> None: ...
|
|
def prepare_for_forward(self) -> None: ...
|
|
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
|
|
def get_backward_stats(self) -> list[int]: ...
|
|
def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
|
|
def _rebuild_buckets(self) -> bool: ...
|
|
def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
|
|
def _push_all_rebuilt_params(self) -> None: ...
|
|
def _set_forward_pass_work_handle(
|
|
self,
|
|
work: Work,
|
|
use_static_world_size: bool,
|
|
): ...
|
|
def _get_local_used_map(self) -> Tensor: ...
|
|
def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
|
|
def _set_static_graph(self) -> None: ...
|
|
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
|
|
def set_logger(self, logger: Logger) -> None: ...
|
|
def _remove_autograd_hooks(self) -> None: ...
|
|
def _check_reducer_finalized(self) -> None: ...
|
|
def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
|
|
def _reset_state(self) -> None: ...
|
|
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
|
|
|
|
class DDPLoggingData:
|
|
strs_map: dict[str, str]
|
|
ints_map: dict[str, int]
|
|
|
|
class Logger:
|
|
def __init__(self, reducer: Reducer) -> None: ...
|
|
def set_construction_data_and_log(
|
|
self,
|
|
module_name: str,
|
|
device_ids: list[int],
|
|
output_device: int,
|
|
broadcast_buffers: bool,
|
|
has_sync_bn: bool,
|
|
static_graph: bool,
|
|
): ...
|
|
def set_runtime_stats_and_log(self) -> None: ...
|
|
def set_error_and_log(self, error: str) -> None: ...
|
|
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
|
|
def _set_comm_hook_name(self, comm_hook: str) -> None: ...
|
|
def _set_uneven_input_join(self) -> None: ...
|
|
def _set_static_graph(self) -> None: ...
|
|
|
|
class _WorkerServer:
|
|
def __init__(self, socket_path: str) -> None: ...
|
|
def shutdown(self) -> None: ...
|
|
|
|
def get_debug_level(): ...
|
|
def set_debug_level(): ...
|
|
def set_debug_level_from_env(): ...
|
|
|
|
class DebugLevel(Enum):
|
|
OFF = ...
|
|
INFO = ...
|
|
DETAIL = ...
|
|
|
|
class ReduceOp:
|
|
# pyrefly: ignore # unknown-name
|
|
def __init__(self, op: RedOpType) -> None: ...
|
|
|
|
# pyrefly: ignore # unknown-name
|
|
SUM: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
AVG: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
PRODUCT: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
MIN: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
MAX: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
BAND: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
BOR: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
BXOR: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
PREMUL_SUM: RedOpType = ...
|
|
# pyrefly: ignore # unknown-name
|
|
UNUSED: RedOpType = ...
|
|
|
|
# mypy error being ignored:
|
|
# Detected enum "torch._C._distributed_c10d.ReduceOp.RedOpType" in a type
|
|
# stub with zero members. There is a chance this is due to a recent change
|
|
# in the semantics of enum membership. If so, use `member = value` to mark
|
|
# an enum member, instead of `member: type`
|
|
class RedOpType(Enum): ... # type: ignore[misc]
|
|
|
|
class BroadcastOptions:
|
|
rootRank: int
|
|
rootTensor: int
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class AllreduceOptions:
|
|
reduceOp: ReduceOp
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
sparseIndices: Optional[Tensor]
|
|
|
|
class AllreduceCoalescedOptions(AllreduceOptions): ...
|
|
|
|
class ReduceOptions:
|
|
reduceOp: ReduceOp
|
|
rootRank: int
|
|
rootTensor: int
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class AllgatherOptions:
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class GatherOptions:
|
|
rootRank: int
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class ScatterOptions:
|
|
rootRank: int
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class ReduceScatterOptions:
|
|
reduceOp: ReduceOp
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class BarrierOptions:
|
|
device_ids: list[int]
|
|
device: torch.device
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class AllToAllOptions:
|
|
timeout: timedelta
|
|
asyncOp: bool
|
|
|
|
class Store:
|
|
def set(self, key: str, value: str): ...
|
|
def get(self, key: str) -> bytes: ...
|
|
def add(self, key: str, value: int) -> int: ...
|
|
def check(self, keys: list[str]) -> bool: ...
|
|
def compare_set(
|
|
self,
|
|
key: str,
|
|
expected_value: str,
|
|
desired_value: str,
|
|
) -> bytes: ...
|
|
def delete_key(self, key: str) -> bool: ...
|
|
def num_keys(self) -> int: ...
|
|
def set_timeout(self, timeout: timedelta): ...
|
|
@overload
|
|
def wait(self, keys: list[str]): ...
|
|
@overload
|
|
def wait(self, keys: list[str], timeout: timedelta): ...
|
|
def queue_pop(self, key: str, block: bool = True) -> bytes: ...
|
|
def queue_push(self, key: str, value: Union[bytes, str]) -> None: ...
|
|
def queue_len(self, key: str) -> int: ...
|
|
|
|
class FileStore(Store):
|
|
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
|
|
|
|
class HashStore(Store):
|
|
def __init__(self) -> None: ...
|
|
|
|
class TCPStore(Store):
|
|
def __init__(
|
|
self,
|
|
host_name: str,
|
|
port: int,
|
|
world_size: int | None = ...,
|
|
is_master: bool = ...,
|
|
timeout: timedelta = ...,
|
|
wait_for_workers: bool = ...,
|
|
multi_tenant: bool = ...,
|
|
master_listen_fd: int | None = ...,
|
|
use_libuv: bool | None = ...,
|
|
) -> None: ...
|
|
@property
|
|
def host(self) -> str: ...
|
|
@property
|
|
def port(self) -> int: ...
|
|
|
|
class PrefixStore(Store):
|
|
def __init__(self, prefix: str, store: Store) -> None: ...
|
|
@property
|
|
def underlying_store(self) -> Store: ...
|
|
|
|
class _ControlCollectives:
|
|
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
|
|
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
|
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
|
|
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
|
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
|
|
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
|
|
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
|
|
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
|
|
def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ...
|
|
|
|
class _StoreCollectives(_ControlCollectives):
|
|
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
|
|
|
|
class _DistributedBackendOptions:
|
|
def __init__(self) -> None: ...
|
|
@property
|
|
def store(self) -> Store: ...
|
|
@store.setter
|
|
def store(self, store: Store) -> None: ...
|
|
@property
|
|
def group_rank(self) -> int: ...
|
|
@group_rank.setter
|
|
def group_rank(self, rank: int) -> None: ...
|
|
@property
|
|
def group_size(self) -> int: ...
|
|
@group_size.setter
|
|
def group_size(self, size: int) -> None: ...
|
|
@property
|
|
def timeout(self) -> timedelta: ...
|
|
@timeout.setter
|
|
def timeout(self, timeout: timedelta) -> None: ...
|
|
@property
|
|
def group_id(self) -> str: ...
|
|
@group_id.setter
|
|
def group_id(self, group_id: str) -> None: ...
|
|
@property
|
|
def global_ranks_in_group(self) -> list[int]: ...
|
|
@global_ranks_in_group.setter
|
|
def global_ranks_in_group(self, ranks: list[int]) -> None: ...
|
|
|
|
class Work:
|
|
def is_completed(self) -> bool: ...
|
|
def is_success(self) -> bool: ...
|
|
def exception(self) -> Any: ...
|
|
def wait(self, timeout: timedelta = ...) -> bool: ...
|
|
def block_current_stream(self) -> None: ...
|
|
def get_future(self) -> Future: ...
|
|
def source_rank(self) -> int: ...
|
|
def _source_rank(self) -> int: ...
|
|
def result(self) -> list[Tensor]: ...
|
|
def synchronize(self) -> None: ...
|
|
def boxed(self) -> ScriptObject: ...
|
|
@staticmethod
|
|
def unbox(obj: ScriptObject) -> Work: ...
|
|
|
|
class Backend:
|
|
class Options:
|
|
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
|
|
@property
|
|
def backend(self) -> str: ...
|
|
@property
|
|
def _timeout(self) -> timedelta: ...
|
|
@_timeout.setter
|
|
def _timeout(self, val: timedelta) -> None: ...
|
|
global_ranks_in_group: list[int]
|
|
group_name: str
|
|
|
|
def __init__(
|
|
self,
|
|
rank: int,
|
|
size: int,
|
|
) -> None: ...
|
|
@property
|
|
def supports_splitting(self) -> bool: ...
|
|
@property
|
|
def supports_coalescing(self) -> bool: ...
|
|
@property
|
|
def supports_time_estimate(self) -> bool: ...
|
|
def set_timeout(self, timeout: timedelta) -> None: ...
|
|
@property
|
|
def options(self) -> Options: ...
|
|
def rank(self) -> int: ...
|
|
def size(self) -> int: ...
|
|
def name(self) -> str: ...
|
|
def abort(self) -> None: ...
|
|
def shutdown(self) -> None: ...
|
|
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
|
|
def _set_sequence_number_for_group(self) -> None: ...
|
|
def _set_default_timeout(self, timeout: timedelta) -> None: ...
|
|
def get_error(self) -> ErrorType: ...
|
|
def supports_tensor_alloc(self, device: torch.device) -> bool: ...
|
|
def allocate_tensor(
|
|
self,
|
|
size: int,
|
|
*,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> Tensor: ...
|
|
@property
|
|
def mem_allocator(self) -> Any: ...
|
|
|
|
class ProcessGroup:
|
|
class BackendType(Enum):
|
|
UNDEFINED = ...
|
|
GLOO = ...
|
|
NCCL = ...
|
|
UCC = ...
|
|
MPI = ...
|
|
XCCL = ...
|
|
CUSTOM = ...
|
|
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
rank: int,
|
|
size: int,
|
|
) -> None: ...
|
|
def rank(self) -> int: ...
|
|
def size(self) -> int: ...
|
|
def get_group_store(self) -> Store: ...
|
|
def split_group(
|
|
self,
|
|
new_ranks: list[int],
|
|
timeout: Optional[timedelta] = None,
|
|
opts: Optional[Backend.Options] = None,
|
|
group_name: Optional[str] = None,
|
|
group_desc: Optional[str] = None,
|
|
) -> Optional[ProcessGroup]: ...
|
|
def merge_remote_group(
|
|
self,
|
|
store: Store,
|
|
size: int,
|
|
timeout: timedelta,
|
|
group_name: Optional[str] = None,
|
|
group_desc: Optional[str] = None,
|
|
) -> ProcessGroup: ...
|
|
def abort(self) -> None: ...
|
|
def set_timeout(self, timeout: timedelta) -> None: ...
|
|
def shutdown(self) -> None: ...
|
|
@overload
|
|
def broadcast(
|
|
self,
|
|
tensors: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def broadcast(
|
|
self,
|
|
tensor: Tensor,
|
|
root: int,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def allreduce(
|
|
self,
|
|
tensors: list[Tensor],
|
|
opts: AllreduceOptions = ...,
|
|
) -> Work: ...
|
|
@overload
|
|
def allreduce(
|
|
self,
|
|
tensors: list[Tensor],
|
|
op=...,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def allreduce(
|
|
self,
|
|
tensor: Tensor,
|
|
op=...,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
def allreduce_coalesced(
|
|
self,
|
|
tensors: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
def reduce_scatter_tensor_coalesced(
|
|
self,
|
|
outputTensors: list[Tensor],
|
|
inputTensors: list[Tensor],
|
|
opts: ReduceScatterOptions | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def reduce(
|
|
self,
|
|
tensors: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def reduce(
|
|
self,
|
|
tensor: Tensor,
|
|
root: int,
|
|
op=...,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def allgather(
|
|
self,
|
|
output_tensors: list[list[Tensor]],
|
|
input_tensors: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def allgather(
|
|
self,
|
|
output_tensors: list[Tensor],
|
|
input_tensor: Tensor,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
def _allgather_base(
|
|
self,
|
|
output: Tensor,
|
|
input: Tensor,
|
|
opts=...,
|
|
) -> Work: ...
|
|
def allgather_coalesced(
|
|
self,
|
|
output_lists: list[list[Tensor]],
|
|
input_list: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
def allgather_into_tensor_coalesced(
|
|
self,
|
|
output_lists: list[Tensor],
|
|
input_list: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def gather(
|
|
self,
|
|
output_tensors: list[list[Tensor]],
|
|
input_tensors: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def gather(
|
|
self,
|
|
output_tensors: list[Tensor],
|
|
input_tensor: Tensor,
|
|
root: int,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def scatter(
|
|
self,
|
|
output_tensors: list[Tensor],
|
|
input_tensors: list[list[Tensor]],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def scatter(
|
|
self,
|
|
output_tensor: Tensor,
|
|
input_tensors: list[Tensor],
|
|
root: int,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def reduce_scatter(
|
|
self,
|
|
output_tensors: list[Tensor],
|
|
input_tensors: list[list[Tensor]],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def reduce_scatter(
|
|
self,
|
|
output_tensors: Tensor,
|
|
input_tensor: list[Tensor],
|
|
op=...,
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
def _reduce_scatter_base(
|
|
self,
|
|
outputTensor: Tensor,
|
|
inputTensor: Tensor,
|
|
opts: ReduceScatterOptions | None,
|
|
) -> Work: ...
|
|
@overload
|
|
def alltoall_base(
|
|
self,
|
|
output_tensor: Tensor,
|
|
input_tensor: Tensor,
|
|
output_split_sizes: list[int],
|
|
input_split_sizes: list[int],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def alltoall_base(
|
|
self,
|
|
output: Tensor,
|
|
input: Tensor,
|
|
output_split_sizes: list[int],
|
|
input_split_sizes: list[int],
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
@overload
|
|
def alltoall(
|
|
self,
|
|
output_tensor: list[Tensor],
|
|
input_tensor: list[Tensor],
|
|
opts=...,
|
|
) -> Work: ...
|
|
@overload
|
|
def alltoall(
|
|
self,
|
|
output: list[Tensor],
|
|
input: list[Tensor],
|
|
timeout: timedelta | None = None,
|
|
) -> Work: ...
|
|
def send(
|
|
self,
|
|
tensors: list[Tensor],
|
|
dstRank: int,
|
|
tag: int,
|
|
) -> Work: ...
|
|
def recv(
|
|
self,
|
|
tensors: list[Tensor],
|
|
srcRank: int,
|
|
tag: int,
|
|
) -> Work: ...
|
|
def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
|
|
@overload
|
|
def barrier(self, opts=...) -> Work: ...
|
|
@overload
|
|
def barrier(self, timeout: timedelta | None = None) -> Work: ...
|
|
def boxed(self) -> ScriptObject: ...
|
|
@staticmethod
|
|
def unbox(obj: ScriptObject) -> ProcessGroup: ...
|
|
def _start_coalescing(self, device: torch.device) -> None: ...
|
|
def _end_coalescing(self, device: torch.device) -> Work: ...
|
|
def _get_backend_name(self) -> str: ...
|
|
def _backend_id(self, backend_type: BackendType) -> int: ...
|
|
@property
|
|
def _device_types(self) -> list[torch.device]: ...
|
|
def _get_backend(self, device: torch.device) -> Backend: ...
|
|
def _set_default_backend(self, backend_type: BackendType) -> None: ...
|
|
def _register_backend(
|
|
self,
|
|
device: torch.device,
|
|
backend_type: BackendType,
|
|
backend: Backend | None,
|
|
) -> None: ...
|
|
def _set_group_name(self, name: str) -> None: ...
|
|
def _set_group_desc(self, desc: str) -> None: ...
|
|
def name(self) -> str: ...
|
|
def _has_hooks(self) -> bool: ...
|
|
def _wait_for_pending_works(self) -> None: ...
|
|
def _set_sequence_number_for_group(self) -> None: ...
|
|
@property
|
|
def bound_device_id(self) -> torch.device | None: ...
|
|
@bound_device_id.setter
|
|
def bound_device_id(self, device: torch.device | None) -> None: ...
|
|
@property
|
|
def group_name(self) -> str: ...
|
|
@property
|
|
def group_desc(self) -> str: ...
|
|
|
|
class FakeProcessGroup(Backend):
|
|
@staticmethod
|
|
def _create_internal(rank: int, world_size: int) -> FakeProcessGroup: ...
|
|
|
|
class FakeWork(Work):
|
|
seq_id: int
|
|
def __init__(self) -> None: ...
|
|
def wait(self, timeout: timedelta = ...) -> bool: ...
|
|
def getFuture(self) -> Future: ...
|
|
|
|
class ProcessGroupGloo(Backend):
|
|
class Device: ...
|
|
|
|
class Options(Backend.Options):
|
|
devices: list[ProcessGroupGloo.Device]
|
|
threads: int
|
|
|
|
def __init__(self): ...
|
|
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
rank: int,
|
|
size: int,
|
|
timeout: timedelta,
|
|
) -> None: ...
|
|
@staticmethod
|
|
def create_device(hostname="", interface="", lazy_init=None) -> Device: ...
|
|
@staticmethod
|
|
def create_default_device(lazy_init=None) -> Device: ...
|
|
def _set_default_timeout(self, timeout) -> None: ...
|
|
@property
|
|
def options(self) -> Options: ... # type: ignore[override]
|
|
|
|
class _ProcessGroupWrapper(Backend):
|
|
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
|
|
wrapped_pg: Backend
|
|
|
|
class ErrorType(Enum):
|
|
SUCCESS = ...
|
|
TIMEOUT = ...
|
|
COMM_ERROR = ...
|
|
REMOTE_ERROR = ...
|
|
|
|
class ProcessGroupNCCL(Backend):
|
|
class NCCLConfig:
|
|
blocking: int
|
|
cga_cluster_size: int
|
|
min_ctas: int
|
|
max_ctas: int
|
|
def unsafe_get_ptr(self) -> int: ...
|
|
|
|
class Options(Backend.Options):
|
|
config: ProcessGroupNCCL.NCCLConfig
|
|
is_high_priority_stream: bool
|
|
split_from: ProcessGroupNCCL
|
|
split_color: int
|
|
|
|
def __init__(self, is_high_priority_stream: bool = False): ...
|
|
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
rank: int,
|
|
size: int,
|
|
options: Options,
|
|
) -> None: ...
|
|
def _group_start(self) -> None: ...
|
|
def _group_end(self) -> None: ...
|
|
def _start_time_estimate(self) -> None: ...
|
|
def _end_time_estimate(self) -> float: ...
|
|
def _set_default_timeout(self, timeout) -> None: ...
|
|
def perform_nocolor_split(self, device: torch.device) -> None: ...
|
|
def register_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
|
def deregister_mem_pool(self, pool: torch.cuda.MemPool) -> None: ...
|
|
def comm_split_count(self) -> int: ...
|
|
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
|
|
def abort(self) -> None: ...
|
|
def _is_initialized(self) -> bool: ...
|
|
@property
|
|
def uid(self) -> int: ...
|
|
@property
|
|
def options(self) -> Options: ... # type: ignore[override]
|
|
@staticmethod
|
|
def get_build_nccl_version(self) -> tuple[int, int, int]: ...
|
|
@staticmethod
|
|
def get_runtime_nccl_version(self) -> tuple[int, int, int]: ...
|
|
|
|
class ProcessGroupUCC(Backend):
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
rank: int,
|
|
size: int,
|
|
timeout: timedelta,
|
|
) -> None: ...
|
|
|
|
class ProcessGroupMPI(Backend):
|
|
def __init__(
|
|
self,
|
|
rank: int,
|
|
size: int,
|
|
pgComm: int,
|
|
) -> None: ...
|
|
@staticmethod
|
|
def create(ranks: list[int]) -> ProcessGroupMPI: ...
|
|
|
|
def _compute_bucket_assignment_by_size(
|
|
tensors: list[Tensor],
|
|
bucket_size_limits: list[int],
|
|
expect_sparse_gradient: list[bool] = ...,
|
|
tensor_indices: list[int] = ...,
|
|
) -> tuple[list[list[int]], list[int]]: ...
|
|
def _broadcast_coalesced(
|
|
process_group: ProcessGroup,
|
|
tensors: list[Tensor],
|
|
buffer_size: int,
|
|
src: int,
|
|
): ...
|
|
def _test_python_store(store: Store): ...
|
|
def _verify_params_across_processes(
|
|
process_group: ProcessGroup,
|
|
params: list[Tensor],
|
|
logger: Logger | None,
|
|
): ...
|
|
def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
|
|
def _register_process_group(
|
|
group_name: str,
|
|
process_group: ProcessGroup,
|
|
) -> None: ...
|
|
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
|
|
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
|
|
def _get_work_registry_size() -> int: ...
|
|
def _set_allow_inflight_collective_as_graph_input(
|
|
value: bool,
|
|
) -> None: ...
|
|
def _allow_inflight_collective_as_graph_input() -> bool: ...
|
|
def _unregister_all_process_groups() -> None: ...
|
|
def _unregister_process_group(group_name: str) -> None: ...
|
|
|
|
# Initializes the device state in CUmodule so that it's able to perform NVSHMEM
|
|
# operations. CUmodule is a pointer to a CUDA module, carried by a int64 in
|
|
# Python. At C++ interface, it is converted to a uintptr_t.
|
|
def _nvshmemx_cumodule_init(module: int) -> None: ...
|
|
|
|
# Check if NVSHMEM is available on current system.
|
|
def _is_nvshmem_available() -> bool: ...
|
|
|
|
class _SymmetricMemory:
|
|
@staticmethod
|
|
def set_group_info(
|
|
group_name: str,
|
|
rank: int,
|
|
world_size: int,
|
|
store: Store,
|
|
) -> None: ...
|
|
@staticmethod
|
|
def empty_strided_p2p(
|
|
size: torch.types._size,
|
|
stride: torch.types._size,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
group_name: str | None = None,
|
|
alloc_id: int | None = None,
|
|
) -> torch.Tensor: ...
|
|
@staticmethod
|
|
def has_multicast_support(
|
|
device_type: DeviceType,
|
|
device_idx: int,
|
|
) -> bool: ...
|
|
# Set Symmetric Memory allocation backend.
|
|
@staticmethod
|
|
def set_backend(name: str) -> None: ...
|
|
@staticmethod
|
|
def get_backend(device: torch.device) -> Optional[str]: ...
|
|
@staticmethod
|
|
def get_mempool_allocator(device: torch.device) -> Any: ...
|
|
@property
|
|
def rank(self) -> int: ...
|
|
@property
|
|
def world_size(self) -> int: ...
|
|
@staticmethod
|
|
def rendezvous(
|
|
tensor: torch.Tensor, group_name: str | None = None
|
|
) -> _SymmetricMemory: ...
|
|
def get_buffer(
|
|
self,
|
|
rank: int,
|
|
sizes: torch.types._size,
|
|
dtype: torch.dtype,
|
|
storage_offset: int | None = 0,
|
|
) -> torch.Tensor: ...
|
|
def get_signal_pad(
|
|
self,
|
|
rank: int,
|
|
sizes: torch.types._size = [],
|
|
dtype: torch.dtype | None = None,
|
|
storage_offset: int | None = 0,
|
|
) -> torch.Tensor: ...
|
|
def barrier(self, channel: int = 0, timeout_ms: int = 0) -> None: ...
|
|
def put_signal(
|
|
self,
|
|
dst_rank: int,
|
|
channel: int = 0,
|
|
timeout_ms: int = 0,
|
|
) -> None: ...
|
|
def wait_signal(
|
|
self,
|
|
src_rank: int,
|
|
channel: int = 0,
|
|
timeout_ms: int = 0,
|
|
) -> None: ...
|
|
def get_remote_tensor(
|
|
self,
|
|
peer: int,
|
|
sizes: torch.types._size,
|
|
dtype: torch.dtype,
|
|
) -> torch.Tensor: ...
|
|
@staticmethod
|
|
def memset32(
|
|
tensor: torch.Tensor, offset: int, val: int, count: int = 1
|
|
) -> torch.Tensor: ...
|
|
@staticmethod
|
|
def stream_write_value32(
|
|
tensor: torch.Tensor, offset: int, val: int
|
|
) -> torch.Tensor: ...
|
|
@property
|
|
def buffer_ptrs(self) -> list[int]: ...
|
|
@property
|
|
def buffer_ptrs_dev(self) -> int: ...
|
|
@property
|
|
def signal_pad_ptrs(self) -> list[int]: ...
|
|
@property
|
|
def signal_pad_ptrs_dev(self) -> int: ...
|
|
@property
|
|
def multicast_ptr(self) -> int: ...
|
|
@property
|
|
def buffer_size(self) -> int: ...
|
|
@property
|
|
def signal_pad_size(self) -> int: ...
|
|
|
|
class ProcessGroupXCCL(Backend):
|
|
class Options(Backend.Options):
|
|
def __init__(self): ...
|
|
|
|
def __init__(
|
|
self,
|
|
store: Store,
|
|
rank: int,
|
|
size: int,
|
|
options: Options,
|
|
) -> None: ...
|
|
@property
|
|
def options(self) -> Options: ... # type: ignore[override]
|
|
|
|
def _set_process_group(pg: ProcessGroup) -> None: ...
|
|
def _current_process_group() -> ProcessGroup: ...
|