Added device detection to communication logging (#7398)

In `comms_logging.py`, when calling log_all and the `show_straggler`
option is enabled, an all_reduce is performed across all nodes to
calculate the minimum latency to find stragglers. However, the tensors
on which this is performed are not sent to the configured devices. This
commit adds this capability using deepspeed's abstract accelerator api.

Resolves #7397

Signed-off-by: Alex Kiefer <alexkiefer51@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
Alexander Kiefer
2025-06-28 02:56:19 -04:00
committed by GitHub
parent 6594c266c2
commit 4c687bfdac

View File

@ -128,6 +128,8 @@ class CommsLogger:
from deepspeed.utils.timer import trim_mean
import deepspeed.comm as dist
from deepspeed.comm.reduce_op import ReduceOp
from deepspeed.accelerator import get_accelerator
if print_log:
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
@ -158,6 +160,7 @@ class CommsLogger:
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}"
)
device = get_accelerator().current_device_name()
for record_name in self.comms_dict.keys():
if print_log:
print(record_name)
@ -165,8 +168,8 @@ class CommsLogger:
# vals[0] is the count for each msg size
count = vals[0]
# vals[1] is a list of latency records for each msg size
lats = torch.tensor(vals[1])
min_lats = torch.tensor(vals[1])
lats = torch.tensor(vals[1], device=device)
min_lats = torch.tensor(vals[1], device=device)
dist.all_reduce(min_lats, op=ReduceOp.MIN)
total_lat = min_lats.sum().item()
total_straggler = (lats - min_lats).sum().item()