mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Added device detection to communication logging (#7398)
In `comms_logging.py`, when calling log_all and the `show_straggler` option is enabled, an all_reduce is performed across all nodes to calculate the minimum latency to find stragglers. However, the tensors on which this is performed are not sent to the configured devices. This commit adds this capability using deepspeed's abstract accelerator api. Resolves #7397 Signed-off-by: Alex Kiefer <alexkiefer51@gmail.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
@ -128,6 +128,8 @@ class CommsLogger:
|
||||
from deepspeed.utils.timer import trim_mean
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.comm.reduce_op import ReduceOp
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
if print_log:
|
||||
print(
|
||||
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
|
||||
@ -158,6 +160,7 @@ class CommsLogger:
|
||||
print(
|
||||
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}"
|
||||
)
|
||||
device = get_accelerator().current_device_name()
|
||||
for record_name in self.comms_dict.keys():
|
||||
if print_log:
|
||||
print(record_name)
|
||||
@ -165,8 +168,8 @@ class CommsLogger:
|
||||
# vals[0] is the count for each msg size
|
||||
count = vals[0]
|
||||
# vals[1] is a list of latency records for each msg size
|
||||
lats = torch.tensor(vals[1])
|
||||
min_lats = torch.tensor(vals[1])
|
||||
lats = torch.tensor(vals[1], device=device)
|
||||
min_lats = torch.tensor(vals[1], device=device)
|
||||
dist.all_reduce(min_lats, op=ReduceOp.MIN)
|
||||
total_lat = min_lats.sum().item()
|
||||
total_straggler = (lats - min_lats).sum().item()
|
||||
|
Reference in New Issue
Block a user