Add pyrefly suppressions to torch/distributed (7/n) (#165002)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss
2025-10-09 04:08:21 +00:00
committed by PyTorch MergeBot
parent ab94a0d544
commit 7457d139c5
100 changed files with 354 additions and 24 deletions

View File

@ -372,6 +372,7 @@ class BackendConfig:
def __init__(self, backend: Backend):
"""Init."""
self.device_backend_map: dict[str, Backend] = {}
# pyrefly: ignore # bad-assignment
backend = str(backend)
if backend == Backend.UNDEFINED:
@ -392,6 +393,7 @@ class BackendConfig:
# e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend)
# pyrefly: ignore # bad-assignment
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
elif ":" in backend.lower():
# Backend specified in "device:backend" format
@ -410,6 +412,7 @@ class BackendConfig:
f"Invalid device:backend pairing: \
{device_backend_pair_str}. {backend_str_error_message}"
)
# pyrefly: ignore # bad-assignment
device, backend = device_backend_pair
if device in self.device_backend_map:
raise ValueError(
@ -1182,6 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable:
def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
# pyrefly: ignore # bad-assignment
for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
@ -1837,6 +1841,7 @@ def _get_split_source(pg):
split_from = pg._get_backend(pg.bound_device_id)
elif pg is _world.default_pg:
try:
# pyrefly: ignore # missing-attribute
split_from = pg._get_backend(torch.device("cuda"))
except RuntimeError:
# no cuda device associated with this backend
@ -1997,7 +2002,12 @@ def _new_process_group_helper(
if not is_gloo_available():
raise RuntimeError("Distributed package doesn't have Gloo built in")
backend_class = ProcessGroupGloo(
backend_prefix_store, group_rank, group_size, timeout=timeout
# pyrefly: ignore # bad-argument-type
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
backend_class.options.global_ranks_in_group = global_ranks_in_group
backend_class.options.group_name = group_name
@ -2018,6 +2028,7 @@ def _new_process_group_helper(
# default backend_options for NCCL
backend_options = ProcessGroupNCCL.Options()
backend_options.is_high_priority_stream = False
# pyrefly: ignore # bad-argument-type
backend_options._timeout = timeout
if split_from:
@ -2037,7 +2048,12 @@ def _new_process_group_helper(
# RuntimeError if is_ucc_available() returns false.
backend_class = ProcessGroupUCC(
backend_prefix_store, group_rank, group_size, timeout=timeout
# pyrefly: ignore # bad-argument-type
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
backend_type = ProcessGroup.BackendType.UCC
elif backend_str == Backend.XCCL:
@ -2046,6 +2062,7 @@ def _new_process_group_helper(
backend_options = ProcessGroupXCCL.Options()
backend_options.global_ranks_in_group = global_ranks_in_group
backend_options.group_name = group_name
# pyrefly: ignore # bad-argument-type
backend_options._timeout = timeout
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size, backend_options
@ -2070,6 +2087,7 @@ def _new_process_group_helper(
dist_backend_opts.store = backend_prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = group_size
# pyrefly: ignore # bad-argument-type
dist_backend_opts.timeout = timeout
dist_backend_opts.group_id = group_name
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
@ -2113,6 +2131,7 @@ def _new_process_group_helper(
store=backend_prefix_store,
rank=group_rank,
world_size=group_size,
# pyrefly: ignore # bad-argument-type
timeout=timeout,
)
@ -3322,6 +3341,7 @@ def gather_object(
return
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
# pyrefly: ignore # unbound-name
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
@ -3698,8 +3718,10 @@ def broadcast_object_list(
# has only one element, we can skip the copy.
if my_group_rank == group_src:
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
object_tensor = tensor_list[0]
else:
# pyrefly: ignore # unbound-name
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
@ -3828,6 +3850,7 @@ def scatter_object_list(
broadcast(max_tensor_size, group_src=group_src, group=group)
# Scatter actual serialized objects
# pyrefly: ignore # no-matching-overload
output_tensor = torch.empty(
max_tensor_size.item(), dtype=torch.uint8, device=pg_device
)
@ -4864,16 +4887,19 @@ def barrier(
if isinstance(device_ids, list):
opts.device_ids = device_ids
# use only the first device id
# pyrefly: ignore # read-only
opts.device = torch.device(device.type, device_ids[0])
elif getattr(group, "bound_device_id", None) is not None:
# Use device id from `init_process_group(device_id=...)`
opts.device = group.bound_device_id # type: ignore[assignment]
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
# pyrefly: ignore # read-only
opts.device = torch.device("cpu")
else:
# Use the current device set by the user. If user did not set any, this
# may use default device 0, causing issues like hang or all processes
# creating context on device 0.
# pyrefly: ignore # read-only
opts.device = device
if group.rank() == 0:
warnings.warn( # warn only once
@ -5004,6 +5030,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str:
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: list[int]) -> int:
# Convert list to tuple to make it hashable
# pyrefly: ignore # bad-assignment
ranks = tuple(ranks)
hash_value = hash(ranks)
# Split color must be: