mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 One more PR after this one. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab94a0d544
commit
7457d139c5
@ -372,6 +372,7 @@ class BackendConfig:
|
||||
def __init__(self, backend: Backend):
|
||||
"""Init."""
|
||||
self.device_backend_map: dict[str, Backend] = {}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
backend = str(backend)
|
||||
|
||||
if backend == Backend.UNDEFINED:
|
||||
@ -392,6 +393,7 @@ class BackendConfig:
|
||||
# e.g. "nccl", "gloo", "ucc", "mpi"
|
||||
supported_devices = Backend.backend_capability[backend.lower()]
|
||||
backend_val = Backend(backend)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
|
||||
elif ":" in backend.lower():
|
||||
# Backend specified in "device:backend" format
|
||||
@ -410,6 +412,7 @@ class BackendConfig:
|
||||
f"Invalid device:backend pairing: \
|
||||
{device_backend_pair_str}. {backend_str_error_message}"
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
device, backend = device_backend_pair
|
||||
if device in self.device_backend_map:
|
||||
raise ValueError(
|
||||
@ -1182,6 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable:
|
||||
|
||||
def _ensure_all_tensors_same_dtype(*tensors) -> None:
|
||||
last_dtype = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
|
||||
tensor_dtype = tensor.dtype
|
||||
# Mixing complex and its element type is allowed
|
||||
@ -1837,6 +1841,7 @@ def _get_split_source(pg):
|
||||
split_from = pg._get_backend(pg.bound_device_id)
|
||||
elif pg is _world.default_pg:
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
split_from = pg._get_backend(torch.device("cuda"))
|
||||
except RuntimeError:
|
||||
# no cuda device associated with this backend
|
||||
@ -1997,7 +2002,12 @@ def _new_process_group_helper(
|
||||
if not is_gloo_available():
|
||||
raise RuntimeError("Distributed package doesn't have Gloo built in")
|
||||
backend_class = ProcessGroupGloo(
|
||||
backend_prefix_store, group_rank, group_size, timeout=timeout
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
backend_class.options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_class.options.group_name = group_name
|
||||
@ -2018,6 +2028,7 @@ def _new_process_group_helper(
|
||||
# default backend_options for NCCL
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options.is_high_priority_stream = False
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_options._timeout = timeout
|
||||
|
||||
if split_from:
|
||||
@ -2037,7 +2048,12 @@ def _new_process_group_helper(
|
||||
# RuntimeError if is_ucc_available() returns false.
|
||||
|
||||
backend_class = ProcessGroupUCC(
|
||||
backend_prefix_store, group_rank, group_size, timeout=timeout
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.UCC
|
||||
elif backend_str == Backend.XCCL:
|
||||
@ -2046,6 +2062,7 @@ def _new_process_group_helper(
|
||||
backend_options = ProcessGroupXCCL.Options()
|
||||
backend_options.global_ranks_in_group = global_ranks_in_group
|
||||
backend_options.group_name = group_name
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
backend_options._timeout = timeout
|
||||
backend_class = ProcessGroupXCCL(
|
||||
backend_prefix_store, group_rank, group_size, backend_options
|
||||
@ -2070,6 +2087,7 @@ def _new_process_group_helper(
|
||||
dist_backend_opts.store = backend_prefix_store
|
||||
dist_backend_opts.group_rank = group_rank
|
||||
dist_backend_opts.group_size = group_size
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
dist_backend_opts.timeout = timeout
|
||||
dist_backend_opts.group_id = group_name
|
||||
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
|
||||
@ -2113,6 +2131,7 @@ def _new_process_group_helper(
|
||||
store=backend_prefix_store,
|
||||
rank=group_rank,
|
||||
world_size=group_size,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@ -3322,6 +3341,7 @@ def gather_object(
|
||||
return
|
||||
|
||||
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
tensor_size = object_size_list[i]
|
||||
@ -3698,8 +3718,10 @@ def broadcast_object_list(
|
||||
# has only one element, we can skip the copy.
|
||||
if my_group_rank == group_src:
|
||||
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
object_tensor = tensor_list[0]
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
object_tensor = torch.cat(tensor_list)
|
||||
else:
|
||||
object_tensor = torch.empty( # type: ignore[call-overload]
|
||||
@ -3828,6 +3850,7 @@ def scatter_object_list(
|
||||
broadcast(max_tensor_size, group_src=group_src, group=group)
|
||||
|
||||
# Scatter actual serialized objects
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
output_tensor = torch.empty(
|
||||
max_tensor_size.item(), dtype=torch.uint8, device=pg_device
|
||||
)
|
||||
@ -4864,16 +4887,19 @@ def barrier(
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
# use only the first device id
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = torch.device(device.type, device_ids[0])
|
||||
elif getattr(group, "bound_device_id", None) is not None:
|
||||
# Use device id from `init_process_group(device_id=...)`
|
||||
opts.device = group.bound_device_id # type: ignore[assignment]
|
||||
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = torch.device("cpu")
|
||||
else:
|
||||
# Use the current device set by the user. If user did not set any, this
|
||||
# may use default device 0, causing issues like hang or all processes
|
||||
# creating context on device 0.
|
||||
# pyrefly: ignore # read-only
|
||||
opts.device = device
|
||||
if group.rank() == 0:
|
||||
warnings.warn( # warn only once
|
||||
@ -5004,6 +5030,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str:
|
||||
# Takes a list of ranks and computes an integer color
|
||||
def _process_group_color(ranks: list[int]) -> int:
|
||||
# Convert list to tuple to make it hashable
|
||||
# pyrefly: ignore # bad-assignment
|
||||
ranks = tuple(ranks)
|
||||
hash_value = hash(ranks)
|
||||
# Split color must be:
|
||||
|
Reference in New Issue
Block a user