Files
DeepSpeed/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Nir Sonnenschein 1a8ad24f0d fix issues raised by Coverity scans (#7431)
This commit combines fixes for 37 potential code issues found in
Coverity scans.
the issues include but are not limited to potential access to
uninitialized variables, dead and redundant code.
We understand that reviewing such a commit can be difficult and will be
happy to help with any questions or changes required.

---------

Signed-off-by: Nir Sonnenschein <nsonnenschein@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
2025-08-02 12:16:10 -04:00

886 lines
48 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import sys
from collections import defaultdict
import csv
import time
from multiprocessing import Process, Manager
import numpy as np
import torch
from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset
import deepspeed.comm as dist
from deepspeed.utils import logger
from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset, valid_dtypes
from deepspeed.runtime.data_pipeline.data_sampling.utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype
class DataAnalyzer(object):
def __init__(self,
dataset,
num_workers=1,
worker_id=0,
num_threads=1,
num_threads_reduce=1,
specific_threads=[],
batch_size=1,
metric_names=[],
metric_functions=[],
metric_types=[],
metric_dtypes=[],
save_path="./",
collate_fn=None,
custom_map_init=None,
custom_map_update=None,
custom_map_finalize=None,
custom_reduce=None,
sample_indices=None):
super().__init__()
self.dataset = dataset
self.num_workers = num_workers
self.worker_id = worker_id
self.num_threads = num_threads
self.num_threads_reduce = num_threads_reduce
self.specific_threads = specific_threads
self.batch_size = batch_size
self.metric_names = metric_names
self.metric_functions = metric_functions
self.metric_types = metric_types
self.metric_dtypes = metric_dtypes
self.save_path = save_path
self.collate_fn = collate_fn
self.custom_map_init = custom_map_init
self.custom_map_update = custom_map_update
self.custom_map_finalize = custom_map_finalize
self.custom_reduce = custom_reduce
self.sample_indices = sample_indices
def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id):
metric_results = []
for m_idx in range(len(metric_names)):
metric_name, metric_type, metric_dtype = metric_names[m_idx], \
metric_types[m_idx], metric_dtypes[m_idx]
assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}"
metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/"
os.makedirs(metric_save_path, exist_ok=True)
if metric_type == 'single_value_per_sample':
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_dtype)
metric_to_sample_fname = f"{metric_save_path}/{metric_name}_metric_to_sample"
os.system(f"rm -rf {metric_to_sample_fname}*")
metric_to_sample_dict = defaultdict(list)
metric_results.append({
"sample_to_metric_fname": sample_to_metric_fname,
"sample_to_metric_builder": sample_to_metric_builder,
"metric_to_sample_fname": metric_to_sample_fname,
"metric_to_sample_dict": metric_to_sample_dict
})
elif metric_type == 'accumulate_value_over_samples':
metric_value = None
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname})
return metric_results
def update_metric_results(self,
data,
metric_types,
metric_dtypes,
metric_functions,
metric_results,
batch_start_idx=0):
for m_idx in range(len(metric_types)):
metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \
metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx]
metric_values = metric_function(data)
assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \
"metric_function must return a tensor or array"
assert metric_values.dtype == metric_dtype, \
f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}"
if isinstance(metric_values, np.ndarray):
metric_values = torch.from_numpy(metric_values)
if metric_type == 'single_value_per_sample':
for row in range(metric_values.size()[0]):
sample_idx = batch_start_idx + row # sample idx following dataset iteration order
if isinstance(data, dict) and 'index' in data: # Megatron use case, idx provided in 'index' field
sample_idx = data['index'][row][0].item()
elif self.sample_indices is not None: # user defined shuffling of indices
sample_idx = self.sample_indices[sample_idx]
metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1))
metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx)
for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 100:
metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f)
writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples':
if metric_result["metric_value"] is None:
metric_result["metric_value"] = metric_values
else:
metric_result["metric_value"].add_(metric_values)
def finalize_metric_results(self, metric_types, metric_dtypes, metric_results):
for m_idx in range(len(metric_types)):
metric_type, metric_dtype, metric_result = metric_types[m_idx], \
metric_dtypes[m_idx], metric_results[m_idx]
if metric_type == 'single_value_per_sample':
metric_fname = metric_result["sample_to_metric_fname"]
close_mmap_dataset_builder(metric_result["sample_to_metric_builder"], metric_fname)
for m_value in metric_result["metric_to_sample_dict"]:
if len(metric_result["metric_to_sample_dict"][m_value]) > 0:
metric_fname = metric_result["metric_to_sample_fname"]
with open(f"{metric_fname}_{m_value}.csv", 'a') as f:
writer = csv.writer(f)
writer.writerows([metric_result["metric_to_sample_dict"][m_value]])
metric_result["metric_to_sample_dict"][m_value] = []
elif metric_type == 'accumulate_value_over_samples':
if metric_result["metric_value"] is not None:
metric_value_builder = create_mmap_dataset_builder(metric_result["metric_value_fname"],
metric_dtype)
metric_value_builder.add_item(metric_result["metric_value"].reshape(-1))
close_mmap_dataset_builder(metric_value_builder, metric_result["metric_value_fname"])
def run_map_helper(self, thread_id):
start_idx, end_idx = self.thread_splits[thread_id][0], \
self.thread_splits[thread_id][1]
logger.info(f"worker {self.worker_id} thread {thread_id}: start working " \
f"on data subset {start_idx} to {end_idx}")
thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False)
iterator = iter(
DataLoader(thread_dataset,
batch_sampler=sampler,
num_workers=0,
collate_fn=self.collate_fn,
pin_memory=False))
if self.custom_map_init is None:
metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types,
self.metric_dtypes, self.save_path, self.worker_id)
else:
metric_results = self.custom_map_init(thread_id, self.metric_names, self.metric_types, self.metric_dtypes,
self.save_path, self.worker_id)
total_sample = len(thread_dataset)
processed_sample = 0
start = time.time()
while True:
try:
data = next(iterator)
batch_start_idx = start_idx + processed_sample
if self.custom_map_update is None:
self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions,
metric_results, batch_start_idx)
else:
self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions,
metric_results, batch_start_idx)
processed_sample += len(data)
duration = (time.time() - start) / 3600.0
remain_duration = duration * total_sample / processed_sample - duration
logger.info(
f"worker {self.worker_id} thread {thread_id}: {processed_sample} " \
f"out of {total_sample} processed in {duration:.2f} hr, " \
f"estimated to finish in {remain_duration:.2f} hr")
except StopIteration:
logger.info(f"worker {self.worker_id} thread {thread_id}: reach end of file")
break
if self.custom_map_finalize is None:
self.finalize_metric_results(self.metric_types, self.metric_dtypes, metric_results)
else:
self.custom_map_finalize(self.metric_types, self.metric_dtypes, metric_results)
logger.info(f"worker {self.worker_id} thread {thread_id}: finished")
def run_map(self):
self.worker_splits, self.thread_splits = split_dataset(self.dataset, self.num_workers, self.worker_id,
self.num_threads)
if len(self.specific_threads) > 0:
threads_to_run = self.specific_threads
else:
threads_to_run = list(range(self.num_threads))
if self.num_threads > 1:
p = []
for thread in threads_to_run:
p.append(Process(target=self.run_map_helper, args=(thread, )))
p[thread].start()
for thread in threads_to_run:
p[thread].join()
else:
assert self.num_threads == 1
self.run_map_helper(0)
def get_metric_value_percentiles(self, metric_name, num_sample_per_value, total_num_samples):
logger.info(f"Checking the value percentiles of metric {metric_name}...")
processed_samples = 0
current_percentile = 5
for key in sorted(num_sample_per_value.keys()):
processed_samples += num_sample_per_value[key]
if processed_samples >= total_num_samples * current_percentile / 100.0:
logger.info(f"Metric {metric_name} {current_percentile}th percentile: {key}")
current_percentile += 5
def merge_gather_map_stats(self, num_workers, num_threads, num_threads_reduce, t_idx_reduce, metric_save_path,
metric_name, return_dict):
results = []
for w_idx in range(num_workers):
for t_idx in range(num_threads):
if (w_idx * num_threads + t_idx) % num_threads_reduce == t_idx_reduce:
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_sample_to_metric = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
unique_v = list(np.unique(w_sample_to_metric))
sample_to_metric_count = len(w_sample_to_metric)
logger.info(f"Finished gathering map stats from worker {w_idx} thread {t_idx}.")
results.append([unique_v, sample_to_metric_count])
return_dict[t_idx_reduce] = results
def merge_sample_to_metric(self, t_idx_reduce, metric_save_path, metric_name, metric_value_dtype,
map_worker_thread):
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
for w_t in map_worker_thread:
w_metric_save_path = f"{metric_save_path}/worker{w_t[0]}_thread{w_t[1]}/"
w_sample_to_metric_fname = f"{w_metric_save_path}/{metric_name}_sample_to_metric"
w_data = MMapIndexedDataset(w_sample_to_metric_fname, skip_warmup=True)
for row in range(len(w_data)):
sample_to_metric_builder.add_item(torch.tensor(w_data[row].astype(np.int64), dtype=torch.long))
logger.info(f"Finished merge_sample_to_metric from worker {w_t[0]} thread {w_t[1]}.")
close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
def merge_metric_to_sample(self, t_idx_reduce, metric_save_path, metric_name, sample_idx_dtype, metric_value_dtype,
unique_metric_values, num_workers, num_threads):
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
for unique_v in unique_metric_values:
samples = []
for w_idx in range(num_workers):
for t_idx in range(num_threads):
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_metric_to_sample_fname = f"{w_metric_save_path}/{metric_name}_metric_to_sample_{unique_v}.csv"
if os.path.isfile(w_metric_to_sample_fname):
with open(w_metric_to_sample_fname, 'r') as f:
datareader = csv.reader(f)
for row in datareader:
samples += [int(x) for x in row]
index_to_sample_builder.add_item(torch.tensor(samples, dtype=torch.long))
index_to_metric_builder.add_item(torch.tensor([unique_v], dtype=torch.long))
logger.info(f"Finished reducing metric {metric_name} value {unique_v}.")
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
def merge_map_results(self, dataset, metric_names, metric_types, save_path, num_workers, num_threads,
num_threads_reduce):
total_num_samples = len(dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
logger.info(
f"Total number of data samples: {total_num_samples}. Will use {sample_idx_dtype} to store the sample indexes."
)
for m_idx in range(len(metric_names)):
metric_name, metric_type = metric_names[m_idx], metric_types[m_idx]
if metric_type == 'single_value_per_sample':
metric_save_path = f"{save_path}/{metric_name}/"
sample_to_metric_count = 0
unique_metric_values = set([])
manager = Manager()
return_dict = manager.dict()
p = []
for t_idx_reduce in range(num_threads_reduce):
p.append(
Process(target=self.merge_gather_map_stats,
args=(
num_workers,
num_threads,
num_threads_reduce,
t_idx_reduce,
metric_save_path,
metric_name,
return_dict,
)))
p[t_idx_reduce].start()
for t_idx_reduce in range(num_threads_reduce):
p[t_idx_reduce].join()
for t_idx_reduce in range(num_threads_reduce):
results = return_dict[t_idx_reduce]
for res in results:
unique_metric_values = unique_metric_values.union(set(res[0]))
sample_to_metric_count += res[1]
value_max = max(unique_metric_values)
value_min = min(unique_metric_values)
assert sample_to_metric_count == total_num_samples, "The number of samples in map result files are not correct. It's possible that some map worker didn't finish successfully."
metric_value_dtype = find_fit_int_dtype(value_min, value_max)
logger.info(
f"Metric {metric_name} has values between {value_min} and {value_max}. Will use {metric_value_dtype} to store the metric values."
)
# sample_to_metric
map_worker_thread = []
for w_idx in range(num_workers):
for t_idx in range(num_threads):
map_worker_thread.append([w_idx, t_idx])
thread_splits = split_index(0, len(map_worker_thread), num_threads_reduce)
p = []
for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
p.append(
Process(target=self.merge_sample_to_metric,
args=(
t_idx_reduce,
metric_save_path,
metric_name,
metric_value_dtype,
map_worker_thread[start_idx:end_idx],
)))
p[t_idx_reduce].start()
for t_idx_reduce in range(num_threads_reduce):
p[t_idx_reduce].join()
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
sample_to_metric_builder = create_mmap_dataset_builder(sample_to_metric_fname, metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce):
chunk_fname = f"{metric_save_path}/{metric_name}_sample_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_fname}")
sample_to_metric_builder.merge_file_(chunk_fname)
close_mmap_dataset_builder(sample_to_metric_builder, sample_to_metric_fname)
sample_to_metric = MMapIndexedDataset(sample_to_metric_fname, skip_warmup=True)
assert len(sample_to_metric) == total_num_samples
# metric_to_sample
unique_metric_values = list(sorted(unique_metric_values))
thread_splits = split_index(0, len(unique_metric_values), num_threads_reduce)
p = []
for t_idx_reduce in range(num_threads_reduce):
start_idx, end_idx = thread_splits[t_idx_reduce][0], thread_splits[t_idx_reduce][1]
p.append(
Process(target=self.merge_metric_to_sample,
args=(
t_idx_reduce,
metric_save_path,
metric_name,
sample_idx_dtype,
metric_value_dtype,
unique_metric_values[start_idx:end_idx],
num_workers,
num_threads,
)))
p[t_idx_reduce].start()
for t_idx_reduce in range(num_threads_reduce):
p[t_idx_reduce].join()
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample"
index_to_sample_builder = create_mmap_dataset_builder(index_to_sample_fname, sample_idx_dtype)
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric"
index_to_metric_builder = create_mmap_dataset_builder(index_to_metric_fname, metric_value_dtype)
for t_idx_reduce in range(num_threads_reduce):
chunk_is_fname = f"{metric_save_path}/{metric_name}_index_to_sample_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_is_fname}")
index_to_sample_builder.merge_file_(chunk_is_fname)
chunk_im_fname = f"{metric_save_path}/{metric_name}_index_to_metric_thread{t_idx_reduce}"
logger.info(f"Merging file {chunk_im_fname}")
index_to_metric_builder.merge_file_(chunk_im_fname)
close_mmap_dataset_builder(index_to_sample_builder, index_to_sample_fname)
close_mmap_dataset_builder(index_to_metric_builder, index_to_metric_fname)
num_sample_per_value = DataAnalyzer.output_index_to_sample_percentile(
index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path, total_num_samples,
sample_idx_dtype)
self.get_metric_value_percentiles(metric_name, num_sample_per_value, total_num_samples)
elif metric_type == 'accumulate_value_over_samples':
metric_save_path = f"{save_path}/{metric_name}/"
metric_value = None
for w_idx in range(num_workers):
for t_idx in range(num_threads):
w_metric_save_path = f"{metric_save_path}/worker{w_idx}_thread{t_idx}/"
w_metric_value_fname = f"{w_metric_save_path}/{metric_name}_metric_value"
w_metric_value = MMapIndexedDataset(w_metric_value_fname, skip_warmup=True)
if metric_value is None:
metric_value = np.copy(w_metric_value[0])
else:
metric_value += np.copy(w_metric_value[0])
value_max = int(max(metric_value))
value_min = int(min(metric_value))
metric_value_dtype = find_fit_int_dtype(value_min, value_max)
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
metric_value_builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
metric_value_builder.add_item(torch.tensor(metric_value.astype(np.int64), dtype=torch.long))
close_mmap_dataset_builder(metric_value_builder, metric_value_fname)
@staticmethod
def output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname, metric_name, metric_save_path,
total_num_samples, sample_idx_dtype):
""" read index_to_metric and index_to_sample files and write distribution to index_to_sample_percentage_merged """
num_sample_per_value = {}
index_to_sample = MMapIndexedDataset(index_to_sample_fname, skip_warmup=True)
index_to_metric = MMapIndexedDataset(index_to_metric_fname, skip_warmup=True)
index_to_sample_merged_fname = f"{metric_save_path}/{metric_name}_index_to_sample_percentile_merged"
index_to_sample_merged_builder = create_mmap_dataset_builder(index_to_sample_merged_fname, sample_idx_dtype)
for v_idx in range(len(index_to_sample)):
if v_idx > 0:
assert index_to_metric[v_idx] > index_to_metric[v_idx - 1]
num_sample_per_value[index_to_metric[v_idx][0]] = len(index_to_sample[v_idx])
assert sum(list(num_sample_per_value.values())) == total_num_samples
merge_step = max(1, len(index_to_sample) // 100)
for v_idx in range(0, len(index_to_sample), merge_step):
merged_samples = np.copy(
np.concatenate(index_to_sample[v_idx:min(len(index_to_sample), (v_idx + merge_step))], axis=None))
index_to_sample_merged_builder.add_item(torch.tensor(merged_samples.astype(np.int64), dtype=torch.long))
logger.info(f"Finished merging index_to_sample {v_idx} to {v_idx+merge_step}.")
close_mmap_dataset_builder(index_to_sample_merged_builder, index_to_sample_merged_fname)
return num_sample_per_value
def run_reduce(self):
if self.custom_reduce is None:
self.merge_map_results(self.dataset, self.metric_names, self.metric_types, self.save_path,
self.num_workers, self.num_threads, self.num_threads_reduce)
else:
self.custom_reduce(self.dataset, self.metric_names, self.metric_types, self.save_path, self.num_workers,
self.num_threads, self.num_threads_reduce)
def run_map_reduce(self, comm_group=None):
self.run_map()
# wait for the mapping operation, where all nodes outputs their own (partial) result files
dist.barrier(group=comm_group)
if self.worker_id == 0:
self.run_reduce()
# wait for the reduce, where rank 0 merges all (partial) files. Dataset can then be used by all nodes.
dist.barrier(group=comm_group)
class DistributedDataAnalyzer(object):
def __init__(
self,
dataset,
num_workers=1,
num_threads=1,
worker_id=0,
batch_size=1,
metric_names=[],
metric_functions=[],
metric_types=[],
save_path="./",
collate_fn=None,
device='cuda',
comm_group=None,
sample_indices=None,
) -> None:
self.dataset = dataset
self.batch_size = batch_size
self.metric_names = metric_names
self.metric_functions = metric_functions
self.metric_types = metric_types
self.save_path = save_path
self.collate_fn = collate_fn
self.device = device
self.sample_indices = sample_indices
self.num_threads = num_threads
self.worker_id = worker_id
if not dist.is_initialized():
dist.init_distributed()
# comm_group and worker_id+num_workers are mutually exclusive
self.comm_group = comm_group
if self.comm_group is None:
# self.comm_group = deepspeed.utils.groups._clone_world_group()
self.num_workers = num_workers
self.worker_id = worker_id
else:
self.num_workers = self.comm_group.size()
self.worker_id = self.comm_group.rank()
if self.worker_id == 0:
logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.")
def run_map_helper(self, thread_id=0, metric_queues=None):
thread_start_idx, thread_end_idx = self.thread_splits[thread_id][0], self.thread_splits[thread_id][1]
worker_dataset = Subset(self.dataset, list(range(thread_start_idx, thread_end_idx)))
sampler = BatchSampler(SequentialSampler(worker_dataset), batch_size=self.batch_size, drop_last=False)
dataloader = DataLoader(dataset=worker_dataset,
batch_sampler=sampler,
num_workers=0,
collate_fn=self.collate_fn,
pin_memory=False)
# set initial results list
metric_results = []
for metric_type in self.metric_types:
assert metric_type in ['single_value_per_sample', 'accumulate_value_over_samples'], \
f"metric_type {metric_type} not implemented."
metric_results.append([] if metric_type == 'single_value_per_sample' else None)
# iterate dataloader and store metric results
batch_start_idx = thread_start_idx
for data in dataloader:
for m_idx in range(len(self.metric_names)):
metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx]
metric_values = metric_function(data)
assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \
"metric_function must return a tensor or array"
if isinstance(metric_values, np.ndarray):
metric_values = torch.from_numpy(metric_values)
assert metric_values.dtype in valid_dtypes, \
f"metric_function result dtype {metric_values.dtype} not supported. Supported dtypes {valid_dtypes}"
if metric_type == 'single_value_per_sample':
for row in range(metric_values.size()[0]):
value = metric_values[row].item()
sample_idx = batch_start_idx + row # sample idx following dataset iteration order
if isinstance(data, dict) and 'index' in data: # Megatron use case
sample_idx = data['index'][row][0].item()
elif self.sample_indices is not None: # user defined shuffling of indices
sample_idx = self.sample_indices[sample_idx]
metric_results[m_idx].append((value, sample_idx))
elif metric_type == 'accumulate_value_over_samples':
if metric_results[m_idx] is None:
metric_results[m_idx] = metric_values
else:
metric_results[m_idx].add_(metric_values)
batch_start_idx += len(data)
if self.num_threads == 1:
return metric_results
# copy metric_results to the shared queue
assert metric_queues
for m_idx in range(len(self.metric_names)):
results = metric_results[m_idx]
if torch.is_tensor(results):
results = results.item() if results.dim() == 0 else results.tolist()
try:
metric_queues[m_idx].put((thread_id, results))
except Exception as e:
logger.error(f"Error putting metric results to queue: {e}")
sys.exit(1)
def run_map_reduce(self):
# setup individual dataloaders
self.worker_splits, self.thread_splits = split_dataset(self.dataset,
self.num_workers,
self.worker_id,
num_threads=self.num_threads)
node_start_idx, node_end_idx = self.worker_splits[self.worker_id]
logger.info(f"worker {self.worker_id} working on data subset {node_start_idx} to {node_end_idx}.")
if self.num_threads in [0, 1, None]:
metric_results = self.run_map_helper()
metric_results = [torch.tensor(m).to(self.device) for m in metric_results]
else:
# create a shared queue of results per metric to be populated by individual threads
with Manager() as manager:
metric_queues = [manager.Queue() for _ in self.metric_names]
threads = [
Process(target=self.run_map_helper, args=(t, metric_queues)) for t in range(self.num_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# gather results from shared queues into metric_results
metric_results = [None for _ in self.metric_names]
for m_idx, (queue, metric_type) in enumerate(zip(metric_queues, self.metric_types)):
while not queue.empty():
t_idx, t_results = queue.get()
t_start_idx, t_end_idx = self.thread_splits[t_idx]
if t_start_idx >= t_end_idx: # no results from this thread
continue #corner case for small datasets and high thread count
t_results = torch.tensor(t_results)
if metric_type == 'single_value_per_sample':
# add thread results to the metric_results list, ordered by thread idx
if metric_results[m_idx] is None: # initialize if needed
metric_results[m_idx] = torch.zeros(node_end_idx - node_start_idx,
t_results.size(1)).to(self.device)
metric_results[m_idx][t_start_idx - node_start_idx:t_end_idx - node_start_idx] = t_results
else:
if metric_results[m_idx] is None: # initialize if needed
metric_results[m_idx] = torch.zeros(t_results.size()).to(self.device)
metric_results[m_idx].add_(t_results)
# compute dtype for sample ids
total_num_samples = len(self.dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
logger.info(f"Total number of data samples: {total_num_samples}.")
logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.")
for m_idx in range(len(self.metric_names)):
metric_values, metric_name, metric_type = \
metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx]
metric_save_path = f"{self.save_path}/{metric_name}/"
os.makedirs(metric_save_path, exist_ok=True)
if metric_type == 'single_value_per_sample':
# Compute sample and metric value dtypes based on range
values, samples = metric_values[:, 0], metric_values[:, 1]
value_min, value_max = Dist.min_max(values, self.comm_group)
sample_min, sample_max = Dist.min_max(samples, self.comm_group)
metric_value_dtype = find_fit_int_dtype(value_min, value_max)
sample_value_dtype = find_fit_int_dtype(sample_min, sample_max)
# sample_to_metric maps sample ids to metric values, as a list of metric values
sample_to_metric_fname = f"{metric_save_path}/{metric_name}_sample_to_metric"
values = [torch.tensor([x]) for x in metric_values[:, 0]]
self.file_write_ordered(values, sample_to_metric_fname, metric_value_dtype)
# distributed sorting by values, gives an ordered disjoint subset of keys on nodes
metric_values = Dist.sample_sort(metric_values, self.comm_group, self.num_workers)
metric_to_samples_dict = {}
if len(metric_values) > 0:
for value, sample in metric_values:
if value.item() not in metric_to_samples_dict:
metric_to_samples_dict[value.item()] = []
metric_to_samples_dict[value.item()].append(sample.item())
# index_to_metric and index_to_sample serialize a dicitonary from metric to samples
# index_to_metric stores a key per row, index_to_sample stores the values per row
values = [torch.tensor([x]) for x in metric_to_samples_dict.keys()]
samples = [torch.tensor(metric_to_samples_dict[x]) for x in metric_to_samples_dict.keys()]
index_to_metric_fname = f"{metric_save_path}/{metric_name}_index_to_metric" #dict keys
index_to_sample_fname = f"{metric_save_path}/{metric_name}_index_to_sample" #dict values
self.file_write_ordered(values, index_to_metric_fname, metric_value_dtype)
self.file_write_ordered(samples, index_to_sample_fname, sample_value_dtype)
if self.worker_id == 0:
DataAnalyzer.output_index_to_sample_percentile(index_to_sample_fname, index_to_metric_fname,
metric_name, metric_save_path, total_num_samples,
sample_idx_dtype)
dist.barrier(self.comm_group)
elif metric_type == 'accumulate_value_over_samples':
metric_value_fname = f"{metric_save_path}/{metric_name}_metric_value"
dist.reduce(metric_values, dst=0, op=dist.ReduceOp.SUM, group=self.comm_group)
metric_value_dtype = find_fit_int_dtype(metric_values.min(), metric_values.max())
if self.worker_id == 0:
builder = create_mmap_dataset_builder(metric_value_fname, metric_value_dtype)
builder.add_item(metric_values.cpu())
close_mmap_dataset_builder(builder, metric_value_fname)
dist.barrier(self.comm_group)
def file_write_ordered(self, tensor_list, fname, numpy_dtype):
""" MPI_file_write_ordered extended to write a list of tensors, by one rank, iteratively """
# each node has a list of rows (tensors) to be written to the file.
# we will serialize it in order to communicate it in one comm step.
tkwargs = dict(dtype=torch.int64, device=self.device)
# 1. gather on rank 0 the number of rows to be sent/recv
row_count = torch.tensor([len(tensor_list)], **tkwargs)
row_counts = torch.zeros(self.num_workers, **tkwargs)
dist.all_gather_into_tensor(row_counts, row_count, group=self.comm_group)
assert row_counts[self.worker_id] == row_count == len(tensor_list), "all_gather failed"
# 2. gather on rank 0 the sizes of the rows to be sent/recv
row_len = torch.tensor([len(l) for l in tensor_list], **tkwargs)
row_lens = Dist.gather_v(row_len, 0, self.comm_group, self.num_workers, self.worker_id)
# 4. gather on rank 0 of the total size (sum of all row lengths) to be received
size = torch.tensor([sum(row_len).item()], **tkwargs)
sizes = torch.zeros(self.num_workers, **tkwargs)
dist.all_gather_into_tensor(sizes, size, group=self.comm_group)
assert sizes[self.worker_id] == size.item(), "all_gather did not return the same sizes" #sanity check
# method to deserializes a buffer into rows of different lengths and write them to file
def write_buffer_to_file(buff, src, builder):
assert self.worker_id == 0, "only rank 0 can write to file"
# collect all buffers and write them at once
buff = buff.cpu().detach().numpy()
row_offsets = np.cumsum([0] + row_lens[src].tolist())
arr_list = []
for i in range(len(row_lens[src])):
arr_list.append(buff[row_offsets[i]:row_offsets[i + 1]])
builder.add_items(arr_list)
# 5. rank 0 prepares output folder and file
if self.worker_id == 0:
os.makedirs(os.path.dirname(fname), exist_ok=True)
builder = create_mmap_dataset_builder(fname, numpy_dtype)
# iterate through ranks that have data to be sent/recv/written
for src in [rank for rank, count in enumerate(row_counts) if count > 0]:
dist.barrier(group=self.comm_group)
if self.worker_id == 0 and src == 0: # rank 0's write its own data
buffer = torch.cat(tensor_list, dim=0).to(self.device)
write_buffer_to_file(buffer, 0, builder)
elif self.worker_id == 0 and src > 0: # rank 0 receives other rank's data and writes it
buffer = torch.empty(sizes[src].item(), dtype=numpy_dtype, device=self.device)
err = dist.recv(buffer, src=src, group=self.comm_group, tag=src)
assert err == src and len(buffer) > 0, "recv failed"
write_buffer_to_file(buffer, src, builder)
elif self.worker_id == src: # current rank sends data to rank 0
buffer = torch.cat(tensor_list, dim=0).to(self.device)
dist.send(buffer, 0, group=self.comm_group, tag=src)
# rank 0 closes the file
if self.worker_id == 0:
close_mmap_dataset_builder(builder, fname) # close file
dist.barrier(self.comm_group)
class Dist:
""" auxiliary class to perform distributed operations on tensors"""
@staticmethod
def min_max(tensor, comm_group):
""" given a distributed tensor, return the min/max values across all ranks"""
value_min, value_max = tensor.min(), tensor.max()
dist.reduce(value_min, 0, op=dist.ReduceOp.MIN, group=comm_group)
dist.reduce(value_max, 0, op=dist.ReduceOp.MAX, group=comm_group)
return value_min.item(), value_max.item()
@staticmethod
def gather_v(tensor, dst, comm_group, num_workers, worker_id):
""" MPI_Gatherv. gather tensors of variable sizes in a single rank """
# gather the number of rows to be sent/recv
size = torch.tensor([len(tensor)], dtype=torch.int64, device=tensor.device)
sizes = torch.zeros(num_workers, dtype=torch.int64, device=tensor.device)
dist.all_gather_into_tensor(sizes, size, group=comm_group)
assert sizes[worker_id] == size, "all_gather failed"
# all_gather requires all tensors to be of same size so we need to pad them
max_size = max(sizes).item()
buffer = torch.empty(max_size, dtype=tensor.dtype, device=tensor.device)
buffer[0:size] = tensor.data
buffer_list = None
if worker_id == 0: # create padded recv buffers
buffer_list = [torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) for _ in range(num_workers)]
dist.gather(buffer, buffer_list, dst=dst, group=comm_group)
# revert padding and return value
if worker_id == 0:
buffer_list = [r[:s.item()] for r, s in zip(buffer_list, sizes)]
return buffer_list
@staticmethod
def sample_sort(tensor, comm_group, num_workers, n_samples=100):
""" perform a distributed random sort of a tensor, and returns the sorted partial tensor"""
device, dims = tensor.device, tensor.size()[1]
# 1 - sort rows by first column, then second column, then third, etc...
tensor = torch.tensor(sorted(tensor.tolist()), dtype=tensor.dtype, device=tensor.device)
# 2 - collect few samples per rank
idx = torch.round(torch.linspace(0, len(tensor) - 1, n_samples)).to(int)
samples = tensor[idx][:, 0].contiguous().to(device) #only first column, all but last row
# 2 - Allgather samples
all_samples = [torch.zeros(n_samples, dtype=samples.dtype, device=device) for _ in range(num_workers)]
dist.all_gather(all_samples, samples, group=comm_group)
all_samples = torch.cat(all_samples, dim=0).to(device)
# 3 - Sort all samples and collect the ranges of each rank as equidistant
all_samples = all_samples.sort()[0]
idx = torch.round(torch.linspace(0, len(all_samples) - 1, num_workers + 1)).to(int)
ranges = all_samples[idx] # range of each rank r as ranges[r] <= x < ranges[r+1]
ranges[-1] += 1 # increase upper limit of last rank so that x < ranges[r+1].
# 4 - collect elements to send to each rank, based on the rank ranges
send = []
for rank in range(num_workers):
mask = (tensor[:, 0] >= ranges[rank]) & (tensor[:, 0] < ranges[rank + 1])
send.append(tensor[mask])
# 5. all to all to communicate the sizes to be sent/recv
send_count = [torch.tensor([len(s) * dims], dtype=torch.int64, device=device) for s in send]
recv_count = list(torch.empty([num_workers], dtype=torch.int64, device=device).chunk(num_workers))
dist.all_to_all(recv_count, send_count, group=comm_group)
# 6. all-to-all-v to communicate the elements to be sent/recv as a single tensor
send = torch.cat(send, dim=0).flatten().to(device)
recv = torch.zeros(sum(recv_count), dtype=send.dtype).to(device)
send_count = [s.item() for s in send_count] # convert to list of ints
recv_count = [r.item() for r in recv_count]
dist.all_to_all_single(recv, send, recv_count, send_count, group=comm_group)
del send
# 7. the received tensor is the 1D disjoint subset of the distributed tensor.
# We will recover the original dimensionality and sort it by columns again.
recv = recv.view(-1, dims)
recv = torch.tensor(sorted(recv.tolist()), dtype=recv.dtype, device=recv.device)
return recv
def test_compare_both_data_analyzers(dataset):
""" given a dataset, compare file and memory based data analyser"""
id = lambda t: t.to(torch.int64) # identity
batch_sum = lambda t: id(t).sum() #sum batch
num_threads = 4
kwargs = dict(
dataset=dataset,
batch_size=2**10,
worker_id=int(os.environ['RANK']),
num_workers=int(os.environ['WORLD_SIZE']),
metric_names=["mod", "batch_sum"],
metric_functions=[id, batch_sum],
metric_types=['single_value_per_sample', 'accumulate_value_over_samples'],
num_threads=num_threads,
)
dda = DistributedDataAnalyzer(
save_path="./output_dist",
device=f"cuda:{int(os.environ['LOCAL_RANK'])}",
**kwargs,
)
start_time = time.time()
dda.run_map_reduce()
if dda.worker_id == 0:
print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time))
da = DataAnalyzer(num_threads_reduce=num_threads,
save_path="./output_disk",
metric_dtypes=[torch.int64, torch.int64],
**kwargs)
start_time = time.time()
da.run_map_reduce()
if da.worker_id == 0:
print("DataAnalyzer runtime: %s seconds " % (time.time() - start_time))
output_paths = [
"batch_sum/batch_sum_metric_value.bin", "batch_sum/batch_sum_metric_value.idx", \
"mod/mod_index_to_metric.bin", "mod/mod_index_to_metric.idx", \
"mod/mod_index_to_sample.bin", "mod/mod_index_to_sample.idx", \
"mod/mod_index_to_sample_percentile_merged.bin", "mod/mod_index_to_sample_percentile_merged.idx", \
"mod/mod_sample_to_metric.bin", "mod/mod_sample_to_metric.idx"
]
if dda.worker_id == 0:
for path in output_paths:
with open(os.path.join(da.save_path, path), 'rb') as f1, \
open(os.path.join(dda.save_path, path), 'rb') as f2:
# if files have suffix .bin, they should be identical
if path.endswith(".bin"):
assert f1.read() == f2.read(), f"files {path} are not identical."
elif f1.read() != f2.read():
print(f"files {path} are not identical.")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
class TestDataset(torch.utils.data.Dataset):
def __init__(self, size=10_000_000):
self.values = [(x + 7) % 10_000 for x in range(size)]
self.size = size
__len__ = lambda self: self.size
__getitem__ = lambda self, idx: self.values[idx]
test_compare_both_data_analyzers(TestDataset())