[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
This commit is contained in:
Xuehai Pan
2025-02-28 11:10:58 +08:00
committed by PyTorch MergeBot
parent 4e160d5fd9
commit 995df34b19
143 changed files with 920 additions and 774 deletions

View File

@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
with _all_gather_dict_lock:
if not worker_names:
worker_names = _ALL_WORKER_NAMES
assert (
worker_name in worker_names
), f"{worker_name} is not expected by leader."
assert worker_name in worker_names, (
f"{worker_name} is not expected by leader."
)
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
worker_name not in states.gathered_objects
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
assert worker_name not in states.gathered_objects, (
f"{worker_name} reported intent sequence id {sequence_id} twice. "
)
states.gathered_objects[worker_name] = obj
if worker_names == set(states.gathered_objects.keys()):
states.proceed_signal.set()
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id]
assert (
not states.proceed_signal.is_set()
), f"Termination signal sequence id {sequence_id} got set twice."
assert not states.proceed_signal.is_set(), (
f"Termination signal sequence id {sequence_id} got set twice."
)
states.gathered_objects = objects_map
states.proceed_signal.set()
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
function blocks until all workers have received the gathered results.
"""
if not worker_names:
assert (
_ALL_WORKER_NAMES is not None
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
assert _ALL_WORKER_NAMES is not None, (
"`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
)
worker_names = _ALL_WORKER_NAMES
leader_name = min(worker_names)
@ -930,8 +930,7 @@ def _get_should_profile():
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
return (
torch.autograd._profiler_enabled()
and torch._C._autograd._profiler_type()
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
)