mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE][PYFMT] migrate PYFMT for torch.{distributed,distributions}
to ruff format
(#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e160d5fd9
commit
995df34b19
@ -137,13 +137,13 @@ def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
|
||||
with _all_gather_dict_lock:
|
||||
if not worker_names:
|
||||
worker_names = _ALL_WORKER_NAMES
|
||||
assert (
|
||||
worker_name in worker_names
|
||||
), f"{worker_name} is not expected by leader."
|
||||
assert worker_name in worker_names, (
|
||||
f"{worker_name} is not expected by leader."
|
||||
)
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
assert (
|
||||
worker_name not in states.gathered_objects
|
||||
), f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||||
assert worker_name not in states.gathered_objects, (
|
||||
f"{worker_name} reported intent sequence id {sequence_id} twice. "
|
||||
)
|
||||
states.gathered_objects[worker_name] = obj
|
||||
if worker_names == set(states.gathered_objects.keys()):
|
||||
states.proceed_signal.set()
|
||||
@ -153,9 +153,9 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
with _all_gather_dict_lock:
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
|
||||
assert (
|
||||
not states.proceed_signal.is_set()
|
||||
), f"Termination signal sequence id {sequence_id} got set twice."
|
||||
assert not states.proceed_signal.is_set(), (
|
||||
f"Termination signal sequence id {sequence_id} got set twice."
|
||||
)
|
||||
states.gathered_objects = objects_map
|
||||
states.proceed_signal.set()
|
||||
|
||||
@ -202,9 +202,9 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
|
||||
function blocks until all workers have received the gathered results.
|
||||
"""
|
||||
if not worker_names:
|
||||
assert (
|
||||
_ALL_WORKER_NAMES is not None
|
||||
), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||||
assert _ALL_WORKER_NAMES is not None, (
|
||||
"`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
|
||||
)
|
||||
worker_names = _ALL_WORKER_NAMES
|
||||
leader_name = min(worker_names)
|
||||
|
||||
@ -930,8 +930,7 @@ def _get_should_profile():
|
||||
ActiveProfilerType = torch._C._profiler.ActiveProfilerType
|
||||
return (
|
||||
torch.autograd._profiler_enabled()
|
||||
and torch._C._autograd._profiler_type()
|
||||
== ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user