DeepNVMe perf tuning (#6560)

Add performance tuning utilities: `ds_nvme_tune` and `ds_io`.  
Update tutorial with tuning section.

---------

Co-authored-by: Ubuntu <jomayeri@microsoft.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
This commit is contained in:
Olatunji Ruwase
2024-09-26 09:07:19 -04:00
committed by GitHub
parent 7622cd9e68
commit a5400974df
16 changed files with 1350 additions and 2 deletions

6
bin/ds_io Normal file
View File

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from deepspeed.nvme import ds_io_main
if __name__ == '__main__':
ds_io_main()

9
bin/ds_nvme_tune Normal file
View File

@ -0,0 +1,9 @@
#!/usr/bin/env python3
from deepspeed.nvme import sweep_main, generate_main, parse_sweep_arguments
if __name__ == '__main__':
args = parse_sweep_arguments()
print(f"Running DeepNVMe performance tuning on {args.nvme_dir}")
sweep_main(args)
generate_main(args.log_dir)

View File

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .perf_run_sweep import sweep_main, parse_sweep_arguments
from .perf_generate_param import generate_main
from .test_ds_aio import ds_io_main

View File

@ -0,0 +1,175 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import argparse
import os
from .test_ds_aio_utils import refine_integer_value
from deepspeed.accelerator import get_accelerator
MAPPING_DELIMITER = ':'
def refine_args(args):
if args.io_size and type(args.io_size) == str:
args.io_size = refine_integer_value(args.io_size)
if args.block_size and type(args.block_size) == str:
args.block_size = refine_integer_value(args.block_size)
return args
def _get_mapping_dict(args):
if args.folder is not None:
d = {i: args.folder for i in range(args.multi_process)}
else:
d = {}
for m in args.folder_to_device_mapping:
fields = m.split(MAPPING_DELIMITER)
d[fields[1]] = fields[0]
return d
def _validate_folder_mapping(args):
no_error = True
error_messages = []
invalid_mappings = [m for m in args.folder_to_device_mapping if MAPPING_DELIMITER not in m]
if len(invalid_mappings) > 0:
error_messages.append(
f'Missing delimiter ({MAPPING_DELIMITER}) in folder_to_device_mapping {invalid_mappings}')
no_error = False
folder_list = [m.split(MAPPING_DELIMITER)[0] for m in args.folder_to_device_mapping]
invalid_folders = [d for d in folder_list if not os.path.exists(d)]
if len(invalid_folders) > 0:
error_messages.append(f'Invalid folders in folder_to_device_mapping: {invalid_folders}')
no_error = False
if args.gpu:
device_list = [int(m.split(MAPPING_DELIMITER)[1]) for m in args.folder_to_device_mapping]
invalid_device_list = [dev_id for dev_id in device_list if not dev_id < get_accelerator().device_count()]
if len(invalid_device_list) > 0:
error_messages.append(f'Invalid device ids in folder_to_device_mapping: {invalid_device_list}')
no_error = False
return no_error, error_messages
def validate_args(args):
no_error = True
error_messages = []
if args.folder is not None and len(args.folder_to_device_mapping) > 0:
error_messages.append(f'--folder and --folder_to_device_mapping cannot be specified together.')
no_error = False
elif args.folder is None and len(args.folder_to_device_mapping) == 0:
error_messages.append(f'At least one of --folder or --folder_to_device_mapping must be specified.')
no_error = False
# Validate --folder
if args.folder is not None and not os.path.exists(args.folder):
no_error = False
error_messages.append(f'Invalid folder in --folder: {args.folder} ')
# Validate --folder_mapping_to_device
if len(args.folder_to_device_mapping) > 0:
no_mapping_error, mapping_error_messages = _validate_folder_mapping(args)
no_error = no_error and no_mapping_error
error_messages += mapping_error_messages
# Validate --gpu, --use_gds
if args.use_gds and not args.gpu:
error_messages.append(f'--gpu must be set to transfer with --use_gds')
no_error = False
if not no_error:
print(f'Found {len(error_messages)} validation errors')
for i, msg in enumerate(error_messages):
print(f'{i+1}: {msg}')
return no_error
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--folder', default=None, type=str, help='Folder to use for I/O.')
parser.add_argument('--folder_to_device_mapping',
default=[],
nargs='+',
help='Specification of mapping of folder to (gpu) device id, (ignored for cpu accesses).'
'Can be specified multiple times for multi-process runs,'
'e.g. --folder_to_device_mapping /mnt/nvme0:0 --folder_to_device_mapping /mnt/nvme1:15 --gpu'
'means access /mnt/nvme0 with gpu 0 and /mnt/nvme1 with gpu 15')
parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.')
parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)')
parser.add_argument('--multi_process',
type=int,
default=1,
help='Number of parallel processes doing I/O (default 1).')
parser.add_argument('--block_size',
type=str,
default='1M',
help='I/O block size. Can use K, M, or G suffix (default 1M for 1 megabytes).')
parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth (default 32).')
parser.add_argument('--single_submit',
action='store_true',
help='Submit I/O requests in singles (default is submit queue_depth amount at once.).')
parser.add_argument(
'--sequential_requests',
action='store_true',
help=
'Delay I/O request submission until completion of prior requests (default is overlap I/O submission and completion requests.).'
)
parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.')
parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions')
parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism')
parser.add_argument('--gpu', action='store_true', help='Use GPU memory')
parser.add_argument('--use_gds', action='store_true', help='Enable GDS AIO')
parser.add_argument('--slow_bounce_buffer',
action='store_true',
help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.')
args = parser.parse_args()
print(f'args = {args}')
return args
def get_validated_args():
args = parse_arguments()
args = refine_args(args)
if not validate_args(args):
quit()
print(f'Successful validation of command line arguments')
peer_tag = 'gpu' if args.gpu else 'process'
args.mapping_dict = _get_mapping_dict(args)
args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()]
assert len(args.mapping_dict) == len(args.mapping_list)
print(f'Configuring {len(args.mapping_list)} {peer_tag} to folder mapping')
for i, (device_id, folder) in enumerate(args.mapping_list):
print(f'[{i}]: {peer_tag} {device_id} <----> {folder}')
return args

134
deepspeed/nvme/ds_aio_basic.py Executable file
View File

@ -0,0 +1,134 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from .test_ds_aio_utils import report_results, task_log, task_barrier
def pre_basic(args, tid, read_op):
io_string = "Read" if read_op else "Write"
num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
file = args.read_file if read_op else f'{args.write_file}.{tid}'
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
ctxt = {}
ctxt['file'] = file
ctxt['num_bytes'] = num_bytes
ctxt['buffer'] = buffer
ctxt['elapsed_sec'] = 0
return ctxt
def pre_basic_read(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, True)
return ctxt
def pre_basic_write(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, False)
return ctxt
def post_basic(pool_params):
_, _, ctxt = pool_params
ctxt["buffer"].detach()
ctxt["buffer"] = None
return ctxt
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_basic_read
schedule['post'] = post_basic
schedule['main'] = main_basic_read
else:
schedule['pre'] = pre_basic_write
schedule['post'] = post_basic
schedule['main'] = main_basic_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_basic_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

222
deepspeed/nvme/ds_aio_handle.py Executable file
View File

@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import torch
import os
import time
from multiprocessing import Pool, Barrier
from deepspeed.ops.aio import AsyncIOBuilder
from deepspeed.ops.op_builder import GDSBuilder
from deepspeed.accelerator import get_accelerator
from .test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file
BUFFER = 'buffer'
BOUNCE_BUFFER = 'bounce_buffer'
def pre_handle(args, tid, read_op):
io_string = "Read" if read_op else "Write"
gds = True if args.use_gds else False
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
bounce_buffer = None
if args.gpu:
device_name = get_accelerator().device_name(device_id)
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
if not (args.slow_bounce_buffer or gds):
bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8,
device='cpu').pin_memory()
else:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
io_parallel = args.io_parallel if args.io_parallel else 1
if gds:
handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
handle.pin_device_tensor(buffer)
else:
handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
task_log(tid, f'created deepspeed aio handle')
ctxt = {}
ctxt['file'] = filename
ctxt['num_bytes'] = args.io_size
ctxt['handle'] = handle
ctxt['gds'] = gds
ctxt[BUFFER] = buffer
ctxt[BOUNCE_BUFFER] = bounce_buffer
ctxt['elapsed_sec'] = 0
return ctxt
def pre_handle_read(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, True)
return ctxt
def pre_handle_write(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, False)
return ctxt
def post_handle(pool_params):
_, _, ctxt = pool_params
for buf in [BUFFER, BOUNCE_BUFFER]:
if ctxt[buf] is not None:
if ctxt['gds']:
ctxt['handle'].unpin_device_tensor(ctxt[buf])
ctxt[buf].detach()
ctxt[buf] = None
return ctxt
def main_parallel_read(pool_params):
args, tid, ctxt = pool_params
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_parallel_write(pool_params):
args, tid, ctxt = pool_params
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_read(pool_parms):
args, tid, ctxt = pool_parms
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate)
assert ret != -1
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_write(pool_parms):
args, tid, ctxt = pool_parms
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate)
assert ret != -1
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_handle_read
schedule['post'] = post_handle
schedule['main'] = main_parallel_read
else:
schedule['pre'] = pre_handle_write
schedule['post'] = post_handle
schedule['main'] = main_parallel_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_handle_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping tensors to/from (NVMe) storage devices.
"""
import subprocess
import shlex
class Job(object):
def __init__(self, cmd_line, output_file=None, work_dir=None):
self.cmd_line = cmd_line
self.output_file = output_file
self.work_dir = work_dir
self.output_fd = None
def cmd(self):
return self.cmd_line
def get_stdout(self):
return self.output_fd
def get_stderr(self):
return self.output_fd
def get_cwd(self):
return self.work_dir
def open_output_file(self):
if self.output_file is not None:
self.output_fd = open(self.output_file, 'w')
def close_output_file(self):
if self.output_fd is not None:
self.output_fd.close()
self.output_fd = None
def run_job(job, verbose=False):
args = shlex.split(' '.join(job.cmd()))
if verbose:
print(f'args = {args}')
job.open_output_file()
proc = subprocess.run(args=args, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
job.close_output_file()
assert proc.returncode == 0, \
f"This command failed: {job.cmd()}"

View File

@ -0,0 +1,148 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import os
import argparse
READ_SPEED = 'read_speed'
WRITE_SPEED = 'write_speed'
PERF_METRICS = [READ_SPEED, WRITE_SPEED]
METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir', type=str, required=True, help='Folder of statistics logs')
parser.add_argument('--metric',
type=str,
required=True,
help='Performance metric to report: [read_speed|write_speed]')
args = parser.parse_args()
print(f'args = {args}')
return args
def extract_value(key, file):
INVALID_PREFIXES = ["ds"]
for p in INVALID_PREFIXES:
if key.startswith(p):
return key
try:
if key[0] in ['t', 'd', 'p']:
return int(key[1:])
if key.startswith("bs"):
if key.endswith('K'):
v = key[2:].split('K')
return int(v[0]) * 1024
elif key.endswith('M'):
v = key[2:].split('M')
return int(v[0]) * 1024 * 1024
else:
return int(key[2:])
except:
print(f"{file}: extract_value fails on {key}")
return None
return key
def get_file_key(file):
f, _ = os.path.splitext(os.path.basename(file))
fields = f.split('_')
values = [extract_value(k, file) for k in fields]
return tuple(values)
def get_thread_count(file):
f, _ = os.path.splitext(os.path.basename(file))
fields = f.split('_')
for key in fields:
if key[0] == 't':
return int(key[1:])
return 1
"""
Extract performance metric from log file.
Sample file lines are:
Task Read Latency = 0.031647682189941406 sec
Task Read Speed = 12.342926020792527 GB/sec
E2E Read Latency = 0.031697988510131836 sec
E2E Read Speed = 12.323337169333062 GB/sec
For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned
"""
def get_metric(file, metric):
thread_count = get_thread_count(file)
with open(file) as f:
for line in f.readlines():
if line.startswith(METRIC_SEARCH[metric]):
if metric in [READ_SPEED, WRITE_SPEED]:
fields = line.split()
return float(fields[-2])
else:
fields = line.split('=')
return float(fields[-1])
return None
def validate_args(args):
if not args.metric in PERF_METRICS:
print(f'{args.metric} is not a valid performance metrics')
return False
if not os.path.isdir(args.log_dir):
print(f'{args.log_dir} folder is not existent')
return False
return True
def get_results(log_files, metric):
results = {}
for f in log_files:
file_key = get_file_key(f)
value = get_metric(f, metric)
results[file_key] = value
return results
def get_sorted_results(log_dir, metric):
log_files = [f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, f))]
log_files_path = [os.path.join(log_dir, f) for f in log_files]
results = get_results(log_files_path, metric)
result_keys = list(results.keys())
sorted_keys = sorted(result_keys)
return sorted_keys, results
def main():
print("Parsing aio statistics")
args = parse_arguments()
if not validate_args(args):
quit()
sorted_keys, results = get_sorted_results(args.log_dir, args.metric)
for k in sorted_keys:
print(f'{k} = {results[k]}')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import os
import argparse
import json
from .parse_nvme_stats import READ_SPEED, WRITE_SPEED, get_sorted_results
from .perf_sweep_utils import BENCH_LOG_DIR, READ_LOG_DIR, WRITE_LOG_DIR
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir',
type=str,
default=BENCH_LOG_DIR,
help=f'Folder of performance sweep logs. Default is {os.path.join(".", BENCH_LOG_DIR)}')
parser.add_argument('--verbose', action='store_true', help='Print debugging information.')
args = parser.parse_args()
if args.verbose:
print(f'args = {args}')
return args
def validate_args(args):
for d in [READ_LOG_DIR, WRITE_LOG_DIR]:
log_dir = os.path.join(args.log_dir, d)
if not os.path.isdir(log_dir):
print(f'{log_dir} folder is not existent')
return False
return True
def convert_to_param(key):
assert len(key) == 6
return {
"single_submit": "true" if key[0] == "single" else "false",
"overlap_events": "true" if key[1] == "overlap" else "false",
"num_threads": int(key[5]),
"queue_depth": int(key[3]),
"block_size": int(key[4])
}
def generate_aio_param(read_log_dir, write_log_dir):
_, read_results = get_sorted_results(read_log_dir, READ_SPEED)
_, write_results = get_sorted_results(write_log_dir, WRITE_SPEED)
combined_perf = {key[1:]: value for key, value in read_results.items()}
for key, value in write_results.items():
new_key = key[1:]
if new_key in combined_perf:
combined_perf[new_key] += value
else:
combined_perf[new_key] = 0
optimal_key = None
optimal_perf = 0.0
for key, value in combined_perf.items():
if value > optimal_perf:
optimal_perf = value
optimal_key = key
aio_param = {"aio": convert_to_param(optimal_key)}
read_perf_keys = {key[1:]: key for key in read_results.keys()}
write_perf_keys = {key[1:]: key for key in write_results.keys()}
optimal_config_read = read_results.get(read_perf_keys[optimal_key], None)
optimal_config_write = write_results.get(write_perf_keys[optimal_key], None)
print(f'Best performance (GB/sec): read = {optimal_config_read:5.2f}, write = {optimal_config_write:5.2f}')
print(json.dumps(aio_param, indent=3))
def generate_main(log_dir):
read_log_dir = os.path.join(log_dir, READ_LOG_DIR)
write_log_dir = os.path.join(log_dir, WRITE_LOG_DIR)
generate_aio_param(read_log_dir, write_log_dir)
def main():
args = parse_arguments()
if not validate_args(args):
quit()
print(f'Generate DeepNVMe configuration from {args.log_dir} logs')
generate_main(args.log_dir)
if __name__ == "__main__":
generate_main()

View File

@ -0,0 +1,320 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import os
import sys
import argparse
import json
import itertools
import shutil
from deepspeed.ops.op_builder import AsyncIOBuilder
from deepspeed.ops.op_builder import GDSBuilder
from .ds_aio_job import Job, run_job
from .perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
READ_LOG_DIR, WRITE_LOG_DIR
OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'ds_io'
DEFAULT_SWEEP_CONFIG = {
"block_size": ["1M", "8M"],
"queue_depth": [32, 128],
"sequential_requests": [False],
"single_submit": [False],
"io_parallel": [1, 8],
}
class SweepConfig(object):
def __init__(self, args):
self.folder_to_device_mapping = get_ftd_map(args.nvme_dir)
self.search_space = get_sweep_config_dict(args.sweep_config)
self.search_space.update(self.folder_to_device_mapping)
self.read = not args.no_read
self.write = not args.no_write
self.flush_cache = args.flush_page_cache
self.log_dir = args.log_dir
self.verbose = args.verbose
self.other_options = f'{OTHER_OPTIONS} --loops {args.loops} --io_size {args.io_size}'
if args.gpu:
self.other_options += ' --gpu'
if args.gds:
self.other_options += ' --use_gds'
def validate_arguments(args):
if not async_io_setup():
error_msg = """
Failing because environment is not properly configured for deepspeed async i/o module.
Possible fix: apt install libaio-dev.
"""
print(error_msg)
quit()
if args.gds and not gds_io_setup():
error_msg = """
Failing because environment is not properly configured for deepspeed GDS I/O operator.
"""
print(error_msg)
quit()
def parse_sweep_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--nvme_dir',
nargs='+',
required=True,
help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.')
parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.')
parser.add_argument('--no_read', action='store_true', help='Disable read performance measurements.')
parser.add_argument('--no_write', action='store_true', help='Disable write performance measurements.')
parser.add_argument('--io_size',
type=str,
default="400M",
help='Number of I/O bytes to read/write for performance measurements.')
parser.add_argument('--gpu', action='store_true', help='Test tensor transfers between GPU device and NVME device.')
parser.add_argument('--gds', action='store_true', help='Run the sweep over NVIDIA GPUDirectStorage operator')
parser.add_argument(
'--flush_page_cache',
action='store_true',
help=
'Page cache will not be flushed and reported read speeds may be higher than actual ***Requires sudo access***.'
)
parser.add_argument(
'--log_dir',
type=str,
default=BENCH_LOG_DIR,
help=f'Output directory for performance log files. Default is {os.path.join(".", BENCH_LOG_DIR)}')
parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions')
parser.add_argument('--verbose', action='store_true', help='Print debugging information.')
args = parser.parse_args()
if args.verbose:
print(f'args = {args}')
validate_arguments(args)
return args
def dump_cmd_lines(cmd_lines):
print(f'cmd line count = {len(cmd_lines)}')
for i, cmd in enumerate(cmd_lines):
print(f'{i}: {cmd}')
def get_ftd_map(nvme_dir_list):
ftd_list = [f'{dir}:{dev}' for dev, dir in enumerate(nvme_dir_list)]
ftd_arg = [' '.join(ftd for ftd in ftd_list)]
return {'folder_to_device_mapping': ftd_arg}
def get_sweep_config_dict(sweep_config_json):
if sweep_config_json is None:
return DEFAULT_SWEEP_CONFIG
with open(sweep_config_json) as fp:
sweep_config = json.load(fp)
return sweep_config
def get_sweep_cmd_lines(sweep_config_dict):
def flatten_options(key, value_list):
flat_list = []
for v in value_list:
if not type(v) is bool:
flat_list.append(f'--{key} {v}')
elif v:
flat_list.append(f'--{key}')
else:
flat_list.append(' ')
return flat_list
flat_list = [flatten_options(key, value) for key, value in sweep_config_dict.items()]
cmd_list = list(itertools.product(*flat_list))
cmd_list = [list(cmd) for cmd in cmd_list]
#dump_cmd_lines(cmd_list)
return cmd_list
def launch_sweep(sweep_jobs, sync_job, flush_cache_job, verbose):
for perf_job in sweep_jobs:
if flush_cache_job is not None:
run_job(sync_job, verbose)
run_job(flush_cache_job, verbose)
run_job(perf_job, verbose)
run_job(sync_job, verbose)
def create_cmd_tags(cmd_line):
tags = {}
for param_value in cmd_line:
fields = param_value.split()
if len(fields) == 1:
tags[fields[0]] = None
elif len(fields) == 2:
if fields[0] == '--folder_to_device_mapping':
tags[fields[0]] = len(fields[1:])
else:
tags[fields[0]] = fields[1]
elif len(fields) > 2:
tags[fields[0]] = len(fields[1:])
return tags
def get_log_file(io_op_desc, cmd_line):
QUEUE_DEPTH = "--queue_depth"
BLOCK_SIZE = "--block_size"
SINGLE_SUBMIT = "--single_submit"
SEQUENTIAL_REQUESTS = "--sequential_requests"
FTD_MAP = "--folder_to_device_mapping"
IO_PARALLEL = "--io_parallel"
tag_map = {
QUEUE_DEPTH: "d",
BLOCK_SIZE: "bs",
SINGLE_SUBMIT: "single",
SEQUENTIAL_REQUESTS: "sequential",
FTD_MAP: "ftd",
IO_PARALLEL: "p"
}
tag_default = {
QUEUE_DEPTH: 1,
BLOCK_SIZE: "1M",
SINGLE_SUBMIT: "block",
SEQUENTIAL_REQUESTS: "overlap",
FTD_MAP: 1,
IO_PARALLEL: 1
}
def get_default_value(tag):
value = tag_default[tag]
if tag in [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS]:
return value
return f'{tag_map[tag]}{value}'
def get_config_value(tag, value):
tag_key = tag_map[tag]
if value is None:
return tag_key
return f'{tag_key}{value}'
tag_list = [SINGLE_SUBMIT, SEQUENTIAL_REQUESTS, FTD_MAP, QUEUE_DEPTH, BLOCK_SIZE, IO_PARALLEL]
log_tags = [io_op_desc]
cmd_tags = create_cmd_tags(cmd_line)
for tag in tag_list:
if tag in cmd_tags:
log_tags.append(get_config_value(tag, cmd_tags[tag]))
else:
log_tags.append(get_default_value(tag))
log_file = '_'.join(log_tags)
log_file += '.txt'
return log_file
def create_perf_jobs(io_op_desc, log_dir, cmd_lines):
py_cmd = [os.path.join(script_path(), PERF_SCRIPT)]
perf_jobs = []
for cmd in cmd_lines:
log_file = os.path.join(log_dir, get_log_file(io_op_desc, cmd))
job = Job(cmd_line=py_cmd + cmd, output_file=log_file)
perf_jobs.append(job)
return perf_jobs
def script_path():
return os.path.dirname(os.path.realpath(sys.argv[0]))
def async_io_setup():
return AsyncIOBuilder().is_compatible()
def gds_io_setup():
return GDSBuilder().is_compatible()
def remove_folder(folder):
assert os.path.isdir(folder), f"Error: cannot remove {folder} - folder not found"
shutil.rmtree(folder)
def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
read_cmd_lines = [[f'--read {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
# dump_cmd_lines(cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
perf_jobs = create_perf_jobs(io_op_desc=READ_OP_DESC, log_dir=log_folder, cmd_lines=read_cmd_lines)
launch_sweep(sweep_jobs=perf_jobs,
sync_job=sync_job,
flush_cache_job=flush_cache_job,
verbose=sweep_config.verbose)
def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
write_cmd_lines = [[f'{sweep_config.other_options}'] + cmd for cmd in cmd_lines]
# dump_cmd_lines(write_cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
perf_jobs = create_perf_jobs(io_op_desc=WRITE_OP_DESC, log_dir=log_folder, cmd_lines=write_cmd_lines)
launch_sweep(sweep_jobs=perf_jobs,
sync_job=sync_job,
flush_cache_job=flush_cache_job,
verbose=sweep_config.verbose)
def sweep_main(args):
sweep_config = SweepConfig(args)
cmd_lines = get_sweep_cmd_lines(sweep_config.search_space)
if sweep_config.flush_cache:
flush_cache_job = Job(cmd_line=['sudo', 'bash -c', "'echo 1 > /proc/sys/vm/drop_caches'"])
else:
flush_cache_job = None
sync_job = Job(cmd_line=['sync'])
if sweep_config.read:
run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines)
if sweep_config.write:
run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines)
def main():
args = parse_sweep_arguments()
print(f"Running DeepNVMe performance sweep on {args.nvme_dir}")
sweep_main(args)
if __name__ == "__main__":
sweep_main()

View File

@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
SCRIPT_PREFIX = '_aio_bench'
WRITE_OP_DESC = 'write'
READ_OP_DESC = 'read'
READ_IO_DIR = f'{SCRIPT_PREFIX}_{READ_OP_DESC}_io'
WRITE_IO_DIR = f'{SCRIPT_PREFIX}_{WRITE_OP_DESC}_io'
BENCH_LOG_DIR = f'{SCRIPT_PREFIX}_logs'
READ_LOG_DIR = f'{SCRIPT_PREFIX}_{READ_OP_DESC}_logs'
WRITE_LOG_DIR = f'{SCRIPT_PREFIX}_{WRITE_OP_DESC}_logs'

25
deepspeed/nvme/test_ds_aio.py Executable file
View File

@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import multiprocessing as mp
from .ds_aio_basic import aio_basic_multiprocessing
from .ds_aio_handle import aio_handle_multiprocessing
from .ds_aio_args import get_validated_args
def ds_io_main():
print(f'Testing deepspeed_aio python frontend')
args = get_validated_args()
mp.set_start_method('spawn')
multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing
multiprocess_function(args, args.read)
if __name__ == "__main__":
ds_io_main()

View File

@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import os
from .ds_aio_job import Job, run_job
BYTES_PER_GB = 1024**3
BYTES_PER_MB = 1024**2
BYTES_PER_KB = 1024
LOG_TIDS = [0]
def task_log(tid, msg, force=False):
if force or tid in LOG_TIDS:
print(f'tid {tid}: {msg}')
def task_barrier(barrier, num_parties):
assert barrier.parties == num_parties
barrier.wait()
assert barrier.broken == False
def report_results(args, read_op, pool_results):
#print(f'pool_results = {pool_results}')
io_string = 'Read' if read_op else 'Write'
if None in pool_results:
print(f'Failure in one of {args.threads} {io_string} processes')
return
total_bytes = sum([num_bytes for _, _, num_bytes in pool_results])
task_latency_sec = max([sec for _, sec, _ in pool_results])
task_speed_GB = 0 if task_latency_sec == 0 else total_bytes / task_latency_sec / BYTES_PER_GB
print(f'Task {io_string} Latency = {task_latency_sec} sec')
print(f'Task {io_string} Speed = {task_speed_GB} GB/sec')
e2e_latency_sec = max([sec for sec, _, _ in pool_results])
e2e_speed_GB = 0 if e2e_latency_sec == 0 else total_bytes / e2e_latency_sec / BYTES_PER_GB
print(f'E2E {io_string} Latency = {e2e_latency_sec} sec')
print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec')
def get_block_size_and_count(io_bytes):
if io_bytes > BYTES_PER_MB and io_bytes % BYTES_PER_MB == 0:
block_size = BYTES_PER_MB
block_size_string = '1M'
else:
assert io_bytes % BYTES_PER_KB == 0
block_size = BYTES_PER_KB
block_size_string = '1K'
block_count = io_bytes / block_size
return block_size_string, int(block_count)
def refine_integer_value(value):
unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3}
if value[-1] in list(unit_dict.keys()):
int_value = int(value[:-1]) * unit_dict[value[-1]]
return int_value
return int(value)
def create_filename(folder, read_op, size, tid):
io_string = "read" if read_op else "write"
return os.path.join(folder, f'_aio_{io_string}_{size}.pt.{tid}')
def create_file(filename, num_bytes):
block_size, block_count = get_block_size_and_count(num_bytes)
dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={filename} bs={block_size} count={block_count}'])
print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....')
run_job(dd_job)
print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....')

View File

@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
from deepspeed.ops.op_builder import AsyncIOBuilder
assert AsyncIOBuilder().is_compatible()
assert AsyncIOBuilder().load()

View File

@ -188,7 +188,7 @@ This tutorial has been significantly improved by feedback from [Guanhua Wang](ht
## Appendix
### Advanced Handle Creation
Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `num_threads`. The `aio_handle` constructor parameters and default values are illustrated below:
Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, PCIe, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `num_threads`. The `aio_handle` constructor parameters and default values are illustrated below:
```bash
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> help(AsyncIOBuilder().load().aio_handle())
@ -208,6 +208,56 @@ class aio_handle(pybind11_builtins.pybind11_object)
| AIO handle constructor
```
### Performance Tuning
As discussed [earlier](#advanced-handle-creation), achieving peak DeepNVMe performance for a target workload or environment requires using optimally configured `aio_handle` or `gds_handle` handles. For configuration convenience, we provide a utility called `ds_nvme_tune` to automate the discovery of optimal DeepNVMe configurations. `ds_nvme_tune` automatically explores a user-specified or default configuration space and recommends the option that provides the best read and write performance. Below is an example usage of `ds_nvme_tune` to tune `aio_handle` data transfers between GPU memory and a local NVVMe SSD mounted on `/local_nvme`. This example used the default configuration space of `ds_nvme_tune` for tuning.
```bash
$ ds_nvme_tune --nvme_dir /local_nvme --gpu
Running DeepNVMe performance tuning on ['/local_nvme/']
Best performance (GB/sec): read = 3.69, write = 3.18
{
"aio": {
"single_submit": "false",
"overlap_events": "true",
"num_threads": 8,
"queue_depth": 32,
"block_size": 1048576
}
}
```
The above tuning was executed on a Lambda workstation equipped with two NVIDIA A6000-48GB GPUs, 252GB of DRAM, and a [CS3040 NVMe 2TB SDD](https://www.pny.com/CS3040-M2-NVMe-SSD?sku=M280CS3040-2TB-RB) with peak read and write speeds of 5.6 GB/s and 4.3 GB/s respectively. The tuning required about four and half minutes. Based on the results, one can expect to achieve read and write transfer speeds of 3.69 GB/sec and 3.18 GB/sec respectively by using an `aio_handle` configured as below.
```python
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle(block_size=1048576,
queue_depth=32,
single_submit=False,
overlap_events=True,
num_threads=8)
```
The full command line options of `ds_nvme_tune` can be obtained via the normal `-h` or `--help`.
```bash
usage: ds_nvme_tune [-h] --nvme_dir NVME_DIR [NVME_DIR ...] [--sweep_config SWEEP_CONFIG] [--no_read] [--no_write] [--io_size IO_SIZE] [--gpu] [--gds] [--flush_page_cache] [--log_dir LOG_DIR] [--loops LOOPS] [--verbose]
options:
-h, --help show this help message and exit
--nvme_dir NVME_DIR [NVME_DIR ...]
Directory in which to perform I/O tests. A writeable directory on a NVMe device.
--sweep_config SWEEP_CONFIG
Performance sweep configuration json file.
--no_read Disable read performance measurements.
--no_write Disable write performance measurements.
--io_size IO_SIZE Number of I/O bytes to read/write for performance measurements.
--gpu Test tensor transfers between GPU device and NVME device.
--gds Run the sweep over NVIDIA GPUDirectStorage operator
--flush_page_cache Page cache will not be flushed and reported read speeds may be higher than actual ***Requires sudo access***.
--log_dir LOG_DIR Output directory for performance log files. Default is ./_aio_bench_logs
--loops LOOPS Count of operation repetitions
--verbose Print debugging information.
```
### DeepNVMe APIs
For convenience, we provide listing and brief descriptions of the DeepNVMe APIs.

View File

@ -298,7 +298,7 @@ if sys.platform == "win32":
else:
scripts = [
'bin/deepspeed', 'bin/deepspeed.pt', 'bin/ds', 'bin/ds_ssh', 'bin/ds_report', 'bin/ds_bench', 'bin/dsr',
'bin/ds_elastic'
'bin/ds_elastic', 'bin/ds_nvme_tune', 'bin/ds_io'
]
start_time = time.time()