Files
DeepSpeed/deepspeed/utils/comms_logging.py
Alexander Kiefer f485e1369e Improvements to Communication Logger (#7404)
These changes add
- ability to return communication logs as dictionaries, rather than only
printing to stdout
- convenience helper functions for getting information about current
logs
- ability to clear existing log operations
- additional documentation for logging operations

These address points made in #7403

---------

Signed-off-by: Alex Kiefer <alexkiefer51@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2025-07-09 13:35:28 -04:00

379 lines
16 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
from deepspeed.utils import log_dist
def get_caller_func(frame=3):
import sys
return sys._getframe(frame).f_code.co_name
def print_rank_0(message):
import deepspeed.comm as dist
if dist.get_rank() == 0:
print(message)
# Helper function to pretty-print message sizes
def convert_size(size_bytes):
if size_bytes == 0:
return "0B"
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return "%s %s" % (s, size_name[i])
# Helper function to calculate algbw and busbw.
# See https://gist.github.com/jeffra/b5e80466b4c86be00ea3b6f130fb7a36 and https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
def calc_bw_log(comm_op, size, duration):
import deepspeed.comm as dist
n = dist.get_world_size()
tput = 0
busbw = 0
if comm_op == "all_to_all_single":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather" or comm_op == "all_gather_into_tensor" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_tensor":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_reduce" or comm_op == "all_reduce_coalesced" or comm_op == "inference_all_reduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier":
tput = (size / duration)
busbw = tput
else:
print_rank_0("wrong comm_op specified") # noqa: F821
exit(0)
# convert to Gbps
tput *= 8
busbw *= 8
tput /= 1e6
busbw /= 1e6
return tput, busbw
class CommsLogger:
def __init__(self):
from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT
self.comms_dict = {}
self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT
self.debug = COMMS_LOGGER_DEBUG_DEFAULT
self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT
self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT
self.enabled = COMMS_LOGGER_ENABLED_DEFAULT
def configure(self, comms_config):
self.enabled = comms_config.comms_logger_enabled
if self.enabled:
self.verbose = comms_config.comms_logger.verbose
self.debug = comms_config.comms_logger.debug
self.prof_ops = comms_config.comms_logger.prof_ops
self.prof_all = comms_config.comms_logger.prof_all
# There are three settings for the op profiler:
# - Global profiling (profile all comms)
# - Op-type profiling (e.g. profile all all_reduce comms)
# - Op profiling (e.g. profile a specific all_reduce op)
def start_profiling_comms(self):
self.prof_all = True
def stop_profiling_comms(self):
self.prof_all = True
# E.g. start_profiling_op('all_reduce')
def start_profiling_op(self, op_name_list):
self.prof_ops = list(set(self.prof_ops) | set(op_name_list))
def stop_profiling_op(self, op_name_list):
self.prof_ops = [op for op in self.prof_ops if op not in op_name_list]
# Add log entry
def append(self, raw_name, record_name, latency, msg_size):
algbw, busbw = calc_bw_log(raw_name, msg_size, latency)
if record_name in self.comms_dict.keys():
# If this comm_op has already been logged with this message size, just add to existing record
if msg_size in self.comms_dict[record_name].keys():
self.comms_dict[record_name][msg_size][0] += 1
self.comms_dict[record_name][msg_size][1].append(latency)
self.comms_dict[record_name][msg_size][2].append(algbw)
self.comms_dict[record_name][msg_size][3].append(busbw)
# If this is a new message size for this comm_op, add new record under existing comm_op
else:
self.comms_dict[record_name][msg_size] = [1, [latency], [algbw], [busbw]]
else:
# Create entirely new record
self.comms_dict[record_name] = {msg_size: [1, [latency], [algbw], [busbw]]}
# If verbose, print every comm op
# TODO: Add to tensorboard
if self.verbose:
log_str = f"comm op: {record_name} | time (ms): {latency:.2f} | msg size: {convert_size(msg_size)} | algbw (Gbps): {algbw:.2f} | busbw (Gbps): {busbw:.2f}"
log_dist(log_str, [0])
def get_raw_data(self):
"""
Get the raw communication data dictionary.
Returns:
dict: Raw communication data in format {record_name: {msg_size: [count, [latencies], [algbws], [busbws]]}}
"""
return self.comms_dict.copy()
def has_data(self):
"""
Check if any communication data has been logged.
Returns:
bool: True if communication data exists, False otherwise
"""
return len(self.comms_dict) > 0
def reset_data(self):
"""
Clear all logged communication data.
"""
self.comms_dict.clear()
def get_operation_names(self):
"""
Get list of all logged communication operation names.
Returns:
list: List of operation names that have been logged
"""
return list(self.comms_dict.keys())
def get_total_operations(self):
"""
Get total number of communication operations logged across all types.
Returns:
int: Total count of operations
"""
total = 0
for record_name in self.comms_dict:
for msg_size in self.comms_dict[record_name]:
total += self.comms_dict[record_name][msg_size][0] # count is at index 0
return total
def get_operation_summary(self, operation_name):
"""
Get summary statistics for a specific operation type.
Args:
operation_name (str): Name of the communication operation
Returns:
dict: Summary statistics for the operation, or None if operation not found
"""
if operation_name not in self.comms_dict:
return None
from deepspeed.utils.timer import trim_mean
# Create a snapshot to avoid concurrent modification issues
op_data = self.comms_dict[operation_name].copy()
summary = {}
for msg_size, vals in op_data.items():
count = vals[0]
total_lat = sum(vals[1])
avg_lat = trim_mean(vals[1], 0.1)
avg_algbw = trim_mean(vals[2], 0.1)
avg_busbw = trim_mean(vals[3], 0.1)
summary[msg_size] = {
"count": count,
"total_latency_ms": total_lat,
"avg_latency_ms": avg_lat,
"tput_avg_gbps": avg_algbw,
"busbw_avg_gbps": avg_busbw,
"msg_size_bytes": msg_size,
"msg_size_str": convert_size(msg_size)
}
return summary
# Print summary at end of iteration, epoch, or training
def log_all(self, print_log=True, show_straggler=False, return_dict=False):
"""
Print and/or return communication operation statistics.
Args:
print_log (bool, optional): Whether to print the summary to console. Defaults to True.
show_straggler (bool, optional): Whether to include straggler effect analysis. Defaults to False.
return_dict (bool, optional): Whether to return statistics as a dictionary. Defaults to False.
Returns:
dict or None: If return_dict=True, returns a comprehensive dictionary with the following structure:
{
"summary": {
"operation_name": {
message_size_bytes: {
"count": int, # Number of operations with this message size
"total_latency_ms": float, # Sum of all latencies for this message size
"avg_latency_ms": float, # Average latency (outliers trimmed)
"tput_avg_gbps": float, # Average algorithmic bandwidth in Gbps
"busbw_avg_gbps": float, # Average bus bandwidth in Gbps
"msg_size_bytes": int, # Message size in bytes
"msg_size_str": str # Human-readable message size (e.g., "678.86 MB")
}
}
},
"straggler_analysis": { # Only present if show_straggler=True
"operation_name": {
message_size_bytes: {
"count": int, # Number of operations
"total_comm_lat_ms": float, # Total communication latency (min across ranks)
"total_straggler_ms": float, # Total straggler effect
"avg_comm_lat_ms": float, # Average communication latency
"avg_straggler_ms": float, # Average straggler effect
"msg_size_bytes": int, # Message size in bytes
"msg_size_str": str # Human-readable message size
}
}
} if show_straggler else None,
"metadata": {
"world_size": int, # Number of processes in distributed setup
"rank": int, # Current process rank
"timestamp": str # ISO format timestamp when log_all was called
}
}
Returns None if return_dict=False.
Note:
- Statistics use trimmed mean (10% trimmed from both ends) to remove outliers
- Straggler analysis requires distributed communication and may impact performance
- All bandwidth values are in Gbps (Gigabits per second)
- Latency values are in milliseconds
"""
import torch
from deepspeed.utils.timer import trim_mean
import deepspeed.comm as dist
from deepspeed.comm.reduce_op import ReduceOp
from deepspeed.accelerator import get_accelerator
from datetime import datetime
# Create a snapshot of the dictionary to avoid concurrent modification issues
# This prevents "dictionary changed size during iteration" errors when
# communication operations are happening in other threads
comms_dict_snapshot = self.comms_dict.copy()
# Initialize return dictionary structure
result_dict = {
"summary": {},
"straggler_analysis": None,
"metadata": {
"world_size": dist.get_world_size() if dist.is_initialized() else 1,
"rank": dist.get_rank() if dist.is_initialized() else 0,
"timestamp": datetime.now().isoformat()
}
} if return_dict else None
if print_log:
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}"
)
for record_name in comms_dict_snapshot.keys():
if print_log:
print(record_name)
# Initialize operation entry in result dict
if return_dict:
result_dict["summary"][record_name] = {}
for msg_size, vals in sorted(comms_dict_snapshot[record_name].items()):
# vals[0] is the count for each msg size
count = vals[0]
# vals[1] is a list of latency records for each msg size
total_lat = sum(vals[1])
# vals[2] and vals[3] are the lists of algbw and busbw, respectively
# Get rid of outliers when we print
avg_lat = trim_mean(vals[1], 0.1)
avg_algbw = trim_mean(vals[2], 0.1)
avg_busbw = trim_mean(vals[3], 0.1)
# Store data in result dictionary
if return_dict:
result_dict["summary"][record_name][msg_size] = {
"count": count,
"total_latency_ms": total_lat,
"avg_latency_ms": avg_lat,
"tput_avg_gbps": avg_algbw,
"busbw_avg_gbps": avg_busbw,
"msg_size_bytes": msg_size,
"msg_size_str": convert_size(msg_size)
}
if print_log:
print(
f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}"
)
if show_straggler:
if return_dict:
result_dict["straggler_analysis"] = {}
if print_log:
print("_______________________________")
print("Breakdown with straggler effect")
print("-------------------------------")
print(
f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total comm lat(ms)': <20}{'Total straggler(ms)': <20}{'Avg comm lat(ms)': <20}{'Avg straggler(ms)': <20}"
)
device = get_accelerator().current_device_name()
for record_name in comms_dict_snapshot.keys():
if print_log:
print(record_name)
# Initialize operation entry in straggler dict
if return_dict:
result_dict["straggler_analysis"][record_name] = {}
for msg_size, vals in sorted(comms_dict_snapshot[record_name].items()):
# vals[0] is the count for each msg size
count = vals[0]
# vals[1] is a list of latency records for each msg size
lats = torch.tensor(vals[1], device=device)
min_lats = torch.tensor(vals[1], device=device)
dist.all_reduce(min_lats, op=ReduceOp.MIN)
total_lat = min_lats.sum().item()
total_straggler = (lats - min_lats).sum().item()
avg_lat = trim_mean(min_lats.tolist(), 0.1)
avg_straggler = trim_mean((lats - min_lats).tolist(), 0.1)
# Store straggler data in result dictionary
if return_dict:
result_dict["straggler_analysis"][record_name][msg_size] = {
"count": count,
"total_comm_lat_ms": total_lat,
"total_straggler_ms": total_straggler,
"avg_comm_lat_ms": avg_lat,
"avg_straggler_ms": avg_straggler,
"msg_size_bytes": msg_size,
"msg_size_str": convert_size(msg_size)
}
if print_log:
print(
f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{total_straggler: <20.2f}{avg_lat: <20.2f}{avg_straggler: <20.2f}"
)
# Return the dictionary if requested
return result_dict if return_dict else None