update formatter version and style settings (#3098)

This commit is contained in:
Jeff Rasley
2023-03-27 04:55:19 -07:00
committed by GitHub
parent b3ec1c9712
commit 91d63e0228
325 changed files with 5235 additions and 12298 deletions

View File

@ -22,8 +22,8 @@ repos:
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
- id: yapf

View File

@ -1,3 +1,3 @@
[style]
SPLIT_ALL_COMMA_SEPARATED_VALUES = true
COLUMN_LIMIT = 89
SPLIT_ALL_COMMA_SEPARATED_VALUES = false
COLUMN_LIMIT = 119

View File

@ -5,6 +5,7 @@ from abc import ABC
class DeepSpeedAccelerator(ABC):
def __init__(self):
self._name = None
self._communication_backend_name = None

View File

@ -14,6 +14,7 @@ except ImportError:
class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
@ -26,9 +27,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):
for _, module_name, _ in pkgutil.iter_modules([os.path.dirname(op_builder_module.__file__)]):
# avoid self references
if module_name != 'all_ops' and module_name != 'builder':
module = importlib.import_module("{}.{}".format(
op_builder_dir,
module_name))
module = importlib.import_module("{}.{}".format(op_builder_dir, module_name))
for member_name in module.__dir__():
if member_name.endswith(
'Builder'

View File

@ -23,13 +23,8 @@ def _validate_accelerator(accel_obj):
# accelerator.abstractor_accelerator
# or deepspeed.accelerator.abstract_accelerator, consider accel_obj
# is a conforming object
if not ((dsa1 != None and isinstance(accel_obj,
dsa1)) or
(dsa2 != None and isinstance(accel_obj,
dsa2))):
raise AssertionError(
f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator'
)
if not ((dsa1 != None and isinstance(accel_obj, dsa1)) or (dsa2 != None and isinstance(accel_obj, dsa2))):
raise AssertionError(f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator')
# TODO: turn off is_available test since this breaks tests
#assert accel_obj.is_available(), \

View File

@ -22,9 +22,7 @@ def timed_all_gather(input, output, args):
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
@ -38,9 +36,7 @@ def timed_all_gather(input, output, args):
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
@ -58,8 +54,7 @@ def timed_all_gather(input, output, args):
if not args.raw:
size = convert_size(size)
print_rank_0(
f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_all_gather(local_rank, args):
@ -84,22 +79,15 @@ def run_all_gather(local_rank, args):
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
@ -110,41 +98,32 @@ def run_all_gather(local_rank, args):
timed_all_gather(input, output, args)
else:
# all_gather_base saves memory
if (args.dist == 'torch'
and hasattr(torch.distributed,
"_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
if (args.dist == 'torch' and hasattr(torch.distributed, "_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
sync_all()
elements_per_gpu = max_numel(comm_op='all_gather',
dtype=getattr(torch,
args.dtype),
dtype=getattr(torch, args.dtype),
mem_factor=mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(
elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return

View File

@ -37,8 +37,7 @@ def timed_all_reduce(input, args):
if not args.raw:
size = convert_size(size)
print_rank_0(
f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_all_reduce(local_rank, args):
@ -63,12 +62,8 @@ def run_all_reduce(local_rank, args):
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
@ -83,23 +78,18 @@ def run_all_reduce(local_rank, args):
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
elements_per_gpu = max_numel(comm_op='all_reduce',
dtype=getattr(torch,
args.dtype),
dtype=getattr(torch, args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
sync_all()

View File

@ -37,8 +37,7 @@ def timed_all_to_all(input, output, args):
if not args.raw:
size = convert_size(size)
print_rank_0(
f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_all_to_all(local_rank, args):
@ -62,12 +61,8 @@ def run_all_to_all(local_rank, args):
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
@ -83,31 +78,25 @@ def run_all_to_all(local_rank, args):
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='all_to_all',
dtype=getattr(torch,
args.dtype),
dtype=getattr(torch, args.dtype),
mem_factor=args.mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
assert mat.numel(
) % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(
elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
output = torch.zeros(elements_per_gpu,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
sync_all()

View File

@ -38,8 +38,7 @@ def timed_broadcast(input, args):
if not args.raw:
size = convert_size(size)
print_rank_0(
f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_broadcast(local_rank, args):
@ -64,12 +63,8 @@ def run_broadcast(local_rank, args):
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
@ -84,23 +79,18 @@ def run_broadcast(local_rank, args):
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
elements_per_gpu = max_numel(comm_op='broadcast',
dtype=getattr(torch,
args.dtype),
dtype=getattr(torch, args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
sync_all()

View File

@ -56,8 +56,7 @@ def timed_pt2pt(input, args):
if not args.raw:
size = convert_size(size)
print_rank_0(
f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def run_pt2pt(local_rank, args):
@ -82,12 +81,8 @@ def run_pt2pt(local_rank, args):
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(world_size, M,
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
@ -102,23 +97,18 @@ def run_pt2pt(local_rank, args):
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
elements_per_gpu = max_numel(comm_op='pt2pt',
dtype=getattr(torch,
args.dtype),
dtype=getattr(torch, args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
sync_all()
return
sync_all()

View File

@ -120,8 +120,7 @@ def max_numel(comm_op, dtype, mem_factor, local_rank, args):
# Number of elements must be divisible by world_size
# all_to_all performance is lower for non-powers of two. Round down like all_gather.
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elements_per_gpu = int(dist.get_world_size() *
round(elements_per_gpu / dist.get_world_size()))
elements_per_gpu = int(dist.get_world_size() * round(elements_per_gpu / dist.get_world_size()))
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
else:
print(f"This communication operation: {comm_op} is not supported yet")
@ -162,59 +161,32 @@ def _element_size(dtype):
def benchmark_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
parser.add_argument("--trials",
type=int,
default=DEFAULT_TRIALS,
help='Number of timed iterations')
parser.add_argument("--warmups",
type=int,
default=DEFAULT_WARMUPS,
help='Number of warmup (non-timed) iterations')
parser.add_argument("--maxsize",
type=int,
default=24,
help='Max message size as a power of 2')
parser.add_argument("--async-op",
action="store_true",
help='Enables non-blocking communication')
parser.add_argument("--bw-unit",
type=str,
default=DEFAULT_UNIT,
choices=['Gbps',
'GBps'])
parser.add_argument("--trials", type=int, default=DEFAULT_TRIALS, help='Number of timed iterations')
parser.add_argument("--warmups", type=int, default=DEFAULT_WARMUPS, help='Number of warmup (non-timed) iterations')
parser.add_argument("--maxsize", type=int, default=24, help='Max message size as a power of 2')
parser.add_argument("--async-op", action="store_true", help='Enables non-blocking communication')
parser.add_argument("--bw-unit", type=str, default=DEFAULT_UNIT, choices=['Gbps', 'GBps'])
parser.add_argument("--backend",
type=str,
default=DEFAULT_BACKEND,
choices=['nccl',
'ccl'],
choices=['nccl', 'ccl'],
help='Communication library to use')
parser.add_argument("--dist",
type=str,
default=DEFAULT_DIST,
choices=['deepspeed',
'torch'],
choices=['deepspeed', 'torch'],
help='Distributed DL framework to use')
parser.add_argument("--scan",
action="store_true",
help='Enables scanning all message sizes')
parser.add_argument("--raw",
action="store_true",
help='Print the message size and latency without units')
parser.add_argument("--scan", action="store_true", help='Enables scanning all message sizes')
parser.add_argument("--raw", action="store_true", help='Print the message size and latency without units')
parser.add_argument("--all-reduce", action="store_true", help='Run all_reduce')
parser.add_argument("--all-gather", action="store_true", help='Run all_gather')
parser.add_argument("--all-to-all", action="store_true", help='Run all_to_all')
parser.add_argument("--pt2pt", action="store_true", help='Run pt2pt')
parser.add_argument("--broadcast", action="store_true", help='Run broadcast')
parser.add_argument("--dtype",
type=str,
default=DEFAULT_TYPE,
help='PyTorch tensor dtype')
parser.add_argument(
"--mem-factor",
type=float,
default=.4,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug",
action="store_true",
help='Enables all_to_all debug prints')
parser.add_argument("--dtype", type=str, default=DEFAULT_TYPE, help='PyTorch tensor dtype')
parser.add_argument("--mem-factor",
type=float,
default=.4,
help='Proportion of max available GPU memory to use for single-size evals')
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
return parser

View File

@ -13,21 +13,9 @@ parser.add_argument(
default="./results",
help="directory containing sweep results",
)
parser.add_argument("--version",
"-v",
type=int,
default=0,
help="version to be collected")
parser.add_argument("--gen-text-n",
"-n",
type=int,
default=1,
help="expected number of generated text")
parser.add_argument("--output",
"-o",
type=str,
default="./results.csv",
help="output file")
parser.add_argument("--version", "-v", type=int, default=0, help="version to be collected")
parser.add_argument("--gen-text-n", "-n", type=int, default=1, help="expected number of generated text")
parser.add_argument("--output", "-o", type=str, default="./results.csv", help="output file")
args = parser.parse_args()
@ -107,9 +95,7 @@ if __name__ == "__main__":
params = get_benchmark_params(args.results_dir, file_path)
if not params:
print(
f"WARNING: Could not detect benchmark settings for file {file_path}, skipping"
)
print(f"WARNING: Could not detect benchmark settings for file {file_path}, skipping")
continue
# Verify that the version matches that which we want to collect
@ -121,9 +107,7 @@ if __name__ == "__main__":
perf_data = get_perf_data(file_content)
if not perf_data:
print(
f"WARNING: Could not detect benchmark performance data for file {file_path}"
)
print(f"WARNING: Could not detect benchmark performance data for file {file_path}")
generated_text = get_generated_text(file_content, args.gen_text_n)
if not generated_text:
@ -135,12 +119,7 @@ if __name__ == "__main__":
benchmarks_data.append({"branch": branch, **params, **error})
continue
benchmarks_data.append({
"branch": branch,
**params,
**perf_data,
**generated_text
})
benchmarks_data.append({"branch": branch, **params, **perf_data, **generated_text})
# Convert to a DataFrame and save
benchmarks_df = pd.DataFrame(benchmarks_data)

View File

@ -11,26 +11,12 @@ from deepspeed.accelerator import get_accelerator
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, help="hf model name")
parser.add_argument("--deepspeed", action="store_true", help="use deepspeed inference")
parser.add_argument("--dtype",
type=str,
default="fp16",
choices=["fp16",
"fp32",
"int8"],
help="int8, fp16, or fp32")
parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32", "int8"], help="int8, fp16, or fp32")
parser.add_argument("--graphs", action="store_true", help="CUDA Graphs on")
parser.add_argument("--kernel-inject", action="store_true", help="inject kernels on")
parser.add_argument("--max-tokens", type=int, default=50, help="max new tokens")
parser.add_argument("--local_rank",
type=int,
default=int(os.getenv("LOCAL_RANK",
"0")),
help="local rank")
parser.add_argument("--world_size",
type=int,
default=int(os.getenv("WORLD_SIZE",
"1")),
help="world size")
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world size")
parser.add_argument("--trials", type=int, default=30, help="number of trials")
args = parser.parse_args()
@ -81,10 +67,7 @@ elif args.dtype == "fp16":
else:
dtype = torch.float32
pipe = pipeline("text-generation",
model=args.model,
framework="pt",
device=args.local_rank)
pipe = pipeline("text-generation", model=args.model, framework="pt", device=args.local_rank)
if dtype == torch.float16:
pipe.model.half()
@ -115,9 +98,7 @@ for i in range(args.trials):
if args.local_rank == 0:
print_latency(times, "(e2e) latency")
print_latency(mtimes, "(model-only) latency")
print_latency(map(lambda t: t / (args.max_tokens - 3),
times),
"(e2e) per token latency")
print_latency(map(lambda t: t / (args.max_tokens - 3), times), "(e2e) per token latency")
print(f"RESPONSE 0:")
print("-" * 30)
print(responses[0][0]["generated_text"])

View File

@ -9,11 +9,7 @@ from deepspeed.elasticity import compute_elastic_config
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
parser.add_argument('-w',
'--world-size',
type=int,
default=0,
help="Intended/current world size")
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
args = parser.parse_args()
ds_config = json.load(open(args.config, 'r'))
@ -26,7 +22,9 @@ if __name__ == '__main__':
print(json.dumps(elastic_config, indent=4, sort_keys=True))
if args.world_size > 0:
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config,
target_deepspeed_version=ds_version,
world_size=args.world_size)
print('------------------------------------------')
print(f"Calculated results for world size {args.world_size}:")
print('------------------------------------------')

View File

@ -14,13 +14,10 @@ from perf_sweep_utils import BENCH_LOG_DIR, READ_LOG_DIR, WRITE_LOG_DIR
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--log_dir',
type=str,
default=BENCH_LOG_DIR,
help=
f'Folder of performance sweep logs. Default is {os.path.join(".", BENCH_LOG_DIR)}'
)
parser.add_argument('--log_dir',
type=str,
default=BENCH_LOG_DIR,
help=f'Folder of performance sweep logs. Default is {os.path.join(".", BENCH_LOG_DIR)}')
args = parser.parse_args()
print(f'args = {args}')
@ -75,9 +72,7 @@ def generate_aio_param(read_log_dir, write_log_dir):
optimal_config_read = read_results.get(read_perf_keys[optimal_key], None)
optimal_config_write = write_results.get(write_perf_keys[optimal_key], None)
print(
f'Best performance (GB/sec): read = {optimal_config_read:5.2f}, write = {optimal_config_write:5.2f}'
)
print(f'Best performance (GB/sec): read = {optimal_config_read:5.2f}, write = {optimal_config_write:5.2f}')
print(json.dumps(aio_param, indent=3))

View File

@ -20,20 +20,16 @@ from deepspeed.ops.op_builder import AsyncIOBuilder
OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'test_ds_aio.py'
DEFAULT_SWEEP_CONFIG = {
"block_size": ["128K",
"256K"],
"queue_depth": [4,
16,
32],
"overlap_events": [True,
False],
"io_parallel": [2,
8],
"block_size": ["128K", "256K"],
"queue_depth": [4, 16, 32],
"overlap_events": [True, False],
"io_parallel": [2, 8],
"single_submit": [False]
}
class Job(object):
def __init__(self, cmd_line, output_file=None, work_dir=None):
self.cmd_line = cmd_line
self.output_file = output_file
@ -63,6 +59,7 @@ class Job(object):
class SweepConfig(object):
def __init__(self, args):
self.nvme_dir = args.nvme_dir
self.io_size = args.io_size
@ -78,52 +75,35 @@ class SweepConfig(object):
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--nvme_dir',
required=True,
type=str,
help=
'Directory in which to perform I/O tests. A writeable directory on a NVMe device.'
)
parser.add_argument('--sweep_config',
parser.add_argument('--nvme_dir',
required=True,
type=str,
default=None,
help='Performance sweep configuration json file.')
help='Directory in which to perform I/O tests. A writeable directory on a NVMe device.')
parser.add_argument('--no_read',
action='store_true',
help='Disable read performance measurements.')
parser.add_argument('--sweep_config', type=str, default=None, help='Performance sweep configuration json file.')
parser.add_argument('--no_write',
action='store_true',
help='Disable write performance measurements.')
parser.add_argument('--no_read', action='store_true', help='Disable read performance measurements.')
parser.add_argument(
'--io_size',
type=str,
default="400M",
help='Number of I/O bytes to read/write for performance measurements.')
parser.add_argument('--no_write', action='store_true', help='Disable write performance measurements.')
parser.add_argument('--io_size',
type=str,
default="400M",
help='Number of I/O bytes to read/write for performance measurements.')
parser.add_argument(
'--no_sudo',
action='store_true',
help=
'Run without sudo access. Page cache will not be flushed and reported read speeds may be higher than actual.'
)
'Run without sudo access. Page cache will not be flushed and reported read speeds may be higher than actual.')
parser.add_argument(
'--log_dir',
type=str,
default=BENCH_LOG_DIR,
help=
f'Output directory for performance log files. Default is {os.path.join(".", BENCH_LOG_DIR)}'
)
help=f'Output directory for performance log files. Default is {os.path.join(".", BENCH_LOG_DIR)}')
parser.add_argument('--loops',
type=int,
default=1,
help='Count of operation repetitions')
parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions')
args = parser.parse_args()
print(f'args = {args}')
@ -147,6 +127,7 @@ def get_sweep_config_dict(sweep_config_json):
def get_sweep_cmd_lines(sweep_config_dict):
def flatten_options(key, value_list):
flat_list = []
for v in value_list:
@ -170,11 +151,7 @@ def run_job(job):
args = ' '.join(job.cmd())
print(f'args = {args}')
job.open_output_file()
proc = subprocess.run(args=args,
shell=True,
stdout=job.get_stdout(),
stderr=job.get_stderr(),
cwd=job.get_cwd())
proc = subprocess.run(args=args, shell=True, stdout=job.get_stdout(), stderr=job.get_stderr(), cwd=job.get_cwd())
job.close_output_file()
assert proc.returncode == 0, \
f"This command failed: {job.cmd()}"
@ -240,14 +217,7 @@ def get_log_file(io_op_desc, cmd_line):
return tag_key
return f'{tag_key}{value}'
tag_list = [
SINGLE_SUBMIT,
OVERLAP_EVENTS,
THREAD_COUNT,
IO_PARALLEL,
QUEUE_DEPTH,
BLOCK_SIZE
]
tag_list = [SINGLE_SUBMIT, OVERLAP_EVENTS, THREAD_COUNT, IO_PARALLEL, QUEUE_DEPTH, BLOCK_SIZE]
log_tags = [io_op_desc]
cmd_tags = create_cmd_tags(cmd_line)
for tag in tag_list:
@ -298,16 +268,10 @@ def create_read_file(sweep_config):
os.makedirs(read_folder, exist_ok=True)
read_file_name = os.path.join(read_folder, f'random_{sweep_config.io_size}B.pt')
block_size, block_count = get_block_size_and_count(refine_integer_value(sweep_config.io_size))
dd_job = Job(cmd_line=[
f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}'
])
print(
f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....'
)
dd_job = Job(cmd_line=[f'dd if=/dev/urandom of={read_file_name} bs={block_size} count={block_count}'])
print(f'[Start] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
run_job(dd_job)
print(
f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....'
)
print(f'[Done] Create read file of {sweep_config.io_size} bytes by running {dd_job.cmd()} ....')
return read_folder, read_file_name
@ -319,20 +283,15 @@ def remove_folder(folder):
def run_read_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
read_folder, read_file_name = create_read_file(sweep_config)
read_option = f'--read_file {read_file_name}'
read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd
for cmd in cmd_lines]
read_cmd_lines = [[f'{read_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
#dump_cmd_lines(read_cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{READ_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
perf_jobs = create_perf_jobs(io_op_desc=READ_OP_DESC,
log_dir=log_folder,
cmd_lines=read_cmd_lines)
perf_jobs = create_perf_jobs(io_op_desc=READ_OP_DESC, log_dir=log_folder, cmd_lines=read_cmd_lines)
launch_sweep(sweep_jobs=perf_jobs,
sync_job=sync_job,
flush_cache_job=flush_cache_job)
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
remove_folder(read_folder)
@ -342,20 +301,15 @@ def run_write_sweep(sweep_config, flush_cache_job, sync_job, cmd_lines):
os.makedirs(write_folder, exist_ok=True)
write_file_name = os.path.join(write_folder, f'random_{sweep_config.io_size}B.pt')
write_option = f'--write_size {sweep_config.io_size} --write_file {write_file_name}'
write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd
for cmd in cmd_lines]
write_cmd_lines = [[f'{write_option} {sweep_config.other_options}'] + cmd for cmd in cmd_lines]
#dump_cmd_lines(write_cmd_lines)
log_folder = os.path.join(sweep_config.log_dir, f'{WRITE_LOG_DIR}')
os.makedirs(log_folder, exist_ok=True)
perf_jobs = create_perf_jobs(io_op_desc=WRITE_OP_DESC,
log_dir=log_folder,
cmd_lines=write_cmd_lines)
perf_jobs = create_perf_jobs(io_op_desc=WRITE_OP_DESC, log_dir=log_folder, cmd_lines=write_cmd_lines)
launch_sweep(sweep_jobs=perf_jobs,
sync_job=sync_job,
flush_cache_job=flush_cache_job)
launch_sweep(sweep_jobs=perf_jobs, sync_job=sync_job, flush_cache_job=flush_cache_job)
remove_folder(write_folder)
@ -376,10 +330,7 @@ def main():
cmd_lines = get_sweep_cmd_lines(sweep_config.search_space)
if sweep_config.flush_cache:
flush_cache_job = Job(
cmd_line=['sudo',
'bash -c',
"'echo 1 > /proc/sys/vm/drop_caches'"])
flush_cache_job = Job(cmd_line=['sudo', 'bash -c', "'echo 1 > /proc/sys/vm/drop_caches'"])
else:
flush_cache_job = None

View File

@ -20,14 +20,8 @@ def pre_basic(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = get_accelerator().pin_memory(
torch.empty(num_bytes,
dtype=torch.uint8,
device='cpu'))
task_log(
tid,
f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
)
buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
ctxt = {}
ctxt['file'] = file
@ -60,13 +54,8 @@ def post_basic(pool_params):
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, args.overlap_events, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
@ -76,13 +65,8 @@ def main_basic_read(pool_params):
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, args.overlap_events, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time

View File

@ -20,27 +20,17 @@ def pre_handle(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'
io_parallel = args.io_parallel if args.io_parallel else 1
handle = AsyncIOBuilder().load().aio_handle(args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
io_parallel)
handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
args.overlap_events, io_parallel)
task_log(tid, f'Created deepspeed aio handle')
if args.gpu:
buffer = torch.empty(num_bytes,
dtype=torch.uint8,
device=get_accelerator().device_name())
buffer = torch.empty(num_bytes, dtype=torch.uint8, device=get_accelerator().device_name())
else:
if args.use_accelerator_pin_memory:
buffer = get_accelerator().pin_memory(
torch.empty(num_bytes,
dtype=torch.uint8,
device='cpu'))
buffer = get_accelerator().pin_memory(torch.empty(num_bytes, dtype=torch.uint8, device='cpu'))
else:
buffer = handle.new_cpu_locked_tensor(num_bytes,
torch.empty(0,
dtype=torch.uint8))
buffer = handle.new_cpu_locked_tensor(num_bytes, torch.empty(0, dtype=torch.uint8))
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
@ -51,10 +41,7 @@ def pre_handle(args, tid, read_op):
ctxt['buffer'] = buffer
ctxt['elapsed_sec'] = 0
task_log(
tid,
f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
)
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
return ctxt

View File

@ -19,10 +19,7 @@ METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir',
type=str,
required=True,
help='Folder of statistics logs')
parser.add_argument('--log_dir', type=str, required=True, help='Folder of statistics logs')
parser.add_argument('--metric',
type=str,
@ -125,10 +122,7 @@ def get_results(log_files, metric):
def get_sorted_results(log_dir, metric):
log_files = [
f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir,
f))
]
log_files = [f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, f))]
log_files_path = [os.path.join(log_dir, f) for f in log_files]
results = get_results(log_files_path, metric)

View File

@ -20,46 +20,29 @@ def parse_arguments():
parser.add_argument('--write_file', type=str, default=None, help='Write file.')
parser.add_argument('--write_size',
type=str,
default=None,
help='Number of bytes to write.')
parser.add_argument('--write_size', type=str, default=None, help='Number of bytes to write.')
parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.')
parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.')
parser.add_argument('--threads',
type=int,
default=1,
help='Thread parallelism count.')
parser.add_argument('--threads', type=int, default=1, help='Thread parallelism count.')
parser.add_argument(
'--single_submit',
action='store_true',
help=
'Submit I/O requests in singles (default is submit queue_depth amount at once.).'
)
parser.add_argument('--single_submit',
action='store_true',
help='Submit I/O requests in singles (default is submit queue_depth amount at once.).')
parser.add_argument('--overlap_events',
action='store_true',
help='Overlap I/O submission and completion requests.')
parser.add_argument('--validate',
action='store_true',
help='Perform validation in library.')
parser.add_argument('--validate', action='store_true', help='Perform validation in library.')
parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
parser.add_argument('--loops',
type=int,
default=1,
help='Count of operation repetitions')
parser.add_argument('--loops', type=int, default=1, help='Count of operation repetitions')
parser.add_argument('--io_parallel',
type=int,
default=None,
help='Per iop parallelism')
parser.add_argument('--io_parallel', type=int, default=None, help='Per iop parallelism')
parser.add_argument('--gpu', action='store_true', help='Use GPU memory')

View File

@ -51,12 +51,10 @@ __git_branch__ = git_branch
def initialize(args=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]] = None,
optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler,
DeepSpeedSchedulerCallable]] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
@ -110,10 +108,8 @@ def initialize(args=None,
* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
"""
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__,
__git_hash__,
__git_branch__),
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
__git_branch__),
ranks=[0])
# Disable zero.Init context if it's currently enabled
@ -147,12 +143,7 @@ def initialize(args=None,
config=config,
config_params=config_params)
return_items = [
engine,
engine.optimizer,
engine.training_dataloader,
engine.lr_scheduler
]
return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return tuple(return_items)
@ -171,38 +162,28 @@ def _add_core_arguments(parser):
"""
group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations')
group.add_argument(
'--deepspeed',
default=False,
action='store_true',
help=
'Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
group.add_argument('--deepspeed',
default=False,
action='store_true',
help='Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
group.add_argument('--deepspeed_config',
default=None,
type=str,
help='DeepSpeed json configuration file.')
group.add_argument('--deepspeed_config', default=None, type=str, help='DeepSpeed json configuration file.')
group.add_argument(
'--deepscale',
default=False,
action='store_true',
help=
'Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)'
)
group.add_argument('--deepscale',
default=False,
action='store_true',
help='Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
group.add_argument('--deepscale_config',
default=None,
type=str,
help='Deprecated DeepSpeed json configuration file.')
group.add_argument(
'--deepspeed_mpi',
default=False,
action='store_true',
help=
"Run via MPI, this will attempt to discover the necessary variables to initialize torch "
"distributed from the MPI environment")
group.add_argument('--deepspeed_mpi',
default=False,
action='store_true',
help="Run via MPI, this will attempt to discover the necessary variables to initialize torch "
"distributed from the MPI environment")
return parser
@ -278,10 +259,8 @@ def init_inference(model, config=None, **kwargs):
Returns:
A deepspeed.InferenceEngine wrapped model.
"""
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
__version__,
__git_hash__,
__git_branch__),
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
__git_branch__),
ranks=[0])
# Load config_dict from config first
@ -293,17 +272,14 @@ def init_inference(model, config=None, **kwargs):
elif isinstance(config, dict):
config_dict = config
else:
raise ValueError(
f"'config' argument expected string or dictionary, got {type(config)}")
raise ValueError(f"'config' argument expected string or dictionary, got {type(config)}")
# Update with values from kwargs, ensuring no conflicting overlap between config and kwargs
overlap_keys = set(config_dict.keys()).intersection(kwargs.keys())
# If there is overlap, error out if values are different
for key in overlap_keys:
if config_dict[key] != kwargs[key]:
raise ValueError(
f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}"
)
raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}")
config_dict.update(kwargs)
ds_inference_config = DeepSpeedInferenceConfig(**config_dict)

View File

@ -40,6 +40,7 @@ class Autotuner:
"""The DeepSpeed Autotuner automatically discovers the optimal DeepSpeed configuration that delivers good training speed. The Autotuner uses model information, system information, and heuristics to efficiently tune system knobs that affect compute and memory efficiencies, such as ZeRO optimization stages, micro-batch sizes, and many other ZeRO optimization configurations. It not only reduces the time and resources user spend on tuning, but also can discover configurations better than hand-tuned methods.
Autotuning with DeepSpeed requires no code change from DeepSpeed users. Please refer to the README for usage details.
"""
def __init__(self, args, active_resources):
self.args = args
self.selected_exp_dir = None
@ -92,7 +93,8 @@ class Autotuner:
assert self.exp_num_gpus <= self.rm.num_gpus_per_node, "num_gpus in the autotuning configuration must not be less than the --num_gpus value in the train script if any"
assert self.exp_num_nodes <= len(
self.rm.nodes), "num_nodes in the autotuning configuration must not be less than the --num_nodes value in the train script if any"
self.rm.nodes
), "num_nodes in the autotuning configuration must not be less than the --num_nodes value in the train script if any"
self.records = {}
self.optimal_cmd = None
@ -125,18 +127,10 @@ class Autotuner:
row.append(val[0]['name'])
tab.append(row)
summary = tabulate(tab,
headers=[
"tuning_space",
"num_experiments",
"best_metric_val",
"best_exp_name"
],
headers=["tuning_space", "num_experiments", "best_metric_val", "best_exp_name"],
tablefmt="pipe")
print(summary)
with open(os.path.join(self.results_dir,
'summary.txt'),
'w',
buffering=BUFSIZE) as fd:
with open(os.path.join(self.results_dir, 'summary.txt'), 'w', buffering=BUFSIZE) as fd:
fd.write(summary)
fd.flush()
os.fsync(fd)
@ -148,9 +142,7 @@ class Autotuner:
f"{best_exp['name']} is the optimal setup after tuning. The exp result is at {best_exp['result_dir']}."
)
else:
logger.info(
f"No optimal setup is found. Please check that experiments were run successfully."
)
logger.info(f"No optimal setup is found. Please check that experiments were run successfully.")
tuning_duration = datetime.timedelta(seconds=(time.time() - self.start_time))
logger.info(f"Tuning completed in {tuning_duration}")
@ -172,8 +164,8 @@ class Autotuner:
user_config_file = None
if "--deepspeed_config" in user_args:
idx = user_args.index("--deepspeed_config")
assert ".json" in user_args[idx +
1], "DeepSpeed --deepspeed_config requires a json file to specify the configuration"
assert ".json" in user_args[
idx + 1], "DeepSpeed --deepspeed_config requires a json file to specify the configuration"
user_config_file = user_args[idx + 1]
elif "--deepspeed" in user_args:
@ -183,15 +175,10 @@ class Autotuner:
logger.debug(f"user_config_file = {user_config_file}")
if user_config_file is not None:
assert os.path.isfile(
user_config_file
), "DeepSpeed configuration file: {} is not an existing file".format(
user_config_file
)
assert os.path.isfile(user_config_file), "DeepSpeed configuration file: {} is not an existing file".format(
user_config_file)
if os.path.exists(user_config_file):
return json.load(open(user_config_file,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
return json.load(open(user_config_file, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
return None
@ -258,13 +245,11 @@ class Autotuner:
return self.autotuning_config.mp_size
def max_train_micro_batch_size_per_gpu(self):
if self.max_train_batch_size() and self.max_train_batch_size(
) > 0: # if the user specifies a max_train_batch_size
max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size(
) // (self.exp_num_gpus * self.exp_num_nodes
) # gradient accumulation steps >=1
return min(self.autotuning_config.max_train_micro_batch_size_per_gpu,
max_train_micro_batch_size)
if self.max_train_batch_size(
) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
max_train_micro_batch_size = self.max_train_batch_size() * self.mp_size() // (
self.exp_num_gpus * self.exp_num_nodes) # gradient accumulation steps >=1
return min(self.autotuning_config.max_train_micro_batch_size_per_gpu, max_train_micro_batch_size)
else:
return self.autotuning_config.max_train_micro_batch_size_per_gpu
@ -361,19 +346,14 @@ class Autotuner:
if model_info and "hidden_size" in model_info:
hs = model_info["hidden_size"]
template_config[ZERO_OPTIMIZATION]['reduce_bucket_size'] = hs * hs
template_config[ZERO_OPTIMIZATION][
'stage3_prefetch_bucket_size'] = 0.9 * hs * hs
template_config[ZERO_OPTIMIZATION][
'stage3_param_persistence_threshold'] = 10 * hs
template_config[ZERO_OPTIMIZATION]['stage3_prefetch_bucket_size'] = 0.9 * hs * hs
template_config[ZERO_OPTIMIZATION]['stage3_param_persistence_threshold'] = 10 * hs
prefix = "z3_"
else:
return exps
# replace the corresponding parameter values if the user specifies them in the DeepSpeed configuration file
replace_dict(tuning_space,
self.user_config,
[ZERO_OPTIMIZATION,
TRAIN_MICRO_BATCH_SIZE_PER_GPU])
replace_dict(tuning_space, self.user_config, [ZERO_OPTIMIZATION, TRAIN_MICRO_BATCH_SIZE_PER_GPU])
logger.debug(f"tuning_space = {json.dumps(tuning_space)}")
@ -397,11 +377,9 @@ class Autotuner:
# if the config does not use offloading, remove the offloading section
config_zero = config.get(ZERO_OPTIMIZATION, None)
if config_zero:
if OFFLOAD_OPTIMIZER not in config_zero and OFFLOAD_OPTIMIZER in exp_config[
ZERO_OPTIMIZATION]:
if OFFLOAD_OPTIMIZER not in config_zero and OFFLOAD_OPTIMIZER in exp_config[ZERO_OPTIMIZATION]:
del exp_config[ZERO_OPTIMIZATION][OFFLOAD_OPTIMIZER]
if OFFLOAD_PARAM not in config_zero and OFFLOAD_PARAM in exp_config[
ZERO_OPTIMIZATION]:
if OFFLOAD_PARAM not in config_zero and OFFLOAD_PARAM in exp_config[ZERO_OPTIMIZATION]:
del exp_config[ZERO_OPTIMIZATION][OFFLOAD_PARAM]
# set gradient accumulation steps according to max_train_batch_size_per_gpu
mbs = exp_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU]
@ -438,13 +416,10 @@ class Autotuner:
else:
return
logger.info(
f"The model has {number_to_string(self.get_model_num_params())} parameters.")
logger.info(f"The model has {number_to_string(self.get_model_num_params())} parameters.")
self.gpu_mem = self.get_gpu_memory_info()
logger.info(
f"Memory per GPU in the system is {memory_to_string(self.gpu_mem, postfix='B')}."
)
logger.info(f"Memory per GPU in the system is {memory_to_string(self.gpu_mem, postfix='B')}.")
self.activation_mem = self.get_activation_memory_per_gpu()
logger.info(
@ -452,9 +427,7 @@ class Autotuner:
)
#TODO: FIX THIS
stage = self.user_config.get(ZERO_OPTIMIZATION,
{}).get(ZERO_OPTIMIZATION_STAGE,
"all")
stage = self.user_config.get(ZERO_OPTIMIZATION, {}).get(ZERO_OPTIMIZATION_STAGE, "all")
stage = "all"
user_zero_stages = [stage] if not isinstance(stage, list) else stage
logger.info(f"User-defined zero stages are {stage}.")
@ -463,15 +436,13 @@ class Autotuner:
max_mbs = 0
metric_val = 0
required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
ZeroStageEnum.disabled) + self.activation_mem
required_gpu_mem = self.get_instantiation_memory_required_per_gpu(ZeroStageEnum.disabled) + self.activation_mem
if self.gpu_mem > required_gpu_mem:
if "all" in user_zero_stages or ZeroStageEnum.disabled in user_zero_stages:
logger.info(
f"The model might be runable with ZERO 0 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1), adding DEFAULT_TUNING_SPACE_ZERO_0 to the global tuning space"
)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(
DEFAULT_TUNING_SPACE_ZERO_0)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_0)
if next_mbs > mbs:
mbs = next_mbs
max_mbs = next_max_mbs
@ -490,8 +461,10 @@ class Autotuner:
logger.info(
f"The model might be runable with ZERO 1 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_1 to the global tuning space"
)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(
DEFAULT_TUNING_SPACE_ZERO_1, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_1,
prev_max_mbs=max_mbs,
prev_best_mbs=mbs,
prev_best_metric_val=metric_val)
if next_mbs > mbs:
mbs = next_mbs
max_mbs = next_max_mbs
@ -510,8 +483,10 @@ class Autotuner:
logger.info(
f"The model might be runable with ZERO 2 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_2 to the global tuning space"
)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(
DEFAULT_TUNING_SPACE_ZERO_2, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
next_max_mbs, next_mbs, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_2,
prev_max_mbs=max_mbs,
prev_best_mbs=mbs,
prev_best_metric_val=metric_val)
if next_mbs > mbs:
mbs = next_mbs
max_mbs = next_max_mbs
@ -523,15 +498,16 @@ class Autotuner:
f"The model is not runable with ZERO stage {ZeroStageEnum.gradients} (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory with mbs = 1)"
)
required_gpu_mem = self.get_instantiation_memory_required_per_gpu(
ZeroStageEnum.weights) + self.activation_mem
required_gpu_mem = self.get_instantiation_memory_required_per_gpu(ZeroStageEnum.weights) + self.activation_mem
if self.gpu_mem > required_gpu_mem:
if "all" in user_zero_stages or ZeroStageEnum.weights in user_zero_stages:
logger.info(
f"The model might be runable with ZERO 3 (which requires at least {memory_to_string(required_gpu_mem, postfix='B')} memory), adding DEFAULT_TUNING_SPACE_ZERO_3 to the global tuning space"
)
_, _, next_metric_val = self.tune_space(
DEFAULT_TUNING_SPACE_ZERO_3, prev_max_mbs = max_mbs, prev_best_mbs=mbs, prev_best_metric_val=metric_val)
_, _, next_metric_val = self.tune_space(DEFAULT_TUNING_SPACE_ZERO_3,
prev_max_mbs=max_mbs,
prev_best_mbs=mbs,
prev_best_metric_val=metric_val)
if has_mlflow:
mlflow.log_metric(f"z3{self.metric()}", next_metric_val)
else:
@ -542,11 +518,7 @@ class Autotuner:
if has_mlflow:
mlflow.end_run()
def tune_space(self,
tuning_space,
prev_max_mbs=0,
prev_best_mbs=0,
prev_best_metric_val=0):
def tune_space(self, tuning_space, prev_max_mbs=0, prev_best_mbs=0, prev_best_metric_val=0):
config_zero = tuning_space.get(ZERO_OPTIMIZATION, {})
stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, None)
tuning_space_name = TUNING_MICRO_BATCH_SIZE_PREFIX + str(stage)
@ -557,26 +529,20 @@ class Autotuner:
# calculate max micro batch size using gpu memory, model instantiation memory and activation memory
# calculated_max_micro_batch_size = (memory_per_gpu - instantiation_memory) // activation_memory_micro_batch_size_1
calculated_max_micro_batch_size = int(
self.gpu_mem -
self.get_instantiation_memory_required_per_gpu(stage)) // self.activation_mem
self.gpu_mem - self.get_instantiation_memory_required_per_gpu(stage)) // self.activation_mem
logger.info(
f"Start tuning for space {tuning_space_name}, calculated_max_micro_batch_size = {calculated_max_micro_batch_size}"
)
if calculated_max_micro_batch_size < prev_max_mbs:
logger.info(
f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}"
)
logger.info(f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}")
return 0, 0, 0
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU],
list):
self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU], list):
# user-specified micro batch size per gpu is a list which overwrites the default tuning behavior
tuning_micro_batch_sizes = [
s for s in self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU]
if isinstance(s,
int)
s for s in self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] if isinstance(s, int)
]
gas = self.get_gas_from_user_config()
min_micro_batch_size = min(tuning_micro_batch_sizes)
@ -589,9 +555,7 @@ class Autotuner:
stage, prev_max_mbs, calculated_max_micro_batch_size)
if max_micro_batch_size < prev_max_mbs:
logger.info(
f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}"
)
logger.info(f"No need to tune Zero stage {stage}. End tuning for space {tuning_space_name}")
return 0, 0, 0
tuning_micro_batch_sizes, max_train_batch_size_per_gpu = self.get_tuning_micro_batch_size_list(
@ -609,19 +573,15 @@ class Autotuner:
return 0, 0, 0
# tune micro batch sizes and gradient accumulation steps given max_train_batch_size_per_gpu
tuning_micro_batch_sizes = self.run_tuning_micro_batch_sizes(
tuning_micro_batch_sizes,
max_train_batch_size_per_gpu,
min_micro_batch_size,
stage,
tuning_micro_batch_sizes_overwritten)
tuning_micro_batch_sizes = self.run_tuning_micro_batch_sizes(tuning_micro_batch_sizes,
max_train_batch_size_per_gpu,
min_micro_batch_size, stage,
tuning_micro_batch_sizes_overwritten)
fast_best_record = self.get_best_space_record(tuning_space_name)
fast_best_metric_val = fast_best_record[1] if fast_best_record else 0
fast_best_mbs = fast_best_record[0][DS_CONFIG][
TRAIN_MICRO_BATCH_SIZE_PER_GPU] if fast_best_record else 0
logger.info(
f"fast_best_mbs = {fast_best_mbs}, name = {fast_best_record[0]['name']}")
fast_best_mbs = fast_best_record[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] if fast_best_record else 0
logger.info(f"fast_best_mbs = {fast_best_mbs}, name = {fast_best_record[0]['name']}")
if self.fast_enabled() or stage == 0:
logger.info(f"End tuning for space: {tuning_space_name}")
@ -631,8 +591,7 @@ class Autotuner:
if stage > 0:
if fast_best_mbs <= prev_best_mbs or fast_best_metric_val < prev_best_metric_val:
logger.info(
f"End tuning for space: {tuning_space_name}. No need to tune other Zero configuration parameters."
)
f"End tuning for space: {tuning_space_name}. No need to tune other Zero configuration parameters.")
return max_micro_batch_size, fast_best_mbs, fast_best_metric_val
tuning_space[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = tuning_micro_batch_sizes
@ -654,8 +613,7 @@ class Autotuner:
else:
t = GridSearchTuner(exps, self.rm, self.metric())
sample_size = len(self.rm.nodes) * self.rm.num_gpus_per_node // (
self.exp_num_gpus * self.exp_num_nodes)
sample_size = len(self.rm.nodes) * self.rm.num_gpus_per_node // (self.exp_num_gpus * self.exp_num_nodes)
num_exps = t.tune(sample_size=sample_size,
n_trials=self.autotuning_config.tuner_num_trials,
early_stopping=self.autotuning_config.tuner_early_stopping)
@ -669,8 +627,7 @@ class Autotuner:
if full_best_metric_val > fast_best_metric_val:
best_metric_val = full_best_metric_val
best_mbs = full_best_record[0][DS_CONFIG][
TRAIN_MICRO_BATCH_SIZE_PER_GPU] if full_best_record else -1
best_mbs = full_best_record[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] if full_best_record else -1
else:
best_metric_val = fast_best_metric_val
best_mbs = fast_best_mbs
@ -682,9 +639,7 @@ class Autotuner:
if tuning_space_name not in self.records:
return 0
space_records = self.records[tuning_space_name]
sorted_space_records = sorted(
space_records,
key=lambda x: x[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU])
sorted_space_records = sorted(space_records, key=lambda x: x[0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU])
prev_metric_val = None
prev_micro_batch_size = 0
for (exp, metric_val, _) in sorted_space_records:
@ -692,8 +647,7 @@ class Autotuner:
if metric_val < prev_metric_val:
break
if (metric_val >= prev_metric_val
and (metric_val - prev_metric_val) / prev_metric_val <
METRIC_PERCENT_DIFF_CONST):
and (metric_val - prev_metric_val) / prev_metric_val < METRIC_PERCENT_DIFF_CONST):
break
prev_metric_val = metric_val
prev_micro_batch_size = exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]
@ -718,16 +672,8 @@ class Autotuner:
ds_config = copy.deepcopy(self.user_config)
replace_dict(ds_config, DEFAULT_MIN_MEM_CONFIG)
model_info_path = os.path.join(self.results_dir,
"profile_model_info",
"model_info.json")
ds_config[AUTOTUNING] = {
"enabled": True,
"model_info_path": model_info_path,
"model_info": {
"profile": True
}
}
model_info_path = os.path.join(self.results_dir, "profile_model_info", "model_info.json")
ds_config[AUTOTUNING] = {"enabled": True, "model_info_path": model_info_path, "model_info": {"profile": True}}
exp_config = {}
exp_name = "profile_model_info"
@ -748,8 +694,7 @@ class Autotuner:
for exp_id, (exp_json, err) in self.rm.finished_experiments.items():
self.rm.clear()
if err:
logger.error(
f"The model is not runnable with DeepSpeed with error = {err}")
logger.error(f"The model is not runnable with DeepSpeed with error = {err}")
return None
if os.path.exists(model_info_path):
@ -790,12 +735,8 @@ class Autotuner:
best_space_records[GLOBAL_TUNING_SPACE] = global_best_record
return best_space_records
def run_tuning_micro_batch_sizes(self,
tuning_micro_batch_sizes,
max_train_batch_size_per_gpu,
min_micro_batch_size,
stage,
tuning_micro_batch_sizes_overwritten):
def run_tuning_micro_batch_sizes(self, tuning_micro_batch_sizes, max_train_batch_size_per_gpu,
min_micro_batch_size, stage, tuning_micro_batch_sizes_overwritten):
assert tuning_micro_batch_sizes, "the tuning micro batch size list is empty"
tuning_micro_batch_sizes.sort()
max_micro_batch_size = tuning_micro_batch_sizes[-1]
@ -838,8 +779,7 @@ class Autotuner:
results = hjson.load(f)
metric_val = results[self.metric()]
self.update_records(tuning_space_name, exp, metric_val, 1)
if max_micro_batch_size == exp[DS_CONFIG][
TRAIN_MICRO_BATCH_SIZE_PER_GPU]:
if max_micro_batch_size == exp[DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU]:
max_micro_batch_size_metric_val = metric_val
if has_mlflow:
os.environ.pop('MLFLOW_RUN_ID')
@ -862,9 +802,8 @@ class Autotuner:
# in a auto-detected tuning_micro_batch_sizs list, max_micro_batch_size might not be performant as the memory consumption is close to max
# try smaller values while gas stays the same
# if finding a more performant mbs value, use it to replace max_micro_batch_size in the list
min_micro_batch_size_with_same_gas = (
tuning_micro_batch_sizes[-2] +
1) if len(tuning_micro_batch_sizes) > 1 else min_micro_batch_size
min_micro_batch_size_with_same_gas = (tuning_micro_batch_sizes[-2] +
1) if len(tuning_micro_batch_sizes) > 1 else min_micro_batch_size
prev_best_metric_val = max_micro_batch_size_metric_val
prev_best_mbs = max_micro_batch_size
@ -872,10 +811,7 @@ class Autotuner:
stride = (max_micro_batch_size - min_micro_batch_size_with_same_gas) // 3
if stride == 0:
stride = 1
for mbs in reversed(
range(min_micro_batch_size_with_same_gas,
max_micro_batch_size,
stride)):
for mbs in reversed(range(min_micro_batch_size_with_same_gas, max_micro_batch_size, stride)):
ds_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = mbs
gas = max_train_batch_size_per_gpu // mbs
ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
@ -908,10 +844,7 @@ class Autotuner:
tuning_micro_batch_sizes[-1] = prev_best_mbs
return tuning_micro_batch_sizes
def get_min_max_micro_batch_size(self,
stage,
min_micro_batch_size,
calculated_max_micro_batch_size):
def get_min_max_micro_batch_size(self, stage, min_micro_batch_size, calculated_max_micro_batch_size):
# get min and max micro batch size with gradient accumulation steps = 1
if min_micro_batch_size > calculated_max_micro_batch_size:
return -1, -1
@ -927,8 +860,7 @@ class Autotuner:
# search for the min micro batch size
if min_micro_batch_size < 1:
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self.user_config and isinstance(
self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU],
int):
self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU], int):
# user specifies train_micro_batch_size_per_gpu as an int
mbs = int(self.user_config[TRAIN_MICRO_BATCH_SIZE_PER_GPU])
else:
@ -951,8 +883,7 @@ class Autotuner:
min_micro_batch_size = mbs
else:
self.update_records(tuning_space_name, exp, 0, 1)
logger.info(
f"User-specified micro batch size per GPU {mbs} does not run")
logger.info(f"User-specified micro batch size per GPU {mbs} does not run")
if self.min_train_micro_batch_size_per_gpu() == mbs:
return -1, -1
mbs = self.min_train_micro_batch_size_per_gpu()
@ -964,8 +895,7 @@ class Autotuner:
exp, metric_val = self.run_ds_config(ds_config, exp_name)
if not metric_val:
self.update_records(tuning_space_name, exp, 0, 1)
logger.info(
f"min_train_micro_batch_size_per_gpu {mbs} is not runnable.")
logger.info(f"min_train_micro_batch_size_per_gpu {mbs} is not runnable.")
return -1, -1
self.update_records(tuning_space_name, exp, metric_val, 1)
min_micro_batch_size = mbs
@ -975,8 +905,7 @@ class Autotuner:
ds_config[GRADIENT_ACCUMULATION_STEPS] = gas
ds_config[TRAIN_BATCH_SIZE] = min_micro_batch_size * gas * \
self.exp_num_gpus * self.exp_num_nodes // self.mp_size()
exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(
min_micro_batch_size)
exp_name = tuning_space_name + "_gas" + str(gas) + "_tmbspg" + str(min_micro_batch_size)
exp, metric_val = self.run_ds_config(ds_config, exp_name)
if metric_val:
self.update_records(tuning_space_name, exp, metric_val, 1)
@ -986,13 +915,8 @@ class Autotuner:
return -1, -1
# search for the max micro batch size
max_micro_batch_size = min(calculated_max_micro_batch_size,
self.max_train_micro_batch_size_per_gpu())
for mbs in [
math.ceil(1.05 * max_micro_batch_size),
max_micro_batch_size,
int(0.95 * max_micro_batch_size)
]:
max_micro_batch_size = min(calculated_max_micro_batch_size, self.max_train_micro_batch_size_per_gpu())
for mbs in [math.ceil(1.05 * max_micro_batch_size), max_micro_batch_size, int(0.95 * max_micro_batch_size)]:
if mbs > self.max_train_micro_batch_size_per_gpu():
continue
if mbs in used_micro_batch_sizes:
@ -1011,12 +935,11 @@ class Autotuner:
else:
self.update_records(tuning_space_name, exp, 0, 1)
space_records = self.records[
tuning_space_name] if tuning_space_name in self.records else []
space_records = self.records[tuning_space_name] if tuning_space_name in self.records else []
if space_records:
prev_idx = min(range(len(space_records)),
key=lambda i: abs(space_records[i][0][DS_CONFIG][
TRAIN_MICRO_BATCH_SIZE_PER_GPU] - min_micro_batch_size))
key=lambda i: abs(space_records[i][0][DS_CONFIG][TRAIN_MICRO_BATCH_SIZE_PER_GPU] -
min_micro_batch_size))
prev_metric_val = space_records[prev_idx][1]
else:
prev_metric_val = None
@ -1037,8 +960,8 @@ class Autotuner:
low = mid + 1
self.update_records(tuning_space_name, exp, metric_val, 1)
used_micro_batch_sizes.append(mid)
if prev_metric_val and ((metric_val - prev_metric_val) /
prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
if prev_metric_val and (
(metric_val - prev_metric_val) / prev_metric_val) < METRIC_PERCENT_DIFF_CONST:
logger.info(f"performance plateaus at mbs = {low}")
break
prev_metric_val = metric_val
@ -1049,9 +972,7 @@ class Autotuner:
low = mid + 1
max_micro_batch_size = low - 1
logger.info(
f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}."
)
logger.info(f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}.")
return min_micro_batch_size, max_micro_batch_size
@ -1067,8 +988,7 @@ class Autotuner:
gas = int(val)
elif isinstance(gas_in_config, list):
logger.info(
f"Specifying a list of {GRADIENT_ACCUMULATION_STEPS} to tune is not supported. 1 would be used."
)
f"Specifying a list of {GRADIENT_ACCUMULATION_STEPS} to tune is not supported. 1 would be used.")
assert gas > 0, "Gradient accumulation steps must be positive."
return gas
@ -1083,9 +1003,7 @@ class Autotuner:
return (user_args[idx + 1])
return None
def get_tuning_micro_batch_size_list(self,
min_micro_batch_size,
max_micro_batch_size,
def get_tuning_micro_batch_size_list(self, min_micro_batch_size, max_micro_batch_size,
num_tuning_micro_batch_sizes):
"""Get a list of micro batch sizes to tune based on min and max values, as well as the size of the list.
Args:
@ -1098,17 +1016,16 @@ class Autotuner:
"""
if min_micro_batch_size <= 0 or max_micro_batch_size <= 0:
logger.info(
f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}"
)
f"min_micro_batch_size = {min_micro_batch_size}, max_micro_batch_size = {max_micro_batch_size}")
return [], 0
# NUM_GPUS=$(( ${NUM_WORKERS} * ${NUM_GPUS_PER_WORKER} ))
# DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
# GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${BATCH_SIZE} * ${DP_SIZE}) ))
if self.max_train_batch_size() and self.max_train_batch_size(
) > 0: # if the user specifies a max_train_batch_size
max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size(
) // (self.exp_num_gpus * self.exp_num_nodes)
if self.max_train_batch_size(
) and self.max_train_batch_size() > 0: # if the user specifies a max_train_batch_size
max_train_batch_size_per_gpu = self.max_train_batch_size() * self.mp_size() // (self.exp_num_gpus *
self.exp_num_nodes)
else:
gas = self.get_gas_from_user_config()
max_train_batch_size_per_gpu = max_micro_batch_size * gas // self.mp_size()
@ -1117,8 +1034,7 @@ class Autotuner:
min_micro_batch_size = max_micro_batch_size // 2
# constant stride
stride = (max_micro_batch_size -
min_micro_batch_size) // num_tuning_micro_batch_sizes
stride = (max_micro_batch_size - min_micro_batch_size) // num_tuning_micro_batch_sizes
if stride == 0:
stride = 1
ls = []
@ -1187,8 +1103,6 @@ class Autotuner:
result = subprocess.Popen(self.optimal_cmd)
result.wait()
logger.info(
f"Done running with the optimal DeepSpeed configuration using {self.optimal_cmd}"
)
logger.info(f"Done running with the optimal DeepSpeed configuration using {self.optimal_cmd}")
else:
logger.info(f"No optimal DeepSpeed configuration found by autotuning.")

View File

@ -9,6 +9,7 @@ from deepspeed.autotuning.constants import *
class DeepSpeedAutotuningConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedAutotuningConfig, self).__init__()
@ -31,102 +32,65 @@ class DeepSpeedAutotuningConfig(DeepSpeedConfigObject):
self._initialize(autotuning_dict)
def _initialize(self, autotuning_dict):
self.enabled = get_scalar_param(autotuning_dict,
AUTOTUNING_ENABLED,
AUTOTUNING_ENABLED_DEFAULT)
self.enabled = get_scalar_param(autotuning_dict, AUTOTUNING_ENABLED, AUTOTUNING_ENABLED_DEFAULT)
self.fast = get_scalar_param(autotuning_dict,
AUTOTUNING_FAST,
AUTOTUNING_FAST_DEFAULT)
self.fast = get_scalar_param(autotuning_dict, AUTOTUNING_FAST, AUTOTUNING_FAST_DEFAULT)
self.results_dir = get_scalar_param(autotuning_dict,
AUTOTUNING_RESULTS_DIR,
AUTOTUNING_RESULTS_DIR_DEFAULT)
self.results_dir = get_scalar_param(autotuning_dict, AUTOTUNING_RESULTS_DIR, AUTOTUNING_RESULTS_DIR_DEFAULT)
assert self.results_dir, "results_dir cannot be empty"
self.exps_dir = get_scalar_param(autotuning_dict,
AUTOTUNING_EXPS_DIR,
AUTOTUNING_EXPS_DIR_DEFAULT)
self.exps_dir = get_scalar_param(autotuning_dict, AUTOTUNING_EXPS_DIR, AUTOTUNING_EXPS_DIR_DEFAULT)
assert self.exps_dir, "exps_dir cannot be empty"
self.overwrite = get_scalar_param(autotuning_dict,
AUTOTUNING_OVERWRITE,
AUTOTUNING_OVERWRITE_DEFAULT)
self.overwrite = get_scalar_param(autotuning_dict, AUTOTUNING_OVERWRITE, AUTOTUNING_OVERWRITE_DEFAULT)
self.start_profile_step = get_scalar_param(
autotuning_dict,
AUTOTUNING_START_PROFILE_STEP,
AUTOTUNING_START_PROFILE_STEP_DEFAULT)
self.start_profile_step = get_scalar_param(autotuning_dict, AUTOTUNING_START_PROFILE_STEP,
AUTOTUNING_START_PROFILE_STEP_DEFAULT)
self.end_profile_step = get_scalar_param(autotuning_dict,
AUTOTUNING_END_PROFILE_STEP,
self.end_profile_step = get_scalar_param(autotuning_dict, AUTOTUNING_END_PROFILE_STEP,
AUTOTUNING_END_PROFILE_STEP_DEFAULT)
self.metric = get_scalar_param(autotuning_dict,
AUTOTUNING_METRIC,
AUTOTUNING_METRIC_DEFAULT)
self.metric = get_scalar_param(autotuning_dict, AUTOTUNING_METRIC, AUTOTUNING_METRIC_DEFAULT)
self.metric_path = get_scalar_param(autotuning_dict,
AUTOTUNING_METRIC_PATH,
AUTOTUNING_METRIC_PATH_DEFAULT)
self.metric_path = get_scalar_param(autotuning_dict, AUTOTUNING_METRIC_PATH, AUTOTUNING_METRIC_PATH_DEFAULT)
self.tuner_type = get_scalar_param(autotuning_dict,
AUTOTUNING_TUNER_TYPE,
AUTOTUNING_TUNER_TYPE_DEFAULT)
self.tuner_type = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_TYPE, AUTOTUNING_TUNER_TYPE_DEFAULT)
self.tuner_early_stopping = get_scalar_param(
autotuning_dict,
AUTOTUNING_TUNER_EARLY_STOPPING,
AUTOTUNING_TUNER_EARLY_STOPPING_DEFAULT)
self.tuner_early_stopping = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_EARLY_STOPPING,
AUTOTUNING_TUNER_EARLY_STOPPING_DEFAULT)
self.tuner_num_trials = get_scalar_param(autotuning_dict,
AUTOTUNING_TUNER_NUM_TRIALS,
self.tuner_num_trials = get_scalar_param(autotuning_dict, AUTOTUNING_TUNER_NUM_TRIALS,
AUTOTUNING_TUNER_NUM_TRIALS_DEFAULT)
self.arg_mappings = get_dict_param(autotuning_dict,
AUTOTUNING_ARG_MAPPINGS,
AUTOTUNING_ARG_MAPPINGS_DEFAULT)
self.arg_mappings = get_dict_param(autotuning_dict, AUTOTUNING_ARG_MAPPINGS, AUTOTUNING_ARG_MAPPINGS_DEFAULT)
self.model_info = get_model_info_config(autotuning_dict)
self.model_info_path = get_scalar_param(autotuning_dict,
AUTOTUNING_MODEL_INFO_PATH,
self.model_info_path = get_scalar_param(autotuning_dict, AUTOTUNING_MODEL_INFO_PATH,
AUTOTUNING_MODEL_INFO_PATH_DEFAULT)
self.mp_size = get_scalar_param(autotuning_dict,
AUTOTUNING_MP_SIZE,
AUTOTUNING_MP_SIZE_DEFAULT)
self.mp_size = get_scalar_param(autotuning_dict, AUTOTUNING_MP_SIZE, AUTOTUNING_MP_SIZE_DEFAULT)
self.max_train_batch_size = get_dict_param(
autotuning_dict,
AUTOTUNING_MAX_TRAIN_BATCH_SIZE,
AUTOTUNING_MAX_TRAIN_BATCH_SIZE_DEFAULT)
self.max_train_batch_size = get_dict_param(autotuning_dict, AUTOTUNING_MAX_TRAIN_BATCH_SIZE,
AUTOTUNING_MAX_TRAIN_BATCH_SIZE_DEFAULT)
self.min_train_batch_size = get_dict_param(
autotuning_dict,
AUTOTUNING_MIN_TRAIN_BATCH_SIZE,
AUTOTUNING_MIN_TRAIN_BATCH_SIZE_DEFAULT)
self.min_train_batch_size = get_dict_param(autotuning_dict, AUTOTUNING_MIN_TRAIN_BATCH_SIZE,
AUTOTUNING_MIN_TRAIN_BATCH_SIZE_DEFAULT)
self.max_train_micro_batch_size_per_gpu = get_dict_param(
autotuning_dict,
AUTOTUNING_MAX_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
autotuning_dict, AUTOTUNING_MAX_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
AUTOTUNING_MAX_TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT)
self.min_train_micro_batch_size_per_gpu = get_dict_param(
autotuning_dict,
AUTOTUNING_MIN_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
autotuning_dict, AUTOTUNING_MIN_TRAIN_MICRO_BATCH_SIZE_PER_GPU,
AUTOTUNING_MIN_TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT)
self.num_tuning_micro_batch_sizes = get_dict_param(
autotuning_dict,
AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES,
AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES_DEFAULT)
self.num_tuning_micro_batch_sizes = get_dict_param(autotuning_dict, AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES,
AUTOTUNING_NUM_TUNING_MICRO_BATCH_SIZES_DEFAULT)
def get_model_info_config(param_dict):
if MODEL_INFO in param_dict and param_dict[MODEL_INFO] is not None:
model_info_config = {}
for key, default_value in MODEL_INFO_KEY_DEFAULT_DICT.items():
model_info_config[key] = get_scalar_param(param_dict[MODEL_INFO],
key,
default_value)
model_info_config[key] = get_scalar_param(param_dict[MODEL_INFO], key, default_value)
return model_info_config
return None

View File

@ -10,17 +10,13 @@ Licensed under the MIT license.
import os
DEFAULT_TEMPLATE_PATH_ZERO_0 = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"config_templates",
DEFAULT_TEMPLATE_PATH_ZERO_0 = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config_templates",
"template_zero0.json")
DEFAULT_TEMPLATE_PATH_ZERO_1 = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"config_templates",
DEFAULT_TEMPLATE_PATH_ZERO_1 = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config_templates",
"template_zero1.json")
DEFAULT_TEMPLATE_PATH_ZERO_2 = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"config_templates",
DEFAULT_TEMPLATE_PATH_ZERO_2 = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config_templates",
"template_zero2.json")
DEFAULT_TEMPLATE_PATH_ZERO_3 = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"config_templates",
DEFAULT_TEMPLATE_PATH_ZERO_3 = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config_templates",
"template_zero3.json")
METRIC_PERCENT_DIFF_CONST = 0.05
@ -157,50 +153,31 @@ DEFAULT_TUNING_SPACE_ZERO_0 = {"zero_optimization": {"stage": 0}}
DEFAULT_TUNING_SPACE_ZERO_1 = {
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": [5e7,
5e8,
1e9],
"allgather_bucket_size": [5e7,
5e8,
1e9],
"reduce_bucket_size": [5e7, 5e8, 1e9],
"allgather_bucket_size": [5e7, 5e8, 1e9],
}
}
DEFAULT_TUNING_SPACE_ZERO_2 = {
"zero_optimization": {
"stage": 2,
"overlap_comm": [True,
False],
"reduce_scatter": [False,
True],
"reduce_bucket_size": [5e7,
5e8,
1e9],
"allgather_bucket_size": [5e7,
5e8,
1e9],
"contiguous_gradients": [False,
True]
"overlap_comm": [True, False],
"reduce_scatter": [False, True],
"reduce_bucket_size": [5e7, 5e8, 1e9],
"allgather_bucket_size": [5e7, 5e8, 1e9],
"contiguous_gradients": [False, True]
},
}
DEFAULT_TUNING_SPACE_ZERO_3 = {
"zero_optimization": {
"stage": 3,
"overlap_comm": [True,
False],
"reduce_scatter": [False,
True],
"reduce_bucket_size": [5e7,
5e8,
1e9],
"allgather_partitions": [True,
False],
"allgather_bucket_size": [5e7,
5e8,
1e9],
"contiguous_gradients": [False,
True]
"overlap_comm": [True, False],
"reduce_scatter": [False, True],
"reduce_bucket_size": [5e7, 5e8, 1e9],
"allgather_partitions": [True, False],
"allgather_bucket_size": [5e7, 5e8, 1e9],
"contiguous_gradients": [False, True]
},
}

View File

@ -28,13 +28,8 @@ TIMEOUT = 5
class ResourceManager:
def __init__(self,
args,
hosts,
num_gpus_per_node,
results_dir,
exps_dir,
arg_mappings):
def __init__(self, args, hosts, num_gpus_per_node, results_dir, exps_dir, arg_mappings):
self.results_dir = results_dir
self.exps_dir = exps_dir
@ -69,13 +64,10 @@ class ResourceManager:
exp["exp_id"] = self.experiment_count
self.experiment_count += 1
result_dir = exp["result_dir"] = os.path.join(
self.results_dir,
exp['name'])
result_dir = exp["result_dir"] = os.path.join(self.results_dir, exp['name'])
if AUTOTUNING in exp["ds_config"]:
metric_file = os.path.join(result_dir, "metrics.json")
exp["ds_config"][AUTOTUNING][
AUTOTUNING_METRIC_PATH] = metric_file
exp["ds_config"][AUTOTUNING][AUTOTUNING_METRIC_PATH] = metric_file
stderr_file = os.path.join(result_dir, "stderr.log")
model_info_file = os.path.join(result_dir, "model_info.json")
metric_file = os.path.join(result_dir, "metrics.json")
@ -86,11 +78,8 @@ class ResourceManager:
err = search_error(stderr_file)
exp_id = exp["exp_id"]
self.finished_experiments[exp_id] = (exp, err)
if err or os.path.exists(metric_file) or os.path.exists(
model_info_file):
logger.info(
f"Skipping exp {exp['name']} whose result already exists"
)
if err or os.path.exists(metric_file) or os.path.exists(model_info_file):
logger.info(f"Skipping exp {exp['name']} whose result already exists")
continue
self.experiment_queue.append(exp)
@ -113,11 +102,7 @@ class ResourceManager:
user_args.append(val)
user_args.append(str(nval))
t = threading.Thread(target=run_experiment,
args=(exp,
reservations,
user_script,
user_args))
t = threading.Thread(target=run_experiment, args=(exp, reservations, user_script, user_args))
t.start()
self.running_experiments[exp_id] = (t, exp, reservations, time.time())
@ -270,6 +255,7 @@ class ResourceManager:
class Node:
def __init__(self, host, max_slots):
self.host = host
self.max_slots = max_slots
@ -284,6 +270,7 @@ class Node:
class Reservation:
def __init__(self, node, slots):
self.node = node
self.slots = slots
@ -389,9 +376,8 @@ def run_experiment(exp: dict, reservations, user_script, user_args):
f"Launching exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}, and ds_config = {os.path.abspath(ds_config_path)}"
)
with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(
os.path.join(exp_dir, "stderr.log"), "wb"
) as err:
with open(os.path.join(exp_dir, "stdout.log"), "wb") as out, open(os.path.join(exp_dir, "stderr.log"),
"wb") as err:
result = subprocess.Popen(cmd, stdout=out, stderr=err)
result.wait()
out.flush()
@ -401,9 +387,7 @@ def run_experiment(exp: dict, reservations, user_script, user_args):
clean_up(exp, reservations)
logger.info(
f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}"
)
logger.info(f"Done running exp_id = {exp['exp_id']}, exp_name = {exp['name']}, with resource = {include_str}")
PDSH_MAX_FAN_OUT = 1024
@ -417,8 +401,7 @@ def clean_up(exp: dict, reservations):
for reservation in reservations:
nodes_str += f"{reservation.node.host},"
nodes_str = nodes_str[:-1]
logger.debug(
f"Cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")
logger.debug(f"Cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
@ -441,6 +424,4 @@ def clean_up(exp: dict, reservations):
if result.returncode > 0:
sys.exit(result.returncode)
logger.info(
f"Done cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}"
)
logger.info(f"Done cleaning up exp_id = {exp['exp_id']} on the following workers: {nodes_str}")

View File

@ -8,6 +8,7 @@ from deepspeed.utils import logger
class BaseTuner:
def __init__(self, exps, resource_manager, metric):
self.all_exps = exps
self.rm = resource_manager
@ -42,8 +43,8 @@ class BaseTuner:
self.rm.schedule_experiments(exp_paths)
self.rm.run()
exp, metric_val = self.rm.parse_results(self.metric)
if self.best_exp == None or self.best_metric_val == None or (
metric_val and metric_val > self.best_metric_val):
if self.best_exp == None or self.best_metric_val == None or (metric_val
and metric_val > self.best_metric_val):
# logger.info(f"tuner finds better = {exp}")
self.best_exp = exp
self.best_metric_val = metric_val

View File

@ -9,6 +9,7 @@ except ImportError:
class XGBoostCostModel():
def __init__(self, loss_type, num_threads=None, log_interval=25, upper_model=None):
assert xgb is not None, "missing requirements, please install deepspeed w. 'autotuning_ml' extra."

View File

@ -7,6 +7,7 @@ from .base_tuner import BaseTuner
class RandomTuner(BaseTuner):
"""Explore the search space in random order"""
def __init__(self, exps: list, resource_manager, metric):
super().__init__(exps, resource_manager, metric)
@ -22,6 +23,7 @@ class RandomTuner(BaseTuner):
class GridSearchTuner(BaseTuner):
"""Explore the search space in sequential order"""
def __init__(self, exps: list, resource_manager, metric):
super().__init__(exps, resource_manager, metric)

View File

@ -15,6 +15,7 @@ INIT_NUM = 2
class ModelBasedTuner(BaseTuner):
"""Exploring the search space with a cost model"""
def __init__(self, exps: list, resource_manager, metric, tuning_sapce):
super().__init__(exps, resource_manager, metric)
self.tuning_space = tuning_sapce
@ -25,8 +26,7 @@ class ModelBasedTuner(BaseTuner):
self.dims = dict_to_dims(self.tuning_space)
logger.info(
f"Create config dim: {self.dims}, all configs: {self.num_all_configs}")
logger.info(f"Create config dim: {self.dims}, all configs: {self.num_all_configs}")
self.visited = set([])
@ -71,9 +71,7 @@ class ModelBasedTuner(BaseTuner):
n = len(estimates)
top_idx = np.argsort(estimates)
top_idx_ret = top_idx if self.metric == AUTOTUNING_METRIC_LATENCY else top_idx[::
-1][:
n]
top_idx_ret = top_idx if self.metric == AUTOTUNING_METRIC_LATENCY else top_idx[::-1][:n]
# top_configs = [self.all_configs[i] for i in top_idx]
@ -145,9 +143,7 @@ class ModelBasedTuner(BaseTuner):
self.evaluated_configs.append(feature_val)
self.evaluated_perf.append(curr_iter)
logger.debug(
f"**Evaluated configs: {len(self.evaluated_configs)}, evaluated perf: {self.evaluated_perf}"
)
logger.debug(f"**Evaluated configs: {len(self.evaluated_configs)}, evaluated perf: {self.evaluated_perf}")
self.cost_model.fit(self.evaluated_configs, self.evaluated_perf)

View File

@ -44,9 +44,7 @@ def gen_combinations(d: dict):
for v in values:
if not isinstance(v, list):
v = [v]
values_choices = (gen_combinations(v) if isinstance(v,
dict) else get_list(v)
for v in values)
values_choices = (gen_combinations(v) if isinstance(v, dict) else get_list(v) for v in values)
for comb in itertools.product(*values_choices):
yield dict(zip(keys, comb))

View File

@ -176,6 +176,7 @@ def fetch_hostfile(hostfile_path):
def validate_ds_config(config: dict):
def is_False(config: dict, key):
if config is None:
return False
@ -189,9 +190,7 @@ def validate_ds_config(config: dict):
if stage == 1:
return True
elif stage == 2:
if is_False(config_zero,
"cpu_offload") and is_False(config_zero,
"cpu_offload_params"):
if is_False(config_zero, "cpu_offload") and is_False(config_zero, "cpu_offload_params"):
return False
elif stage == 3:
offload_devices = ["cpu", "nvme"]
@ -289,14 +288,13 @@ def get_all_configs(tuning_space: dict, ignore_keys=None):
Args:
tuning_space (dict): the tuning space where tunable parameters are lists of values.
"""
def gen_combinations(d: dict):
keys, values = d.keys(), d.values()
for v in values:
if not isinstance(v, list):
v = [v]
values_choices = (gen_combinations(v) if isinstance(v,
dict) else get_list(v)
for v in values)
values_choices = (gen_combinations(v) if isinstance(v, dict) else get_list(v) for v in values)
for comb in itertools.product(*values_choices):
yield dict(zip(keys, comb))

View File

@ -4,9 +4,7 @@ from .reshape_meg_2d import reshape_meg_2d_parallel
from .deepspeed_checkpoint import DeepSpeedCheckpoint
from .utils import (get_layer_ckpt_name_for_rank,
get_model_ckpt_name_for_rank,
get_zero_ckpt_name_for_rank)
from .utils import (get_layer_ckpt_name_for_rank, get_model_ckpt_name_for_rank, get_zero_ckpt_name_for_rank)
from .reshape_utils import (merge_state)

View File

@ -5,11 +5,7 @@ from typing import Dict
import torch
from .reshape_3d_utils import model_3d_desc
from .reshape_utils import (basic_folder_validation,
merge_state,
partition_data,
get_files,
get_files_with_prefix)
from .reshape_utils import (basic_folder_validation, merge_state, partition_data, get_files, get_files_with_prefix)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
@ -24,19 +20,15 @@ CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY = 'iteration'
SEQUENTIAL_LAYERS = [
'input_layernorm.weight',
'input_layernorm.bias',
'self_attention.dense.bias',
'post_attention_layernorm.weight',
'post_attention_layernorm.bias',
'mlp.dense_4h_to_h.bias',
'position_embeddings.weight'
'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', 'post_attention_layernorm.weight',
'post_attention_layernorm.bias', 'mlp.dense_4h_to_h.bias', 'position_embeddings.weight'
]
LAYER_CONCAT_DIM = {'self_attention.dense.weight': 1, 'mlp.dense_4h_to_h.weight': 1}
class DeepSpeedCheckpoint(object):
def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)
@ -50,33 +42,24 @@ class DeepSpeedCheckpoint(object):
self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.tp_degree = self.zero_checkpoint.get_src_tp_degree(
) if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree(
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree
self.tp_degree = self.zero_checkpoint.get_src_tp_degree() if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree() if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree() if dp_degree is None else dp_degree
self.original_world_size = self.zero_checkpoint.get_src_tp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
self.original_world_size = self.zero_checkpoint.get_src_tp_degree() * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
model_3d_desc(self.pp_degree,
self.tp_degree,
self.dp_degree))
if self.is_change_pp_degree() or self.is_change_tp_degree() or self.is_change_dp_degree():
self.zero_checkpoint.reshape(model_3d_desc(self.pp_degree, self.tp_degree, self.dp_degree))
self.global_state = {}
@ -84,8 +67,7 @@ class DeepSpeedCheckpoint(object):
self.pp_to_transformer_map = self._build_pp_transformer_map()
self.transformer_file_map = self._build_transformer_file_map()
self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(
FINAL_LAYER_NORM_INDEX)
self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX)
self._build_global_state()
def is_change_tp_degree(self):
@ -131,9 +113,7 @@ class DeepSpeedCheckpoint(object):
keys_to_ignore=[PARAM_SHAPES])
def get_zero_files(self, pp_index, tp_index, dp_index) -> list:
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index,
tp_index=tp_index,
dp_index=dp_index)
return self.zero_checkpoint.get_files_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)
def get_embedding_layer_id(self):
return self.layer_keys[EMBEDDING_LAYER_INDEX]
@ -150,11 +130,7 @@ class DeepSpeedCheckpoint(object):
def get_embedding_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_embedding_map.keys()
sd_list = [
torch.load(fname,
map_location=torch.device('cpu'))
for fname in self.tp_to_embedding_map[tp_index]
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]]
sd = self._merge_state_dicts(sd_list)
return sd
@ -179,10 +155,7 @@ class DeepSpeedCheckpoint(object):
assert tp_index < self.tp_degree
assert pp_index < self.pp_degree
fname_list = self.get_2d_parallel_files(tp_index=tp_index, pp_index=pp_index)
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
merged_sd = None
for sd in sd_list:
@ -198,10 +171,7 @@ class DeepSpeedCheckpoint(object):
assert pp_index < self.pp_degree
t_list = []
for fname_list in self.transformer_file_map[(tp_index, pp_index)]:
sd_list = [
torch.load(fname,
map_location=torch.device('cpu')) for fname in fname_list
]
sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list]
sd = self._merge_state_dicts(sd_list)
t_list.append(sd)
return t_list
@ -212,8 +182,7 @@ class DeepSpeedCheckpoint(object):
def get_final_norm_state(self, tp_index: int) -> Dict:
assert tp_index in self.tp_to_final_norm_map.keys()
sd = torch.load(self.tp_to_final_norm_map[tp_index][0],
map_location=torch.device('cpu'))
sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu'))
return sd
def get_final_norm_files(self, tp_index: int) -> list:
@ -222,8 +191,7 @@ class DeepSpeedCheckpoint(object):
def _build_tp_other_layer_map(self, layer_index: int):
assert layer_index < len(self.layer_files)
layer_files = get_files_with_prefix(self.layer_files,
self.layer_keys[layer_index])
layer_files = get_files_with_prefix(self.layer_files, self.layer_keys[layer_index])
layer_file_partitions = partition_data(layer_files, self.tp_degree)
data_map = {i: flist for i, flist in enumerate(layer_file_partitions)}
return data_map
@ -238,11 +206,7 @@ class DeepSpeedCheckpoint(object):
data_map = {}
transformer_layers = self.layer_keys[1:-1]
layers_per_pp = len(transformer_layers) // self.pp_degree
data_map = {
i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp]
for i in range(0,
self.pp_degree)
}
data_map = {i: transformer_layers[i * layers_per_pp:(i + 1) * layers_per_pp] for i in range(0, self.pp_degree)}
return data_map
def _dump_mapping(self, data_map, map_tag=None):
@ -308,10 +272,8 @@ class DeepSpeedCheckpoint(object):
file_list = get_files(dir)
for file_prefix in [
MODEL_FILE_PREFIX,
LAYER_FILE_PREFIX,
f'{LAYER_FILE_PREFIX}01'
]:
for file_prefix in [MODEL_FILE_PREFIX, LAYER_FILE_PREFIX, f'{LAYER_FILE_PREFIX}01']:
ckpt_files = get_files_with_prefix(file_list, file_prefix)
assert len(ckpt_files) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'
assert len(
ckpt_files
) > 0, f'{dir} seems a bogus DeepSpeed checkpoint folder: Cannot find {file_prefix}* files in there.'

View File

@ -1,9 +1,6 @@
'''Copyright The Microsoft DeepSpeed Team'''
from .reshape_utils import (get_files,
get_files_with_prefix,
partition_data,
get_zero_files)
from .reshape_utils import (get_files, get_files_with_prefix, partition_data, get_zero_files)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
@ -15,6 +12,7 @@ DP_DIM = 'DP'
class model_3d_desc(object):
def __init__(self, pp_degree=1, tp_degree=1, dp_degree=1):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
@ -33,8 +31,7 @@ class model_3d_desc(object):
src_2d_size=self.pp_degree * self.tp_degree,
dp_degree=self.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map,
dp_degree=target_3d_desc.dp_degree)
return unflatten_dp_dimension(meg_2d_map=flat_3d_map, dp_degree=target_3d_desc.dp_degree)
def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'
@ -45,14 +42,11 @@ class model_3d_desc(object):
def is_valid(self, pp_index, tp_index, dp_index):
err_msg = []
valid = True
for index, degree, dim_name in [
(pp_index, self.pp_degree, PP_DIM),
(tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
for index, degree, dim_name in [(pp_index, self.pp_degree, PP_DIM), (tp_index, self.tp_degree, TP_DIM),
(dp_index, self.dp_degree, DP_DIM)]:
if index >= degree:
valid = False
err_msg.append(
f'{dim_name} indexing error: index {index} >= degree {degree}')
err_msg.append(f'{dim_name} indexing error: index {index} >= degree {degree}')
return valid, err_msg
@ -60,18 +54,15 @@ class model_3d_desc(object):
err_msg = []
if target_3d_desc.pp_degree > self.pp_degree:
err_msg.append(
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}'
)
f'Expansion reshape not supported - {PP_DIM}: {self.pp_degree} ---> {target_3d_desc.pp_degree}')
if target_3d_desc.tp_degree > self.tp_degree:
err_msg.append(
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}'
)
f'Expansion reshape not supported - {TP_DIM}: {self.tp_degree} ---> {target_3d_desc.tp_degree}')
if target_3d_desc.dp_degree > self.dp_degree:
err_msg.append(
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}'
)
f'Expansion reshape not supported - {DP_DIM}: {self.dp_degree} ---> {target_3d_desc.dp_degree}')
return len(err_msg) == 0, err_msg
@ -106,10 +97,7 @@ def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
def unflatten_dp_dimension(meg_2d_map, dp_degree):
pp_degree = meg_2d_map.pp_degree
tp_degree = meg_2d_map.tp_degree
meg_2d_map_list = [
meg_2d_parallel_map(pp_degree=pp_degree,
tp_degree=tp_degree) for _ in range(dp_degree)
]
meg_2d_map_list = [meg_2d_parallel_map(pp_degree=pp_degree, tp_degree=tp_degree) for _ in range(dp_degree)]
for pp_index in range(pp_degree):
for tp_index in range(tp_degree):
flat_dp_indices = meg_2d_map.get_data(pp_index, tp_index)

View File

@ -4,6 +4,7 @@ from .reshape_utils import partition_data
class meg_2d_parallel_map(object):
def __init__(self, pp_degree, tp_degree):
self.pp_degree = pp_degree
self.tp_degree = tp_degree
@ -11,8 +12,7 @@ class meg_2d_parallel_map(object):
def simple_init(self):
self.map = {
self._make_key(i // self.tp_degree,
i % self.tp_degree): [i]
self._make_key(i // self.tp_degree, i % self.tp_degree): [i]
for i in range(self.pp_degree * self.tp_degree)
}
@ -74,11 +74,7 @@ def _reshape_pp_dimension(old_2d_map, new_pp_degree):
return new_2d_map
def reshape_meg_2d_parallel(old_pp_degree,
old_tp_degree,
new_pp_degree,
new_tp_degree,
verbose=False):
def reshape_meg_2d_parallel(old_pp_degree, old_tp_degree, new_pp_degree, new_tp_degree, verbose=False):
assert new_pp_degree <= old_pp_degree
assert new_tp_degree <= old_tp_degree
@ -137,8 +133,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
tensor_model_parallel_size = min(tp_size, world_size)
pipeline_model_parallel_size = min(pp_size, world_size)
data_parallel_size = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
@ -158,10 +153,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the model-parallel groups.
all_pp_group_ranks = []
for i in range(data_parallel_size):
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_dp_group_ranks
]
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_dp_group_ranks]
all_pp_group_ranks.append(list(ranks))
print(f"PP", all_pp_group_ranks)
@ -169,8 +161,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the tensor model-parallel groups.
all_tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
all_tp_group_ranks.append(list(ranks))
print(f"TP", all_tp_group_ranks)

View File

@ -49,11 +49,7 @@ def partition_data(data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
partition_size = num_elems // num_partitions
partitions_list = [
data_list[i:i + partition_size] for i in range(0,
num_elems,
partition_size)
]
partitions_list = [data_list[i:i + partition_size] for i in range(0, num_elems, partition_size)]
return partitions_list
@ -76,9 +72,7 @@ def merge_state_dict(dict_a, dict_b, key_list):
def merge_state_list(list_a, list_b, key_list):
if len(list_a) != len(list_b):
print(f'{_key_list_to_string(key_list)}')
raise ValueError(
f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}'
)
raise ValueError(f'Cannot merge lists of different lengths, a = {len(list_a)} b = {len(list_b)}')
return [merge_state(a, b, key_list) for a, b in zip(list_a, list_b)]
@ -87,8 +81,7 @@ def merge_state(state_a, state_b, key_list=[]):
if type(state_a) != type(state_b):
key_list_string = _key_list_to_string(key_list)
print(f'key_list = {key_list_string}')
raise ValueError(
f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
raise ValueError(f'Cannot merge two states of types {type(state_a)} and type {type(state_b)}')
if type(state_a) in (dict, OrderedDict):
return merge_state_dict(state_a, state_b, key_list)

View File

@ -4,10 +4,7 @@ Copyright 2022 The Microsoft DeepSpeed Team
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
@ -44,9 +41,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(
VOCAB_DIVISIBILITY_PADDING_TENSOR,
None)
vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
if vocab_divisibility_padding_tensor is not None:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
@ -56,13 +51,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_size),
"constant",
0)
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
@ -76,8 +65,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(
key)
dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment(key)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
@ -90,9 +78,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
tp_hp_slice = tp_hp_slice.flatten()
lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0,
lp_frag_address.start,
lp_frag_address.numel)
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
assert dst_tensor.numel() == lp_frag_address.numel, \
f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}'
@ -104,5 +90,4 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
def enable_universal_checkpoint(param_list):
for param in param_list:
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state,
param)
param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, param)

View File

@ -1,10 +1,7 @@
'''Copyright The Microsoft DeepSpeed Team'''
import os
from .constants import (MODEL_FILE_PREFIX,
MODEL_FILE_SUFFIX,
OPTIM_FILE_SUFFIX,
ZERO_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
def get_model_ckpt_name_for_rank(base_folder, mp_rank_str):

View File

@ -2,10 +2,7 @@
import torch
from .constants import (BASE_OPTIMIZER_STATE,
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT)
from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, PARTITION_COUNT)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)
@ -15,6 +12,7 @@ GROUP_STATE_KEY = 'state'
class ZeROCheckpoint(object):
def __init__(self, dir):
basic_folder_validation(dir)
self.dir = dir
@ -49,12 +47,7 @@ class ZeROCheckpoint(object):
file_idx_list = self.get_file_indices_for_rank(pp_index, tp_index, dp_index)
return [self.file_list[idx] for idx in file_idx_list]
def get_state_for_rank(self,
pp_index,
tp_index,
dp_index,
keys_to_ignore=[],
strip_tensor_paddings=True):
def get_state_for_rank(self, pp_index, tp_index, dp_index, keys_to_ignore=[], strip_tensor_paddings=True):
state_file_list = self.get_files_for_rank(pp_index, tp_index, dp_index)
merged_sd = None
for state_file in state_file_list:
@ -111,10 +104,7 @@ class ZeROCheckpoint(object):
for state_name, state_value in group_state.items():
if torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value,
0,
0,
raw_length).clone()
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
@ -144,5 +134,4 @@ class ZeROCheckpoint(object):
partition_counts = self._get_optimizer_state(sd, PARTITION_COUNT)
if partition_counts:
num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
] * num_groups
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree] * num_groups

View File

@ -39,13 +39,8 @@ if older_torch():
input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
return reduce_scatter(output_tensor, input_tensor_lst, group=group)
def configure(deepspeed_config=None,
enabled=None,
prof_all=None,
prof_ops=None,
verbose=None):
utils.logger.warn(
"Communication logging is not supported in torch versions older than 1.8")
def configure(deepspeed_config=None, enabled=None, prof_all=None, prof_ops=None, verbose=None):
utils.logger.warn("Communication logging is not supported in torch versions older than 1.8")
else:
supported_torch_version = True

View File

@ -20,6 +20,7 @@ In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distribute
class Backend(object):
def __init__(self, name='backend', rank=0, size=1):
self.name = name
# The world size and rank of the world process group

View File

@ -106,12 +106,13 @@ def configure(
# Logging wrapper for timing ops
def timed_op(func):
def log_wrapper(*args, **kwargs):
# Add enabled flag so that overhead to each comm op is two if conditions at most
if comms_logger.enabled:
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
if ('prof' in kwargs
and kwargs['prof']) or comms_logger.prof_all or ('log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
# Need func args for their defaults
func_args = get_default_args(func)
func_args.update(kwargs)
@ -129,8 +130,7 @@ def timed_op(func):
if cdb.using_mpi:
cdb.barrier()
if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
'log_name' in kwargs
and kwargs['log_name'] in comms_logger.prof_ops):
'log_name' in kwargs and kwargs['log_name'] in comms_logger.prof_ops):
log_name = get_debug_log_name(func_args, comms_logger.debug)
raw_name = func.__name__
timers(log_name).stop()
@ -178,7 +178,8 @@ def destroy_process_group(group=None):
def new_group(ranks):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.new_group(ranks)
@ -197,9 +198,7 @@ def set_backend(backend):
utils.logger.error(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise RuntimeError(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
raise RuntimeError('Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.')
global cdb
global nccl_backend
@ -217,13 +216,7 @@ def set_backend(backend):
@timed_op
def broadcast(tensor,
src,
group=None,
async_op=False,
prof=False,
log_name='broadcast',
debug=get_caller_func()):
def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()):
global cdb
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
@ -237,15 +230,13 @@ def all_gather(tensor_list,
log_name='all_gather',
debug=get_caller_func()):
global cdb
return cdb.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def has_reduce_scatter_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
return cdb.has_reduce_scatter_base
@ -258,7 +249,8 @@ def reduce_scatter_fn(output_tensor,
prof=False,
debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_reduce_scatter_base:
return reduce_scatter_base(output_tensor,
tensor,
@ -268,10 +260,9 @@ def reduce_scatter_fn(output_tensor,
prof=prof,
debug=debug)
else:
utils.logger.warning_once(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning_once("unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
return reduce_scatter(output_tensor,
input_tensor_lst,
@ -308,44 +299,30 @@ def all_gather_base(output_tensor,
log_name='all_gather_base',
debug=get_caller_func()):
global cdb
return cdb.all_gather_base(output_tensor=output_tensor,
input_tensor=tensor,
group=group,
async_op=async_op)
return cdb.all_gather_base(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
def has_allgather_base():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
return cdb.has_allgather_base
def allgather_fn(output_tensor,
input_tensor,
group=None,
async_op=False,
debug=get_caller_func()):
def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if cdb.has_allgather_base:
return all_gather_base(output_tensor,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
return all_gather_base(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
else:
if get_rank() == 0:
utils.logger.warning_once(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning_once("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
return all_gather(output_tensors,
input_tensor,
group=group,
async_op=async_op,
debug=debug)
return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
@timed_op
@ -368,49 +345,25 @@ def all_to_all_single(output,
@timed_op
def send(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='send',
debug=get_caller_func()):
def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def recv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='recv',
debug=get_caller_func()):
def recv(tensor, src=None, group=None, tag=0, prof=False, log_name='recv', debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@timed_op
def isend(tensor,
dst,
group=None,
tag=0,
prof=False,
log_name='isend',
debug=get_caller_func()):
def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()):
global cdb
return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
@timed_op
def irecv(tensor,
src=None,
group=None,
tag=0,
prof=False,
log_name='irecv',
debug=get_caller_func()):
def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()):
global cdb
return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
@ -425,11 +378,7 @@ def gather(tensor,
log_name='gather',
debug=get_caller_func()):
global cdb
return cdb.gather(tensor=tensor,
gather_list=gather_list,
dst=dst,
group=group,
async_op=async_op)
return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
@timed_op
@ -442,20 +391,11 @@ def scatter(tensor,
log_name='scatter',
debug=get_caller_func()):
global cdb
return cdb.scatter(tensor=tensor,
scatter_list=scatter_list,
src=src,
group=group,
async_op=async_op)
return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
@timed_op
def barrier(group=None,
async_op=False,
device_ids=None,
prof=False,
log_name='barrier',
debug=get_caller_func()):
def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
global cdb
return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
@ -502,11 +442,7 @@ def reduce_scatter(output,
log_name='reduce_scatter',
debug=get_caller_func()):
global cdb
return cdb.reduce_scatter(output=output,
input_list=input_list,
op=op,
group=group,
async_op=async_op)
return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
@timed_op
@ -528,7 +464,8 @@ def all_reduce(tensor,
def get_world_group():
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_group()
@ -544,7 +481,8 @@ def get_world_size(group=None) -> int:
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_world_size(group)
@ -563,7 +501,8 @@ def get_rank(group=None):
-1, if not part of the group
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_rank(group)
@ -576,13 +515,15 @@ def get_local_rank():
local rank (= GPU device ID)
"""
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return get_local_rank_from_launcher()
def get_global_rank(group=None, group_rank=0):
global cdb
assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_global_rank(group, group_rank)
@ -631,9 +572,7 @@ def init_distributed(dist_backend=None,
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
if in_aml() and not in_dlts():
patch_aml_env_for_torch_nccl_backend(verbose=verbose)
elif in_aws_sm():
@ -649,9 +588,7 @@ def init_distributed(dist_backend=None,
if dist_backend == None:
dist_backend = get_accelerator().communication_backend_name()
if int(os.getenv('RANK', '0')) == 0:
utils.logger.info(
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
@ -686,16 +623,12 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)
if verbose:
utils.logger.info(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
if cdb is not None and cdb.is_initialized():
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
rank, cdb.get_rank())
assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, cdb.get_rank())
assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, cdb.get_world_size())
@ -722,8 +655,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
os.environ["WORLD_SIZE"])
single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(os.environ["WORLD_SIZE"])
if not single_node:
master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
@ -736,8 +668,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
if verbose:
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
os.environ["NCCL_SOCKET_IFNAME"]))
utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
@ -745,10 +676,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
if verbose:
utils.logger.info(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
.format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
@ -762,8 +690,5 @@ def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
if verbose:
utils.logger.info(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
.format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))

View File

@ -9,6 +9,7 @@ from .constants import *
class CommsConfig(BaseModel):
class Config:
validate_all = True
validate_assignment = True
@ -25,6 +26,7 @@ class CommsLoggerConfig(CommsConfig):
class DeepSpeedCommsConfig:
def __init__(self, ds_config):
self.comms_logger_enabled = 'comms_logger' in ds_config

View File

@ -16,13 +16,8 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def __init__(self,
backend,
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
@ -44,30 +39,14 @@ class TorchBackend(Backend):
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'
def all_reduce(self,
tensor,
op=torch.distributed.ReduceOp.SUM,
group=None,
async_op=False):
def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor,
op=op,
group=group,
async_op=async_op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)
def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce(tensor=tensor,
dst=dst,
op=self._reduce_op(op),
group=group,
async_op=async_op)
return torch.distributed.reduce(tensor=tensor, dst=dst, op=self._reduce_op(op), group=group, async_op=async_op)
def reduce_scatter(self,
output,
input_list,
op=ReduceOp.SUM,
group=None,
async_op=False):
def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
return torch.distributed.reduce_scatter(output=output,
input_list=input_list,
op=self._reduce_op(op),
@ -75,37 +54,24 @@ class TorchBackend(Backend):
async_op=async_op)
def broadcast(self, tensor, src, group=None, async_op=False):
return torch.distributed.broadcast(tensor=tensor,
src=src,
group=group,
async_op=async_op)
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
return torch.distributed.all_gather(tensor_list=tensor_list,
tensor=tensor,
group=group,
async_op=async_op)
return torch.distributed.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
if self.has_allgather_base:
return torch.distributed.distributed_c10d._all_gather_base(
output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
return torch.distributed.distributed_c10d._all_gather_base(output_tensor=output_tensor,
input_tensor=input_tensor,
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning("unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def reduce_scatter_base(self,
output_tensor,
input_tensor,
op=ReduceOp.SUM,
group=None,
async_op=False):
def reduce_scatter_base(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_base:
return torch.distributed._reduce_scatter_base(output_tensor,
input_tensor,
@ -113,10 +79,9 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
else:
utils.logger.warning(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
utils.logger.warning("unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation.")
pass
def all_to_all_single(self,
@ -159,25 +124,15 @@ class TorchBackend(Backend):
group=group,
async_op=async_op)
def barrier(self,
group=torch.distributed.GroupMember.WORLD,
async_op=False,
device_ids=None):
def barrier(self, group=torch.distributed.GroupMember.WORLD, async_op=False, device_ids=None):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.barrier(group=group,
async_op=async_op,
device_ids=device_ids)
return torch.distributed.barrier(group=group, async_op=async_op, device_ids=device_ids)
def monitored_barrier(self,
group=torch.distributed.GroupMember.WORLD,
timeout=None,
wait_all_ranks=False):
def monitored_barrier(self, group=torch.distributed.GroupMember.WORLD, timeout=None, wait_all_ranks=False):
if group is None:
group = torch.distributed.GroupMember.WORLD
return torch.distributed.monitored_barrier(group=group,
timeout=timeout,
wait_all_ranks=wait_all_ranks)
return torch.distributed.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
def get_rank(self, group=None):
return torch.distributed.get_rank(group=group)

View File

@ -84,11 +84,7 @@ def get_world_size_from_launcher():
def get_default_args(func):
signature = inspect.signature(func)
return {
k: v.default
for k,
v in signature.parameters.items() if v.default is not inspect.Parameter.empty
}
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
# We need this hacky function since torch doesn't consistently name or place the input tensor args

View File

@ -21,6 +21,7 @@ class QuantAct(nn.Module):
Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric'
"""
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
super(QuantAct, self).__init__()
@ -50,10 +51,8 @@ class QuantAct(nn.Module):
self.x_min_max[1] = x_max
# if do not need momentum, please set self.act_range_momentum = 0
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (
1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (
1 - self.act_range_momentum)
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
@ -61,6 +60,7 @@ class QuantAct(nn.Module):
class Embedding_Compress(nn.Embedding):
def __init__(self, *kargs):
super(Embedding_Compress, self).__init__(*kargs)
self.weight.start_bits = None
@ -71,17 +71,10 @@ class Embedding_Compress(nn.Embedding):
def extra_repr(self):
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
self.num_embeddings,
self.embedding_dim,
self.weight.target_bits)
self.num_embeddings, self.embedding_dim, self.weight.target_bits)
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
@ -105,31 +98,20 @@ class Embedding_Compress(nn.Embedding):
self.weight_quantize_num_groups = self.weight.size(0)
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
else:
weight = self.weight
out = nn.functional.embedding(input,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse)
out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
return out
@ -137,6 +119,7 @@ class LinearLayer_Compress(nn.Linear):
"""
Linear layer with compression.
"""
def __init__(self, *kargs, bias=True):
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
self.sparse_pruning_method = None
@ -169,8 +152,7 @@ class LinearLayer_Compress(nn.Linear):
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
@ -209,11 +191,9 @@ class LinearLayer_Compress(nn.Linear):
raise NotImplementedError
else:
self.head_pruning_ratio = ratio
self.head_pruning_scores = nn.Parameter(torch.Tensor(
1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(
self.weight.device)
self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
def fix_sparse_pruning_helper(self):
@ -279,18 +259,17 @@ class LinearLayer_Compress(nn.Linear):
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape).t())
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
-1)[mask.view(-1), :].reshape(-1,
shape).t())
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
else:
shape = self.weight.size()
self.weight.data = (self.weight.data.t().reshape(self.num_heads,
-1) *
mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
shape[1], shape[0]).t()
if self.head_pruning_method == 'topk':
del self.head_pruning_scores
@ -316,37 +295,26 @@ class LinearLayer_Compress(nn.Linear):
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
if pruning_type == 'row':
if self.row_pruning_method == 'l1':
return self.row_pruning_mask.to(self.weight.device)
elif self.row_pruning_method == 'topk':
return TopKBinarizer.apply(self.row_mask_scores,
self.row_pruning_ratio,
False)
return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'head':
if self.head_pruning_method == 'topk':
return TopKBinarizer.apply(self.head_pruning_scores,
self.head_pruning_ratio,
False)
return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
@ -369,10 +337,7 @@ class LinearLayer_Compress(nn.Linear):
self.weight_quantize_num_groups = num_groups
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
@ -391,18 +356,12 @@ class LinearLayer_Compress(nn.Linear):
def head_pruning_reshape(self, w, mask):
shape = w.shape
return (w.t().reshape(self.num_heads,
-1) * mask.view(-1,
1)).reshape(shape[1],
shape[0]).t()
return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
def forward(self, input, skip_bias_add=False):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
@ -428,11 +387,7 @@ class LinearLayer_Compress(nn.Linear):
num_groups = input.numel() // input.size(-1)
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
if skip_bias_add:
# used for mpu linear layers
@ -447,6 +402,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
"""
Conv2D layer with compression.
"""
def __init__(self, *kargs):
super(Conv2dLayer_Compress, self).__init__(*kargs)
self.sparse_pruning_method = None
@ -478,10 +434,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
output = s.format(**self.__dict__)
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
self.sparse_pruning_method is not None,
self.channel_pruning_method is not None,
self.activation_quantization_method is not None,
self.weight.target_bits)
self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
self.activation_quantization_method is not None, self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
self.sparse_pruning_ratio = ratio
@ -493,8 +447,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(
self.weight.device)
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
@ -514,13 +467,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask = mask.view(-1, 1, 1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.channel_mask_scores = nn.Parameter(
torch.Tensor(self.weight.size(0),
1,
1,
1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(
self.weight.device)
self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
mask = None
else:
@ -579,39 +527,27 @@ class Conv2dLayer_Compress(nn.Conv2d):
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores,
self.sparse_pruning_ratio,
False)
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'channel':
if self.channel_pruning_method == 'l1':
return self.channel_pruning_mask.to(self.weight.device)
elif self.channel_pruning_method == 'topk':
return TopKBinarizer.apply(self.channel_mask_scores,
self.channel_pruning_ratio,
False)
return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_weight_quantization(self,
start_bits,
target_bits,
quantization_period,
weight_quantization_enabled_in_forward,
quantization_type,
num_groups):
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
@ -642,10 +578,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight,
self.weight.target_bits,
None,
None,
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
@ -667,22 +600,13 @@ class Conv2dLayer_Compress(nn.Conv2d):
num_groups = input.numel() // input[0].numel()
else:
num_groups = 1
input = self.activation_quantizer(input,
self.activation_quantization_bits,
None,
None,
num_groups)
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
return nn.functional.conv2d(input,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups)
return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class BNLayer_Compress(nn.BatchNorm2d):
def fix_channel_pruning_helper(self, mask, dim_reduction=True):
self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
@ -770,6 +694,7 @@ def _gather(input_):
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_):
return input_
@ -781,6 +706,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@ -792,6 +718,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_):
return _split(input_)
@ -803,6 +730,7 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@ -834,13 +762,8 @@ def gather_from_model_parallel_region(input_):
class ColumnParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
gather_output=True,
skip_bias_add=False):
def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
@ -854,10 +777,7 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size
super(ColumnParallelLinear_Compress,
self).__init__(self.input_size,
self.output_size_per_partition,
bias=bias)
super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
@ -877,13 +797,8 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
class RowParallelLinear_Compress(LinearLayer_Compress):
def __init__(self,
mpu,
input_size,
output_size,
bias=True,
input_is_parallel=False,
skip_bias_add=False):
def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
@ -897,10 +812,7 @@ class RowParallelLinear_Compress(LinearLayer_Compress):
assert input_size % world_size == 0
self.input_size_per_partition = input_size // world_size
super(RowParallelLinear_Compress,
self).__init__(self.input_size_per_partition,
self.output_size,
bias=bias)
super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.

View File

@ -13,21 +13,13 @@ def check_deepspeed_config(config):
if isinstance(config, dict):
return config
elif os.path.exists(config):
return json.load(open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
return json.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
)
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}")
def get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=None,
verbose=True):
def get_module_name(group_name, model, key_word, exist_module_name, mpu=None, verbose=True):
'''
get the associated module name from the model based on the key_word provided by users
'''
@ -40,8 +32,7 @@ def get_module_name(group_name,
if name in exist_module_name and verbose:
# logger.warning
raise ValueError(
f"{name} is already added to compression, please check your config file for {group_name}."
)
f"{name} is already added to compression, please check your config file for {group_name}.")
if name not in exist_module_name:
exist_module_name.add(name)
return_module_name.append(name)
@ -56,8 +47,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
continue
# for loop different methods, i.e., weight quantization, activation quantization etc
exist_module_name = set()
shared_parameters = method_content[
SHARED_PARAMETERS] # get all the shared parameters
shared_parameters = method_content[SHARED_PARAMETERS] # get all the shared parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
# for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
module_name_list = []
@ -65,8 +55,13 @@ def get_compress_methods(model, compress_methods, mpu=None):
if method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]:
# this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
# otherwise we just mask those as zeros
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE], method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
for key_word, related_key_words in zip(method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE],
method_parameters[DIFFERENT_GROUPS_RELATED_MODULE_SCOPE]):
module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name)
tmp_related_module_name_list = []
for rkw in related_key_words:
@ -76,7 +71,11 @@ def get_compress_methods(model, compress_methods, mpu=None):
related_module_name_list.append(tmp_related_module_name_list)
else:
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, model, key_word, exist_module_name, mpu=mpu)
module_name, exist_module_name = get_module_name(group_name,
model,
key_word,
exist_module_name,
mpu=mpu)
module_name_list.append(module_name)
if module_name_list:
@ -85,13 +84,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
**(method_parameters.copy().pop(DIFFERENT_GROUPS_PARAMETERS)),
**shared_parameters
}
compression_item = [
module_name_list,
related_module_name_list,
{
method: combined_method_parameters
}
]
compression_item = [module_name_list, related_module_name_list, {method: combined_method_parameters}]
layer_added_compress_methods.append(compression_item)
return layer_added_compress_methods
@ -118,9 +111,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
assert teacher_model is not None, "Teacher model is required for layer reduction"
student_initialization(c_model, teacher_model, deepspeed_config)
layer_added_compress_methods = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
layer_added_compress_methods = get_compress_methods(c_model, compress_methods, mpu=mpu)
compression_preparation(c_model, layer_added_compress_methods, mpu)
return model
@ -143,31 +134,20 @@ def redundancy_clean(model, deepspeed_config, mpu=None):
else:
c_model = model
layer_added_compress_methods_tmp = get_compress_methods(c_model,
compress_methods,
mpu=mpu)
layer_added_compress_methods_tmp = get_compress_methods(c_model, compress_methods, mpu=mpu)
# sort methods
order_list = [
WEIGHT_QUANTIZATION,
SPARSE_PRUNING,
ROW_PRUNING,
HEAD_PRUNING,
CHANNEL_PRUNING,
ACTIVATION_QUANTIZATION
WEIGHT_QUANTIZATION, SPARSE_PRUNING, ROW_PRUNING, HEAD_PRUNING, CHANNEL_PRUNING, ACTIVATION_QUANTIZATION
]
layer_added_compress_methods = sorted(
layer_added_compress_methods_tmp,
key=lambda x: order_list.index(list(x[2].keys())[0]))
layer_added_compress_methods = sorted(layer_added_compress_methods_tmp,
key=lambda x: order_list.index(list(x[2].keys())[0]))
for module_name_lists, related_module_name_lists, compression_technique in layer_added_compress_methods:
stored_mask = []
need_mask = True if related_module_name_lists else False
for i, mnl in enumerate(module_name_lists):
for module_name in mnl:
mask = fix_compression(c_model,
module_name,
compression_technique,
dim_reduction=need_mask)
mask = fix_compression(c_model, module_name, compression_technique, dim_reduction=need_mask)
if need_mask:
stored_mask.append(mask)
if need_mask:
@ -219,10 +199,8 @@ def student_initialization(student_model, teacher_model, deepspeed_config):
'''
assert len(student_layer) == len(teacher_layer)
for s_name, t_name in zip(student_layer, teacher_layer):
s_module = recursive_getattr(student_model,
module_name_prefix + '.' + str(s_name))
t_module = recursive_getattr(teacher_model,
module_name_prefix + '.' + str(t_name))
s_module = recursive_getattr(student_model, module_name_prefix + '.' + str(s_name))
t_module = recursive_getattr(teacher_model, module_name_prefix + '.' + str(t_name))
for s_param, t_param in zip(s_module.parameters(), t_module.parameters()):
s_param.data.copy_(t_param.data)
for name in other_module_name:

View File

@ -36,9 +36,7 @@ def get_layer_reduction(param_dict):
def get_layer_reduction_enabled(param_dict):
if LAYER_REDUCTION in param_dict.keys():
return get_scalar_param(param_dict[LAYER_REDUCTION],
LAYER_REDUCTION_ENABLED,
LAYER_REDUCTION_ENABLED_DEFAULT)
return get_scalar_param(param_dict[LAYER_REDUCTION], LAYER_REDUCTION_ENABLED, LAYER_REDUCTION_ENABLED_DEFAULT)
else:
return False
@ -70,7 +68,8 @@ def get_weight_quantization(param_dict):
output[SHARED_PARAMETERS] = get_weight_quantization_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Weigh Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_weight_quantization_different_groups(sub_param_dict)
return output
@ -79,51 +78,38 @@ def get_weight_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_TYPE,
output[WEIGHT_QUANTIZE_ENABLED] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_KERNEL] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_KERNEL,
WEIGHT_QUANTIZE_KERNEL_DEFAULT)
output[WEIGHT_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_SCHEDULE_OFFSET,
WEIGHT_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[WEIGHT_QUANTIZE_GROUPS] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_GROUPS,
WEIGHT_QUANTIZE_GROUPS_DEFAULT)
output[WEIGHT_QUANTIZE_VERBOSE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_VERBOSE,
WEIGHT_QUANTIZE_VERBOSE_DEFAULT)
output[WEIGHT_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_TYPE,
WEIGHT_QUANTIZE_TYPE_DEFAULT)
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
assert output[WEIGHT_QUANTIZE_TYPE] in [WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(
sub_param_dict,
WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
assert output[WEIGHT_QUANTIZE_ROUNDING] in [WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
output[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] = get_scalar_param(sub_param_dict,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED,
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED_DEFAULT)
assert output[WEIGHT_QUANTIZE_TYPE] in [
WEIGHT_QUANTIZE_SYMMETRIC, WEIGHT_QUANTIZE_ASYMMETRIC
], f"Invalid weight quantize type. Supported types: [{WEIGHT_QUANTIZE_SYMMETRIC}, {WEIGHT_QUANTIZE_ASYMMETRIC}]"
output[WEIGHT_QUANTIZE_ROUNDING] = get_scalar_param(sub_param_dict, WEIGHT_QUANTIZE_ROUNDING,
WEIGHT_QUANTIZE_ROUNDING_DEFAULT)
assert output[WEIGHT_QUANTIZE_ROUNDING] in [
WEIGHT_QUANTIZE_NEAREST_ROUNDING, WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING
], f"Invalid weight quantize rounding. Supported types: [{WEIGHT_QUANTIZE_NEAREST_ROUNDING}, {WEIGHT_QUANTIZE_STOCHASTIC_ROUNDING}]"
if WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE in sub_param_dict.keys():
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED,
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT)
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = get_scalar_param(
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
WEIGHT_QUANTIZE_CHANGE_RATIO,
sub_param_dict[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], WEIGHT_QUANTIZE_CHANGE_RATIO,
WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT)
else:
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
else:
output[WEIGHT_QUANTIZE_ENABLED] = WEIGHT_QUANTIZE_ENABLED_DEFAULT
@ -133,8 +119,7 @@ def get_weight_quantization_shared_parameters(param_dict):
output[WEIGHT_QUANTIZE_VERBOSE] = WEIGHT_QUANTIZE_VERBOSE_DEFAULT
output[WEIGHT_QUANTIZE_TYPE] = WEIGHT_QUANTIZE_TYPE_DEFAULT
output[WEIGHT_QUANTIZE_ROUNDING] = WEIGHT_QUANTIZE_ROUNDING_DEFAULT
output[
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE] = WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE_ENABLED_DEFAULT
output[WEIGHT_QUANTIZE_CHANGE_RATIO] = WEIGHT_QUANTIZE_CHANGE_RATIO_DEFAULT
return output
@ -144,27 +129,21 @@ def get_weight_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(
group_dict,
WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
assert WEIGHT_QUANTIZE_START_BITS in group_dict.keys(
), f"{WEIGHT_QUANTIZE_START_BITS} must be specified for weight quantization group {name}"
assert WEIGHT_QUANTIZE_TARGET_BITS in group_dict.keys(
), f"{WEIGHT_QUANTIZE_TARGET_BITS} must be specified for weight quantization group {name}"
group_dict[WEIGHT_QUANTIZATION_PERIOD] = get_scalar_param(group_dict, WEIGHT_QUANTIZATION_PERIOD,
WEIGHT_QUANTIZATION_PERIOD_DEFAULT)
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
@ -172,19 +151,15 @@ def get_weight_quantization_different_groups(param_dict):
def get_activation_quantization(param_dict):
output = {}
if ACTIVATION_QUANTIZATION not in param_dict.keys():
param_dict[ACTIVATION_QUANTIZATION] = {
SHARED_PARAMETERS: {},
DIFFERENT_GROUPS: {}
}
param_dict[ACTIVATION_QUANTIZATION] = {SHARED_PARAMETERS: {}, DIFFERENT_GROUPS: {}}
sub_param_dict = param_dict[ACTIVATION_QUANTIZATION]
# shared parameters
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(
sub_param_dict)
output[SHARED_PARAMETERS] = get_activation_quantization_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ACTIVATION_QUANTIZATION_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(
sub_param_dict)
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Activation Quantization is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_activation_quantization_different_groups(sub_param_dict)
return output
@ -192,30 +167,26 @@ def get_activation_quantization_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZATION_ENABLED,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZE_TYPE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_TYPE] in [ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_RANGE,
ACTIVATION_QUANTIZE_RANGE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_RANGE] in [ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
output[ACTIVATION_QUANTIZATION_ENABLED] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZATION_ENABLED,
ACTIVATION_QUANTIZATION_ENABLED_DEFAULT)
output[ACTIVATION_QUANTIZE_TYPE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_TYPE,
ACTIVATION_QUANTIZE_TYPE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_TYPE] in [
ACTIVATION_QUANTIZE_SYMMETRIC, ACTIVATION_QUANTIZE_ASYMMETRIC
], f"Invalid activation quantize type. Supported types: [{ACTIVATION_QUANTIZE_SYMMETRIC}, {ACTIVATION_QUANTIZE_ASYMMETRIC}]"
output[ACTIVATION_QUANTIZE_RANGE] = get_scalar_param(sub_param_dict, ACTIVATION_QUANTIZE_RANGE,
ACTIVATION_QUANTIZE_RANGE_DEFAULT)
assert output[ACTIVATION_QUANTIZE_RANGE] in [
ACTIVATION_QUANTIZE_RANGE_DYNAMIC, ACTIVATION_QUANTIZE_RANGE_STATIC
], f"Invalid activation quantize range calibration. Supported types: [{ACTIVATION_QUANTIZE_RANGE_DYNAMIC}, {ACTIVATION_QUANTIZE_RANGE_STATIC}]"
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET,
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT)
else:
output[ACTIVATION_QUANTIZATION_ENABLED] = ACTIVATION_QUANTIZATION_ENABLED_DEFAULT
output[ACTIVATION_QUANTIZE_TYPE] = ACTIVATION_QUANTIZE_TYPE_DEFAULT
output[ACTIVATION_QUANTIZE_RANGE] = ACTIVATION_QUANTIZE_RANGE_DEFAULT
output[
ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
output[ACTIVATION_QUANTIZE_SCHEDULE_OFFSET] = ACTIVATION_QUANTIZE_SCHEDULE_OFFSET_DEFAULT
return output
@ -224,22 +195,17 @@ def get_activation_quantization_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
assert ACTIVATION_QUANTIZE_BITS in group_dict.keys(
), f"{ACTIVATION_QUANTIZE_BITS} must be specified for activation quantization group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
@ -253,7 +219,8 @@ def get_sparse_pruning(param_dict):
output[SHARED_PARAMETERS] = get_sparse_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][SPARSE_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_sparse_pruning_different_groups(sub_param_dict)
return output
@ -262,18 +229,15 @@ def get_sparse_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
SPARSE_PRUNING_METHOD,
output[SPARSE_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_ENABLED,
SPARSE_PRUNING_ENABLED_DEFAULT)
output[SPARSE_PRUNING_METHOD] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_METHOD,
SPARSE_PRUNING_METHOD_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
assert output[SPARSE_PRUNING_METHOD] in [
SPARSE_PRUNING_METHOD_L1, SPARSE_PRUNING_METHOD_TOPK
], f"Invalid sparse pruning method. Supported types: [{SPARSE_PRUNING_METHOD_L1}, {SPARSE_PRUNING_METHOD_TOPK}]"
output[SPARSE_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, SPARSE_PRUNING_SCHEDULE_OFFSET,
SPARSE_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[SPARSE_PRUNING_ENABLED] = SPARSE_PRUNING_ENABLED_DEFAULT
output[SPARSE_PRUNING_METHOD] = SPARSE_PRUNING_METHOD_DEFAULT
@ -286,22 +250,17 @@ def get_sparse_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
assert SPARSE_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{SPARSE_PRUNING_DENSE_RATIO} must be specified for sparse pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
@ -315,7 +274,8 @@ def get_row_pruning(param_dict):
output[SHARED_PARAMETERS] = get_row_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][ROW_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Row Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_row_pruning_different_groups(sub_param_dict)
return output
@ -324,17 +284,14 @@ def get_row_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
ROW_PRUNING_ENABLED,
output[ROW_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, ROW_PRUNING_ENABLED,
ROW_PRUNING_ENABLED_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
ROW_PRUNING_METHOD,
ROW_PRUNING_METHOD_DEFAULT)
assert output[ROW_PRUNING_METHOD] in [ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
output[ROW_PRUNING_METHOD] = get_scalar_param(sub_param_dict, ROW_PRUNING_METHOD, ROW_PRUNING_METHOD_DEFAULT)
assert output[ROW_PRUNING_METHOD] in [
ROW_PRUNING_METHOD_L1, ROW_PRUNING_METHOD_TOPK
], f"Invalid row pruning method. Supported types: [{ROW_PRUNING_METHOD_L1}, {ROW_PRUNING_METHOD_TOPK}]"
output[ROW_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, ROW_PRUNING_SCHEDULE_OFFSET,
ROW_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[ROW_PRUNING_ENABLED] = ROW_PRUNING_ENABLED_DEFAULT
output[ROW_PRUNING_METHOD] = ROW_PRUNING_METHOD_DEFAULT
@ -347,22 +304,17 @@ def get_row_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
assert ROW_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{ROW_PRUNING_DENSE_RATIO} must be specified for row pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
@ -375,7 +327,8 @@ def get_head_pruning(param_dict):
output[SHARED_PARAMETERS] = get_head_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][HEAD_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Head Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_head_pruning_different_groups(sub_param_dict)
return output
@ -384,19 +337,18 @@ def get_head_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_ENABLED,
output[HEAD_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, HEAD_PRUNING_ENABLED,
HEAD_PRUNING_ENABLED_DEFAULT)
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict,
HEAD_PRUNING_METHOD,
output[HEAD_PRUNING_METHOD] = get_scalar_param(sub_param_dict, HEAD_PRUNING_METHOD,
HEAD_PRUNING_METHOD_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
assert output[HEAD_PRUNING_METHOD] in [
HEAD_PRUNING_METHOD_L1, HEAD_PRUNING_METHOD_TOPK
], f"Invalid head pruning method. Supported types: [{HEAD_PRUNING_METHOD_L1}, {HEAD_PRUNING_METHOD_TOPK}]"
output[HEAD_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, HEAD_PRUNING_SCHEDULE_OFFSET,
HEAD_PRUNING_SCHEDULE_OFFSET_DEFAULT)
if output[HEAD_PRUNING_ENABLED]:
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
assert HEAD_PRUNING_NUM_HEADS in sub_param_dict.keys(
), f"{HEAD_PRUNING_NUM_HEADS} must be specified for head pruning"
output[HEAD_PRUNING_NUM_HEADS] = sub_param_dict[HEAD_PRUNING_NUM_HEADS]
else:
output[HEAD_PRUNING_ENABLED] = HEAD_PRUNING_ENABLED_DEFAULT
@ -410,22 +362,17 @@ def get_head_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(), f"dense_ratio must be specified for head pruning group {name}"
assert HEAD_PRUNING_DENSE_RATIO in group_dict.keys(
), f"dense_ratio must be specified for head pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output
@ -438,7 +385,8 @@ def get_channel_pruning(param_dict):
output[SHARED_PARAMETERS] = get_channel_pruning_shared_parameters(sub_param_dict)
# each sub-groups
if output[SHARED_PARAMETERS][CHANNEL_PRUNING_ENABLED]:
assert DIFFERENT_GROUPS in sub_param_dict.keys(), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
assert DIFFERENT_GROUPS in sub_param_dict.keys(
), f"Sparse Pruning is enabled, {DIFFERENT_GROUPS} must be specified"
output[DIFFERENT_GROUPS] = get_channel_pruning_different_groups(sub_param_dict)
return output
@ -447,19 +395,15 @@ def get_channel_pruning_shared_parameters(param_dict):
output = {}
if SHARED_PARAMETERS in param_dict.keys():
sub_param_dict = param_dict[SHARED_PARAMETERS]
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_ENABLED,
CHANNEL_PRUNING_ENABLED_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_METHOD_DEFAULT)
assert output[CHANNEL_PRUNING_METHOD] in [CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(
sub_param_dict,
CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
output[CHANNEL_PRUNING_ENABLED] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_ENABLED,
CHANNEL_PRUNING_ENABLED_DEFAULT)
output[CHANNEL_PRUNING_METHOD] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_METHOD,
CHANNEL_PRUNING_METHOD_DEFAULT)
assert output[CHANNEL_PRUNING_METHOD] in [
CHANNEL_PRUNING_METHOD_L1, CHANNEL_PRUNING_METHOD_TOPK
], f"Invalid channel pruning method. Supported types: [{CHANNEL_PRUNING_METHOD_L1}, {CHANNEL_PRUNING_METHOD_TOPK}]"
output[CHANNEL_PRUNING_SCHEDULE_OFFSET] = get_scalar_param(sub_param_dict, CHANNEL_PRUNING_SCHEDULE_OFFSET,
CHANNEL_PRUNING_SCHEDULE_OFFSET_DEFAULT)
else:
output[CHANNEL_PRUNING_ENABLED] = CHANNEL_PRUNING_ENABLED_DEFAULT
output[CHANNEL_PRUNING_METHOD] = CHANNEL_PRUNING_METHOD_DEFAULT
@ -472,21 +416,16 @@ def get_channel_pruning_different_groups(param_dict):
sub_param_dict = param_dict[DIFFERENT_GROUPS]
def get_params(name, group_dict):
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
assert CHANNEL_PRUNING_DENSE_RATIO in group_dict.keys(
), f"{CHANNEL_PRUNING_DENSE_RATIO} must be specified for channel pruning group {name}"
return group_dict
for k, v in sub_param_dict.items():
output[k] = {}
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(
k,
sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_PARAMETERS] = get_params(k, sub_param_dict[k][DIFFERENT_GROUPS_PARAMETERS])
output[k][DIFFERENT_GROUPS_MODULE_SCOPE] = get_scalar_param(sub_param_dict[k], DIFFERENT_GROUPS_MODULE_SCOPE,
DIFFERENT_GROUPS_MODULE_SCOPE_DEFAULT)
output[k][DIFFERENT_GROUPS_RELATED_MODULE_SCOPE] = get_scalar_param(
sub_param_dict[k],
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE,
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
sub_param_dict[k], DIFFERENT_GROUPS_RELATED_MODULE_SCOPE, DIFFERENT_GROUPS_RELATED_MODULE_SCOPE_DEFAULT)
return output

View File

@ -59,23 +59,17 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
need_bias = True
# Initialize the new module
if isinstance(old_module,
LinearLayer_Compress) or isinstance(old_module,
torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress) or isinstance(old_module, torch.nn.Linear):
if isinstance(old_module, LinearLayer_Compress):
new_module = old_module
else:
new_module = LinearLayer_Compress(old_module.in_features,
old_module.out_features,
bias=need_bias).to(
device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = LinearLayer_Compress(old_module.in_features, old_module.out_features,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module,
Conv2dLayer_Compress) or isinstance(old_module,
torch.nn.Conv2d):
elif isinstance(old_module, Conv2dLayer_Compress) or isinstance(old_module, torch.nn.Conv2d):
if isinstance(old_module, Conv2dLayer_Compress):
new_module = old_module
else:
@ -86,60 +80,48 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
if need_bias:
new_module.bias.data = old_module.bias.data
elif isinstance(old_module, torch.nn.BatchNorm2d):
new_module = BNLayer_Compress(old_module.num_features,
old_module.eps,
old_module.momentum,
old_module.affine,
old_module.track_running_stats).to(
old_module.weight.device,
old_module.weight.dtype)
new_module = BNLayer_Compress(old_module.num_features, old_module.eps, old_module.momentum, old_module.affine,
old_module.track_running_stats).to(old_module.weight.device,
old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
new_module.running_mean.data = old_module.running_mean.data
new_module.running_var.data = old_module.running_var.data
elif isinstance(old_module,
Embedding_Compress) or isinstance(old_module,
torch.nn.Embedding):
elif isinstance(old_module, Embedding_Compress) or isinstance(old_module, torch.nn.Embedding):
if isinstance(old_module, Embedding_Compress):
new_module = old_module
else:
new_module = Embedding_Compress(old_module.num_embeddings, old_module.embedding_dim, old_module.padding_idx, old_module.max_norm, old_module.norm_type, \
old_module.scale_grad_by_freq, old_module.sparse).to(device=old_module.weight.device, dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
elif mpu is not None and (isinstance(old_module,
ColumnParallelLinear_Compress)
or isinstance(old_module,
mpu.ColumnParallelLinear)):
elif mpu is not None and (isinstance(old_module, ColumnParallelLinear_Compress)
or isinstance(old_module, mpu.ColumnParallelLinear)):
if isinstance(old_module, ColumnParallelLinear_Compress):
new_module = old_module
else:
new_module = ColumnParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = ColumnParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
gather_output=old_module.gather_output,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
elif mpu is not None and (isinstance(old_module,
RowParallelLinear_Compress)
or isinstance(old_module,
mpu.RowParallelLinear)):
elif mpu is not None and (isinstance(old_module, RowParallelLinear_Compress)
or isinstance(old_module, mpu.RowParallelLinear)):
if isinstance(old_module, RowParallelLinear_Compress):
new_module = old_module
else:
new_module = RowParallelLinear_Compress(
mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module = RowParallelLinear_Compress(mpu,
old_module.input_size,
old_module.output_size,
input_is_parallel=old_module.input_is_parallel,
skip_bias_add=old_module.skip_bias_add,
bias=need_bias).to(device=old_module.weight.device,
dtype=old_module.weight.dtype)
new_module.weight.data = old_module.weight.data
if need_bias:
new_module.bias.data = old_module.bias.data
@ -150,39 +132,30 @@ def module_replacement(model, module_name, compression_technique=None, mpu=None)
for k, v in compression_technique.items():
if k == SPARSE_PRUNING:
if v[SPARSE_PRUNING_ENABLED]:
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO],
v[SPARSE_PRUNING_METHOD])
new_module.enable_sparse_pruning(v[SPARSE_PRUNING_DENSE_RATIO], v[SPARSE_PRUNING_METHOD])
elif k == ROW_PRUNING:
if v[ROW_PRUNING_ENABLED]:
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO],
v[ROW_PRUNING_METHOD])
new_module.enable_row_pruning(v[ROW_PRUNING_DENSE_RATIO], v[ROW_PRUNING_METHOD])
elif k == HEAD_PRUNING:
if v[HEAD_PRUNING_ENABLED]:
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO],
v[HEAD_PRUNING_METHOD],
new_module.enable_head_pruning(v[HEAD_PRUNING_DENSE_RATIO], v[HEAD_PRUNING_METHOD],
v[HEAD_PRUNING_NUM_HEADS])
elif k == ACTIVATION_QUANTIZATION:
if v[ACTIVATION_QUANTIZATION_ENABLED]:
new_module.enable_activation_quantization(
v[ACTIVATION_QUANTIZE_BITS],
v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
new_module.enable_activation_quantization(v[ACTIVATION_QUANTIZE_BITS], v[ACTIVATION_QUANTIZE_TYPE],
v[ACTIVATION_QUANTIZE_RANGE])
elif k == WEIGHT_QUANTIZATION:
if v[WEIGHT_QUANTIZE_ENABLED]:
new_module.enable_weight_quantization(
v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE],
v[WEIGHT_QUANTIZE_GROUPS])
new_module.enable_weight_quantization(v[WEIGHT_QUANTIZE_START_BITS],
v[WEIGHT_QUANTIZE_TARGET_BITS],
v[WEIGHT_QUANTIZATION_PERIOD],
v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
v[WEIGHT_QUANTIZE_TYPE], v[WEIGHT_QUANTIZE_GROUPS])
elif k == CHANNEL_PRUNING:
if v[CHANNEL_PRUNING_ENABLED]:
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO],
v[CHANNEL_PRUNING_METHOD])
new_module.enable_channel_pruning(v[CHANNEL_PRUNING_DENSE_RATIO], v[CHANNEL_PRUNING_METHOD])
else:
raise NotImplementedError(
'Compression technique {} is not implemented'.format(k))
raise NotImplementedError('Compression technique {} is not implemented'.format(k))
# Replace the old module with the new one
recursive_setattr(model, module_name, new_module)
@ -195,10 +168,7 @@ def is_module_compressible(module, mpu=None):
isinstance(module, torch.nn.BatchNorm2d)
if mpu is not None:
ret = ret or isinstance(module,
mpu.RowParallelLinear) or isinstance(
module,
mpu.ColumnParallelLinear)
ret = ret or isinstance(module, mpu.RowParallelLinear) or isinstance(module, mpu.ColumnParallelLinear)
return ret
@ -225,11 +195,7 @@ def compression_preparation(model, compression_techinique_list, mpu):
return model
def fix_compression(model,
module_name,
compression_technique,
mask=None,
dim_reduction=False):
def fix_compression(model, module_name, compression_technique, mask=None, dim_reduction=False):
"""
Fix the compression technique of a module.
Args:
@ -243,17 +209,14 @@ def fix_compression(model,
# Here we can make things much simpler by just replacing the module
module = recursive_getattr(model, module_name)
for k, v in compression_technique.items():
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[
WEIGHT_QUANTIZE_ENABLED]:
if k == WEIGHT_QUANTIZATION and v[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED] and v[WEIGHT_QUANTIZE_ENABLED]:
return module.fix_weight_quantization()
elif k == SPARSE_PRUNING and v[SPARSE_PRUNING_ENABLED]:
return module.fix_sparse_pruning_helper()
elif k == ROW_PRUNING and (v[ROW_PRUNING_ENABLED] or mask is not None):
return module.fix_row_col_pruning_helper(mask, dim_reduction=dim_reduction)
elif k == HEAD_PRUNING and (v[HEAD_PRUNING_ENABLED] or mask is not None):
return module.fix_head_pruning_helper(mask,
v[HEAD_PRUNING_NUM_HEADS],
dim_reduction=dim_reduction)
return module.fix_head_pruning_helper(mask, v[HEAD_PRUNING_NUM_HEADS], dim_reduction=dim_reduction)
elif k == CHANNEL_PRUNING and (v[CHANNEL_PRUNING_ENABLED] or mask is not None):
return module.fix_channel_pruning_helper(mask, dim_reduction=dim_reduction)
@ -270,10 +233,9 @@ def convert_conv1d_to_linear(model, convert_type):
for name, module in c_model.named_modules():
if isinstance(module, convert_type):
old_module = recursive_getattr(c_model, name)
new_module = torch.nn.Linear(
old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module = torch.nn.Linear(old_module.weight.data.size(0),
old_module.weight.data.size(1),
bias=True if old_module.bias is not None else False)
new_module.weight.data = old_module.weight.data.t().contiguous()
if new_module.bias is not None:
new_module.bias.data = old_module.bias.data.view(-1)

View File

@ -10,6 +10,7 @@ class compression_scheduler():
'''
Used to schedule different compression methods
'''
def __init__(self, model, compression_config):
self.model = model
self.compression_config = compression_config
@ -38,22 +39,22 @@ class compression_scheduler():
}
exist_module_name = set()
shared_parameters = method_content[SHARED_PARAMETERS]
self.different_compression_methods[method][
TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
self.different_compression_methods[method][
SHARED_PARAMETERS] = shared_parameters
self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED]
self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items():
module_name_list = []
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]:
module_name, exist_module_name = get_module_name(group_name, self.model, key_word, exist_module_name, verbose=False)
module_name, exist_module_name = get_module_name(group_name,
self.model,
key_word,
exist_module_name,
verbose=False)
module_name_list.extend(module_name)
if module_name_list:
self.different_compression_methods[method][DIFFERENT_GROUPS].append([
group_name,
module_name_list,
method_parameters.copy().pop('params')
])
self.different_compression_methods[method][DIFFERENT_GROUPS].append(
[group_name, module_name_list,
method_parameters.copy().pop('params')])
def check_weight_quantization(self):
# check weight quantization
@ -69,8 +70,7 @@ class compression_scheduler():
module.weight_quantization_enabled = True
if not self.verbose[WEIGHT_QUANTIZATION]:
logger.info(
f'Weight quantization is enabled at step {self.training_steps}')
logger.info(f'Weight quantization is enabled at step {self.training_steps}')
self.weight_quantization_enabled = True
self.verbose[WEIGHT_QUANTIZATION] = True
@ -87,9 +87,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.activation_quantization_enabled = True
if not self.verbose[ACTIVATION_QUANTIZATION]:
logger.info(
f'Activation quantization is enabled at step {self.training_steps}'
)
logger.info(f'Activation quantization is enabled at step {self.training_steps}')
self.verbose[ACTIVATION_QUANTIZATION] = True
def check_sparse_pruning(self):
@ -105,8 +103,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.sparse_pruning_enabled = True
if not self.verbose[SPARSE_PRUNING]:
logger.info(
f'Sparse pruning is enabled at step {self.training_steps}')
logger.info(f'Sparse pruning is enabled at step {self.training_steps}')
self.verbose[SPARSE_PRUNING] = True
def check_head_pruning(self):
@ -154,8 +151,7 @@ class compression_scheduler():
module = recursive_getattr(self.model, module_name)
module.channel_pruning_enabled = True
if not self.verbose[CHANNEL_PRUNING]:
logger.info(
f'Channel pruning is enabled at step {self.training_steps}')
logger.info(f'Channel pruning is enabled at step {self.training_steps}')
self.verbose[CHANNEL_PRUNING] = True
def check_all_modules(self):

View File

@ -13,6 +13,7 @@ class TopKBinarizer(autograd.Function):
Implementation is inspired from:
https://github.com/yaozhewei/MLPruning
"""
@staticmethod
def forward(ctx, inputs: torch.tensor, threshold: float, sigmoid: bool):
"""
@ -59,6 +60,7 @@ class SymQuantizer(torch.autograd.Function):
"""
Symmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
@ -75,9 +77,8 @@ class SymQuantizer(torch.autograd.Function):
quantized_input (`torch.FloatTensor`)
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
@ -101,6 +102,7 @@ class AsymQuantizer(torch.autograd.Function):
"""
Asymmetric quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
@ -118,9 +120,8 @@ class AsymQuantizer(torch.autograd.Function):
Quantized input
"""
assert (min_value is None
and max_value is None) or (min_value is not None
and max_value is not None and num_groups == 1)
assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None
and num_groups == 1)
q_range = 2**num_bits
input_shape = input.shape
if min_value is None:
@ -131,9 +132,7 @@ class AsymQuantizer(torch.autograd.Function):
scale = (max_value - min_value) / q_range
zero_point = (min_value / scale).round() * scale
output = (
(input - zero_point) / scale).round().clamp(0,
q_range - 1) * scale + zero_point
output = ((input - zero_point) / scale).round().clamp(0, q_range - 1) * scale + zero_point
output = output.reshape(input_shape).contiguous()
return output
@ -147,6 +146,7 @@ class TernaryQuantizer(torch.autograd.Function):
"""
Ternary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""
@ -187,6 +187,7 @@ class BinaryQuantizer(torch.autograd.Function):
"""
Binary quantization
"""
@staticmethod
def forward(ctx, input, num_bits, min_value=None, max_value=None, num_groups=1):
"""

View File

@ -43,77 +43,64 @@ class ElasticityConfig:
"version": 0.1
}
"""
def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else:
raise ElasticityConfigError(
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
raise ElasticityConfigError(f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES]
else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else:
self.max_acceptable_batch_size = param_dict.get(
MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.max_acceptable_batch_size = param_dict.get(MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
if not isinstance(self.micro_batches, list):
raise ElasticityConfigError(
f"Elasticity expected value of {MICRO_BATCHES} to be a "
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}"
)
f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}")
if not all(map(lambda m: isinstance(m, int), self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"instead contains: f{self.micro_batches}")
raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
f"instead contains: f{self.micro_batches}")
if not all(map(lambda m: m > 0, self.micro_batches)):
raise ElasticityConfigError(
f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"instead contains: f{self.micro_batches}")
raise ElasticityConfigError(f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
f"instead contains: f{self.micro_batches}")
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
if self.min_gpus < 1 or self.max_gpus < 1:
raise ElasticityConfigError(
"Elasticity min/max gpus must be > 0, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
raise ElasticityConfigError("Elasticity min/max gpus must be > 0, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
if self.max_gpus < self.min_gpus:
raise ElasticityConfigError(
"Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
raise ElasticityConfigError("Elasticity min_gpus cannot be greater than max_gpus, "
f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
MODEL_PARLLEL_SIZE_DEFAULT)
self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE, MODEL_PARLLEL_SIZE_DEFAULT)
if self.model_parallel_size < 1:
raise ElasticityConfigError(
"Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")
raise ElasticityConfigError("Model-Parallel size cannot be less than 1, "
f"given model-parallel size: {self.model_parallel_size}")
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
NUM_GPUS_PER_NODE_DEFAULT)
self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE, NUM_GPUS_PER_NODE_DEFAULT)
if self.num_gpus_per_node < 1:
raise ElasticityConfigError(
"Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
raise ElasticityConfigError("Number of GPUs per node cannot be less than 1, "
f"given number of GPUs per node: {self.num_gpus_per_node}")
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
if self.min_time < 0:
raise ElasticityConfigError(
f"Elasticity min time needs to be >= 0: given {self.min_time}")
raise ElasticityConfigError(f"Elasticity min time needs to be >= 0: given {self.min_time}")
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
def repr(self):
return self.__dict__

View File

@ -23,6 +23,7 @@ import subprocess
class DSElasticAgent(LocalElasticAgent):
def __init__(
self,
spec: WorkerSpec,
@ -35,9 +36,7 @@ class DSElasticAgent(LocalElasticAgent):
self.ds_env = env
@staticmethod
def _set_master_addr_port(store: Store,
master_addr: Optional[str],
master_port: Optional[int]):
def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
@ -82,8 +81,7 @@ class DSElasticAgent(LocalElasticAgent):
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
str(1)),
"NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)),
}
worker_env_ds.update(worker_env_elastic)
if "OMP_NUM_THREADS" in os.environ:
@ -120,8 +118,7 @@ class DSElasticAgent(LocalElasticAgent):
spec = self._worker_group.spec
role = spec.role
log.info(
f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
@ -136,13 +133,10 @@ class DSElasticAgent(LocalElasticAgent):
state = run_result.state
self._worker_group.state = state
expire_time = datetime.utcnow() - (
rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval *
rdzv_handler._settings.keep_alive_max_attempt)
_dead_nodes = [
node for node,
last_heartbeat in
rdzv_handler._state_holder.state.last_heartbeats.items()
node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items()
if last_heartbeat < expire_time
]
@ -150,21 +144,16 @@ class DSElasticAgent(LocalElasticAgent):
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
log.info(f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
self._exit_barrier()
return run_result
elif state in {
WorkerState.UNHEALTHY,
WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
} or len(participants) > len(rdzv_handler._state_holder.state.participants):
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
log.info(f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group")
self._remaining_restarts -= 1
# rdzv_handler._state_holder.state.restart = False
self._restart_workers(self._worker_group)

View File

@ -17,44 +17,8 @@ from ..utils import logger
# Thirty eight smallest highly composite numbers. The list should
# be enough to support up to 720K batch size.
HCN_LIST = [
1,
2,
4,
6,
12,
24,
36,
48,
60,
120,
180,
240,
360,
720,
840,
1260,
1680,
2520,
5040,
7560,
10080,
15120,
20160,
25200,
27720,
45360,
50400,
55440,
83160,
110880,
166320,
221760,
277200,
332640,
498960,
554400,
665280,
720720
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680, 2520, 5040, 7560, 10080, 15120, 20160,
25200, 27720, 45360, 50400, 55440, 83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280, 720720
]
@ -94,11 +58,7 @@ def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
return valid_gpus
def get_best_candidates(candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger):
def get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus, prefer_larger):
max_valid_gpus = 0
valid_gpus = None
@ -106,15 +66,11 @@ def get_best_candidates(candidate_batch_sizes,
for batch_size in candidate_batch_sizes:
current_valid_gpus = get_valid_gpus(batch_size,
micro_batches,
min_gpus,
max_gpus)
current_valid_gpus = get_valid_gpus(batch_size, micro_batches, min_gpus, max_gpus)
if (len(current_valid_gpus) > max_valid_gpus
or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
if (len(current_valid_gpus) > max_valid_gpus or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
max_valid_gpus = len(current_valid_gpus)
valid_gpus = current_valid_gpus
final_batch_size = batch_size
@ -157,15 +113,10 @@ def _get_compatible_gpus_v01(micro_batches,
base_list.extend(micro_batches)
base_list.append(lcm)
candidate_batch_sizes = get_candidate_batch_sizes(base_list,
max_acceptable_batch_size)
candidate_batch_sizes = get_candidate_batch_sizes(base_list, max_acceptable_batch_size)
final_batch_size, valid_gpus = get_best_candidates(
candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger)
final_batch_size, valid_gpus = get_best_candidates(candidate_batch_sizes, micro_batches, min_gpus, max_gpus,
prefer_larger)
return final_batch_size, valid_gpus
@ -203,11 +154,12 @@ def _get_compatible_gpus_v02(micro_batches,
dp_size_per_node = num_gpus_per_node // model_parallel_size
final_batch_size, valid_world_size = _get_compatible_gpus_v01(micro_batches,
int(max_acceptable_batch_size/dp_size_per_node),
int(min_gpus/num_gpus_per_node),
int(max_gpus/num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size, valid_world_size = _get_compatible_gpus_v01(
micro_batches,
int(max_acceptable_batch_size / dp_size_per_node),
int(min_gpus / num_gpus_per_node),
int(max_gpus / num_gpus_per_node), # Passing number of max nodes as Elasticity v2 works at node level
prefer_larger=prefer_larger)
final_batch_size = int(final_batch_size) * dp_size_per_node
valid_dp_world_size = [i * dp_size_per_node for i in valid_world_size]
@ -256,38 +208,27 @@ def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
Ensure the resource scheduler saw the same elastic config we are using at runtime
"""
if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
scheduler_elastic_config_dict = json.loads(
os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config_dict = json.loads(os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
raise ElasticityConfigError(
err_str.format('max_acceptable_batch_size',
scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size',
runtime_elastic_config.max_acceptable_batch_size))
err_str.format('max_acceptable_batch_size', scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size', runtime_elastic_config.max_acceptable_batch_size))
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
raise ElasticityConfigError(
err_str.format('micro_batches',
scheduler_elastic_config.micro_batches,
'micro_batches',
err_str.format('micro_batches', scheduler_elastic_config.micro_batches, 'micro_batches',
runtime_elastic_config.micro_batches))
if runtime_elastic_config.version != scheduler_elastic_config.version:
raise ElasticityConfigError(
err_str.format('version',
scheduler_elastic_config.version,
'version',
runtime_elastic_config.version))
err_str.format('version', scheduler_elastic_config.version, 'version', runtime_elastic_config.version))
else:
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
"guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict,
target_deepspeed_version: str,
world_size=0,
return_microbatch=False):
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0, return_microbatch=False):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale
@ -397,8 +338,7 @@ def compute_elastic_config(ds_config: dict,
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
else:
raise NotImplementedError(
f"Unable to find elastic logic for version: {elastic_config.version}")
raise NotImplementedError(f"Unable to find elastic logic for version: {elastic_config.version}")
logger.info(f"Valid World Size (GPUs / Model Parallel Size): {valid_gpus}")

View File

@ -48,8 +48,7 @@ def op_report(verbose=True):
dots = "." * (max_dots - len(op_name))
is_compatible = OKAY if builder.is_compatible(verbose) else no
is_installed = installed if installed_ops[op_name] else no
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) -
(len(is_installed) - color_len))
dots2 = '.' * ((len(h[1]) + (max_dots2 - len(h[1]))) - (len(is_installed) - color_len))
print(op_name, dots, is_installed, dots2, is_compatible)
print("-" * (max_dots + max_dots2 + len(h[0]) + len(h[1])))
@ -68,9 +67,7 @@ def nvcc_version():
if cuda_home is None:
return f"{RED} [FAIL] cannot find CUDA_HOME via torch.utils.cpp_extension.CUDA_HOME={torch.utils.cpp_extension.CUDA_HOME} {END}"
try:
output = subprocess.check_output([cuda_home + "/bin/nvcc",
"-V"],
universal_newlines=True)
output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
except FileNotFoundError:
return f"{RED} [FAIL] nvcc missing {END}"
output_split = output.split()
@ -82,32 +79,18 @@ def nvcc_version():
def debug_report():
max_dots = 33
report = [
("torch install path",
torch.__path__),
("torch version",
torch.__version__),
("deepspeed install path",
deepspeed.__path__),
("deepspeed info",
f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}"
)
]
report = [("torch install path", torch.__path__), ("torch version", torch.__version__),
("deepspeed install path", deepspeed.__path__),
("deepspeed info", f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}")]
if get_accelerator().device_name() == 'cuda':
hip_version = getattr(torch.version, "hip", None)
report.extend([("torch cuda version",
torch.version.cuda),
("torch hip version",
hip_version),
("nvcc version",
(None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.",
f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}"
if hip_version else f"cuda {torch_info['cuda_version']}"))])
report.extend([("torch cuda version", torch.version.cuda), ("torch hip version", hip_version),
("nvcc version", (None if hip_version else nvcc_version())),
("deepspeed wheel compiled w.", f"torch {torch_info['version']}, " +
(f"hip {torch_info['hip_version']}" if hip_version else f"cuda {torch_info['cuda_version']}"))
])
else:
report.extend([("deepspeed wheel compiled w.",
f"torch {torch_info['version']} ")])
report.extend([("deepspeed wheel compiled w.", f"torch {torch_info['version']} ")])
print("DeepSpeed general environment info:")
for name, value in report:
@ -116,15 +99,10 @@ def debug_report():
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--hide_operator_status',
action='store_true',
help=
'Suppress display of installation and compatibility statuses of DeepSpeed operators. '
)
parser.add_argument('--hide_errors_and_warnings',
parser.add_argument('--hide_operator_status',
action='store_true',
help='Suppress warning and error messages.')
help='Suppress display of installation and compatibility statuses of DeepSpeed operators. ')
parser.add_argument('--hide_errors_and_warnings', action='store_true', help='Suppress warning and error messages.')
args = parser.parse_args()
return args
@ -137,8 +115,7 @@ def main(hide_operator_status=False, hide_errors_and_warnings=False):
def cli_main():
args = parse_arguments()
main(hide_operator_status=args.hide_operator_status,
hide_errors_and_warnings=args.hide_errors_and_warnings)
main(hide_operator_status=args.hide_operator_status, hide_errors_and_warnings=args.hide_errors_and_warnings)
if __name__ == "__main__":

View File

@ -224,9 +224,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
replace_method: str = Field(
"auto",
deprecated=True,
deprecated_msg=
"This parameter is no longer needed, please remove from your call to DeepSpeed-inference"
)
deprecated_msg="This parameter is no longer needed, please remove from your call to DeepSpeed-inference")
injection_policy: Dict = Field(None, alias="injection_dict")
"""
@ -237,9 +235,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
injection_policy_tuple: tuple = None
""" TODO: Add docs """
config: Dict = Field(
None,
alias="args") # todo: really no need for this field if we can refactor
config: Dict = Field(None, alias="args") # todo: really no need for this field if we can refactor
max_out_tokens: int = Field(1024, alias="max_tokens")
"""
@ -256,18 +252,10 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
"""
mpu: object = Field(None, deprecated=True, new_param="tensor_parallel.mpu")
ep_size: int = Field(1, deprecated=True, new_param="moe.ep_size")
ep_group: object = Field(None,
alias="expert_group",
deprecated=True,
new_param="moe.ep_group")
ep_mp_group: object = Field(None,
alias="expert_mp_group",
deprecated=True,
new_param="moe.ep_mp_group")
ep_group: object = Field(None, alias="expert_group", deprecated=True, new_param="moe.ep_group")
ep_mp_group: object = Field(None, alias="expert_mp_group", deprecated=True, new_param="moe.ep_mp_group")
moe_experts: list = Field([1], deprecated=True, new_param="moe.moe_experts")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard,
deprecated=True,
new_param="moe.type")
moe_type: MoETypeEnum = Field(MoETypeEnum.standard, deprecated=True, new_param="moe.type")
@validator("moe")
def moe_backward_compat(cls, field_value, values):

View File

@ -80,8 +80,7 @@ class InferenceEngine(Module):
self.checkpoint_engine = TorchCheckpointEngine()
quantization_setting = None
self._init_quantization_setting(
quantization_setting
) # todo: update with the new quant config for weight quant
quantization_setting) # todo: update with the new quant config for weight quant
self.model_profile_enabled = False
self._model_times = []
@ -100,8 +99,7 @@ class InferenceEngine(Module):
self._convert_to_dtype(config)
if self.mpu:
config.tensor_parallel.tp_size = dist.get_world_size(
group=self.mpu.get_model_parallel_group())
config.tensor_parallel.tp_size = dist.get_world_size(group=self.mpu.get_model_parallel_group())
self.mp_group = self.mpu.get_model_parallel_group()
elif config.tensor_parallel.tp_size > 1:
self._create_model_parallel_group(config)
@ -149,8 +147,7 @@ class InferenceEngine(Module):
self.module.to(device)
if config.tensor_parallel.tp_size > 1:
_rng_state = get_accelerator().get_rng_state().to(
get_accelerator().current_device_name())
_rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
dist.broadcast(_rng_state, 0)
get_accelerator().set_rng_state(_rng_state.cpu())
@ -172,9 +169,7 @@ class InferenceEngine(Module):
# todo: remove this once all the config dicts are centralized from top level pydantic config
def _get_model_config_generate(self, config):
# this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
self.config = getattr(self.module,
'config',
None) if config.config is None else config.config
self.config = getattr(self.module, 'config', None) if config.config is None else config.config
def remove_mask_prepare_for_bloom(self):
if hasattr(self.module, 'transformer'):
@ -223,8 +218,7 @@ class InferenceEngine(Module):
num_ep_groups = dist.get_world_size() // moe_ep_size
for i in range(num_ep_groups):
ep_cnt = i * moe_ep_size
size = dist.get_world_size(
) if moe_ep_size > dist.get_world_size() else moe_ep_size
size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size
ranks = list(range(ep_cnt, ep_cnt + size))
_ep_group = dist.new_group(ranks)
if dist.get_rank() in ranks:
@ -234,9 +228,7 @@ class InferenceEngine(Module):
num_expert_mp_groups = dist.get_world_size() // num_ep_groups
expert_mp_size = dist.get_world_size() // moe_ep_size
for i in range(num_expert_mp_groups):
expert_mp_comm_ranks = [
i + nr * moe_ep_size for nr in range(expert_mp_size)
]
expert_mp_comm_ranks = [i + nr * moe_ep_size for nr in range(expert_mp_size)]
_expert_mp_group = dist.new_group(expert_mp_comm_ranks)
if dist.get_rank() in expert_mp_comm_ranks:
self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
@ -253,65 +245,48 @@ class InferenceEngine(Module):
log_dist(
f"quantize_bits = {self.quantize_bits} "
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}",
[0])
f"quantize_groups = {self.quantize_groups}", [0])
# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size,
int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(
f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}"
)
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance(
self._config.checkpoint,
(str,
dict)):
raise ValueError(
f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}"
)
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
supported_dtypes = [None, torch.half, torch.int8, torch.float]
if self._config.dtype not in supported_dtypes:
raise ValueError(
f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(
f"injection_dict must be None or a dict, got: {self.injection_dict}")
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group,
mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
error_msgs = []
def load(module, state_dict, prefix):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
if hasattr(module, 'weight'):
if 'query_key_value' in prefix:
module.weight = self.mp_replace.qkv_copy(
module.weight.data,
state_dict[prefix + 'weight'])
module.weight = self.mp_replace.qkv_copy(module.weight.data, state_dict[prefix + 'weight'])
else:
module.weight = self.mp_replace.copy(module.weight.data,
state_dict[prefix + 'weight'])
module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
else:
module.norm.weight = self.mp_replace.copy(module.norm.weight.data,
state_dict[prefix + 'weight'])
module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
if prefix + 'bias' in self.key_list:
if hasattr(module, 'norm'):
module.norm.bias = self.mp_replace.copy(module.norm.bias,
state_dict[prefix + 'bias'])
module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
else:
data = state_dict[prefix + 'bias']
data = data.to(get_accelerator().current_device_name())
@ -331,45 +306,32 @@ class InferenceEngine(Module):
checking_key = prefix + name + '.'
if not any(checking_key in item for item in self.key_list):
continue
if len(list(child.parameters())) > 0 and list(
child.parameters())[0].numel() == 0:
if len(list(child.parameters())) > 0 and list(child.parameters())[0].numel() == 0:
if len(child.weight.ds_shape) == 1:
child = Normalize(dim=child.weight.ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
load(child, self.sd, prefix + name + '.')
else:
load_module_recursive(child,
prefix if level == 0 else prefix + name + '.',
level + 1)
load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1)
load_module_recursive(r_module)
def _apply_injection_policy(self, config, client_module=None):
# client_module is only passed when using the injection_dict method.
checkpoint_dir = config.checkpoint
checkpoint = SDLoaderFactory.get_sd_loader_json(
checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir,
self.checkpoint_engine) if checkpoint_dir is not None else None
generic_injection(self.module,
fp16=(config.dtype == torch.half)
or (config.dtype == torch.int8),
fp16=(config.dtype == torch.half) or (config.dtype == torch.int8),
enable_cuda_graph=config.enable_cuda_graph)
if isinstance(self.module, torch.nn.Module):
# config is our DeepSpeedInferenceConfig and self.config is the HF model config
replace_transformer_layer(client_module,
self.module,
checkpoint,
config,
self.config)
replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
tag,
mp_placeholder="*")
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
import glob
ckpt_files = glob.glob(ckpt_file_pattern)
@ -392,8 +354,7 @@ class InferenceEngine(Module):
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
is_pipe_parallel = isinstance(self.module, PipelineModule)
if is_pipe_parallel:
raise RuntimeError(
'pipeline parallelism is currently not supported in inference.')
raise RuntimeError('pipeline parallelism is currently not supported in inference.')
if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
if tag is None:
latest_path = os.path.join(load_dir, "latest")
@ -404,8 +365,7 @@ class InferenceEngine(Module):
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)
if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
@ -416,19 +376,18 @@ class InferenceEngine(Module):
for i in range(1, len(sd_loader)):
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"loading checkpoint ({i})")
self.sd = torch.load(sd_loader[i],
map_location=get_accelerator().device_name())
self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name())
self.key_list = list(self.sd.keys())
self.load_model_with_checkpoint(self.module)
else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
mp_rank,
is_pipe_parallel=is_pipe_parallel,
quantize=(self._config.dtype is torch.int8),
quantize_groups=self.quantize_groups,
mlp_extra_grouping=self.mlp_extra_grouping)
self.quantization_scales, self.quantize_merge_count = quantize_config
@ -438,21 +397,20 @@ class InferenceEngine(Module):
old_moe_load = False
if not isinstance(checkpoint['num_experts'], list):
old_moe_load = True
DeepSpeedEngine.load_moe_state_dict(
load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)
DeepSpeedEngine.load_moe_state_dict(load_dir,
tag,
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)
self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)],
strict=load_module_strict)
def _choose_module_key(self, sd):
assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert not ('module' in sd
and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
if 'module' in sd:
return 'module'
@ -465,10 +423,8 @@ class InferenceEngine(Module):
if False: #config.dtype is torch.int8 and self.quantization_scales is None:
quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
model, self.quantization_scales = quantizer.model_quantize(self.module,
self.injection_dict,
self.quantize_bits,
self.quantize_groups)
model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict,
self.quantize_bits, self.quantize_groups)
elif config.dtype == torch.half:
self.module.half()
elif config.dtype == torch.bfloat16:
@ -509,11 +465,10 @@ class InferenceEngine(Module):
assert self.model_profile_enabled, "model profiling is not enabled"
model_times = self._model_times
if self._config.enable_cuda_graph and len(self._model_times) == 0:
raise ValueError(
"Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
raise ValueError("Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
self._model_times = []
return model_times
@ -532,8 +487,7 @@ class InferenceEngine(Module):
for name in module.__dict__.keys():
sub_module = getattr(module, name)
if self._module_match(sub_module) and hasattr(sub_module,
"enable_cuda_graph"):
if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"):
sub_module_cuda_graph = True
return sub_module_cuda_graph
@ -546,13 +500,11 @@ class InferenceEngine(Module):
**kwargs: variable length keyword arguments
"""
start = None
if self.model_profile_enabled and get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph:
if self.model_profile_enabled and get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
get_accelerator().synchronize()
start = time.time()
if get_accelerator().device_name(
) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
@ -580,9 +532,7 @@ class InferenceEngine(Module):
num_beams = kwargs["num_beams"]
if num_beams > 1:
raise NotImplementedError(
"DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
)
raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please "
"add your request to: https://github.com/microsoft/DeepSpeed/issues/2506")
return self.module.generate(*inputs, **kwargs)

View File

@ -52,10 +52,7 @@ def parse_args():
help="Master node (rank 0)'s free port that needs to "
"be used for communication during distributed "
"training")
parser.add_argument("--world_info",
default="None",
type=str,
help="world info base64 encoded dictionary")
parser.add_argument("--world_info", default="None", type=str, help="world info base64 encoded dictionary")
parser.add_argument("--module",
action="store_true",
@ -68,19 +65,11 @@ def parse_args():
help="Skip prepending the training script with "
"'python' - just execute it directly.")
parser.add_argument("--enable_elastic_training",
action="store_true",
help="Enable elastic training support.")
parser.add_argument("--enable_elastic_training", action="store_true", help="Enable elastic training support.")
parser.add_argument("--min_elastic_nodes",
type=int,
default=-1,
help="Min number of nodes in elastic training.")
parser.add_argument("--min_elastic_nodes", type=int, default=-1, help="Min number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes",
type=int,
default=-1,
help="Max number of nodes in elastic training.")
parser.add_argument("--max_elastic_nodes", type=int, default=-1, help="Max number of nodes in elastic training.")
parser.add_argument("--no_local_rank",
action="store_true",
@ -92,11 +81,10 @@ def parse_args():
default=0,
help="main launching process pid, for internal pid tracking")
parser.add_argument(
"--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument("--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
# positional
parser.add_argument("training_script",
@ -145,9 +133,7 @@ def main():
local_node = node_list[args.node_rank]
local_gpu_ids = world_info[local_node]
num_local_procs = len(local_gpu_ids)
logger.info(
f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}"
)
logger.info(f"nnodes={args.nnodes}, num_local_procs={num_local_procs}, node_rank={args.node_rank}")
global_rank_mapping = defaultdict(list)
curr_global_rank = 0
@ -193,8 +179,7 @@ def main():
lines = file.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
if line.startswith('export FC_TASKROLE_NAME') or line.startswith(
'export FC_TASK_INDEX'):
if line.startswith('export FC_TASKROLE_NAME') or line.startswith('export FC_TASK_INDEX'):
key_val = line.split()[1]
key, val = key_val.split('=')
current_env[key] = val
@ -206,17 +191,13 @@ def main():
if args.enable_each_rank_log != "None":
# prepare the log path and the file name prefix
if os.path.isfile(args.enable_each_rank_log):
raise ValueError(
f"{args.enable_each_rank_log} should not be a file, it should be a directory."
)
raise ValueError(f"{args.enable_each_rank_log} should not be a file, it should be a directory.")
if not os.path.exists(args.enable_each_rank_log):
try:
os.makedirs(args.enable_each_rank_log)
except Exception as e:
print(e)
raise ValueError(
f"unable to create directory {args.enable_each_rank_log} for each rank log."
)
raise ValueError(f"unable to create directory {args.enable_each_rank_log} for each rank log.")
log_name_prefix = time.strftime("%Y%m%d%H%M%S", time.localtime())
for local_rank in range(0, num_local_procs):
@ -242,13 +223,9 @@ def main():
cmd += args.training_script_args
if args.enable_each_rank_log != "None":
log_file = os.path.join(args.enable_each_rank_log,
f"{log_name_prefix}_rank{dist_rank}.log")
log_file = os.path.join(args.enable_each_rank_log, f"{log_name_prefix}_rank{dist_rank}.log")
log_fd = open(log_file, 'w')
process = subprocess.Popen(cmd,
env=current_env,
stdout=log_fd,
stderr=log_fd)
process = subprocess.Popen(cmd, env=current_env, stdout=log_fd, stderr=log_fd)
else:
process = subprocess.Popen(cmd, env=current_env)
@ -264,7 +241,7 @@ def main():
args.min_elastic_nodes = 1
if args.max_elastic_nodes == -1:
args.max_elastic_nodes = args.nnodes
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0 , "Max and Min nodes should be positive"
assert args.max_elastic_nodes > 0 and args.min_elastic_nodes > 0, "Max and Min nodes should be positive"
current_env["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
@ -287,8 +264,7 @@ def main():
# Creating config for rendezvous class
rdzv_parameters = RendezvousParameters(backend='c10d',
endpoint=args.master_addr + ":" +
str(args.master_port),
endpoint=args.master_addr + ":" + str(args.master_port),
run_id=run_id,
min_nodes=args.min_elastic_nodes,
max_nodes=args.max_elastic_nodes,

View File

@ -13,6 +13,7 @@ from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE
class MultiNodeRunner(ABC):
def __init__(self, args, world_info_base64):
self.args = args
self.validate_args()
@ -45,6 +46,7 @@ class MultiNodeRunner(ABC):
class PDSHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64):
super().__init__(args, world_info_base64)
@ -56,9 +58,7 @@ class PDSHRunner(MultiNodeRunner):
return "pdsh"
def parse_user_args(self):
return list(
map(lambda x: x if x.startswith("-") else f"'{x}'",
self.args.user_args))
return list(map(lambda x: x if x.startswith("-") else f"'{x}'", self.args.user_args))
def get_cmd(self, environment, active_resources):
environment['PDSH_RCMD_TYPE'] = 'ssh'
@ -68,14 +68,8 @@ class PDSHRunner(MultiNodeRunner):
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
pdsh_cmd_args = [
'pdsh',
'-S',
'-f',
str(PDSH_MAX_FAN_OUT),
'-w',
active_workers
] + split(self.args.launcher_args)
pdsh_cmd_args = ['pdsh', '-S', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + split(
self.args.launcher_args)
exports = ""
for key, val in self.exports.items():
@ -84,15 +78,8 @@ class PDSHRunner(MultiNodeRunner):
# https://linux.die.net/man/1/pdsh
# %n will be replaced by pdsh command
deepspeed_launch = [
exports,
f"cd {os.path.abspath('.')};",
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
f'--world_info={self.world_info_base64}',
"--node_rank=%n",
f"--master_addr={self.args.master_addr}",
exports, f"cd {os.path.abspath('.')};", sys.executable, "-u", "-m", "deepspeed.launcher.launch",
f'--world_info={self.world_info_base64}', "--node_rank=%n", f"--master_addr={self.args.master_addr}",
f"--master_port={self.args.master_port}"
]
if self.args.no_python:
@ -111,11 +98,11 @@ class PDSHRunner(MultiNodeRunner):
cmd_to_search = [i + "\\" for i in deepspeed_launch[2:6]]
kill_command = pdsh_cmd_args + ["pkill -f ", " ".join(cmd_to_search)[:-2]]
return pdsh_cmd_args + deepspeed_launch + [self.user_script
] + self.user_arguments, kill_command
return pdsh_cmd_args + deepspeed_launch + [self.user_script] + self.user_arguments, kill_command
class OpenMPIRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
@ -133,11 +120,9 @@ class OpenMPIRunner(MultiNodeRunner):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
total_process_count = sum(self.resource_pool.values())
@ -166,11 +151,11 @@ class OpenMPIRunner(MultiNodeRunner):
if self.args.module:
python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments
class MPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
@ -187,12 +172,10 @@ class MPICHRunner(MultiNodeRunner):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values()
@ -220,6 +203,7 @@ class MPICHRunner(MultiNodeRunner):
class SlurmRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
@ -232,7 +216,8 @@ class SlurmRunner(MultiNodeRunner):
return 'slurm'
def get_cmd(self, environment, active_resources):
assert not getattr(self.args, 'detect_nvlink_pairs', False), "slurm backend does not support remapping visible devices"
assert not getattr(self.args, 'detect_nvlink_pairs',
False), "slurm backend does not support remapping visible devices"
total_process_count = sum(self.resource_pool.values())
srun_cmd = [
'srun',
@ -261,12 +246,12 @@ class SlurmRunner(MultiNodeRunner):
exports += f",{key}={val}"
python_exec = [sys.executable, "-u"]
command = srun_cmd + [exports] + python_exec + [self.user_script
] + self.user_arguments
command = srun_cmd + [exports] + python_exec + [self.user_script] + self.user_arguments
return command
class MVAPICHRunner(MultiNodeRunner):
def __init__(self, args, world_info_base64, resource_pool):
super().__init__(args, world_info_base64)
self.resource_pool = resource_pool
@ -303,9 +288,7 @@ class MVAPICHRunner(MultiNodeRunner):
if "MVAPICH2-GDR" in mpiname_results:
exists = True
else:
warnings.warn(
f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}"
)
warnings.warn(f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}")
return exists
@property
@ -316,11 +299,9 @@ class MVAPICHRunner(MultiNodeRunner):
super().validate_args()
#TODO: Allow for include/exclude at node-level but not gpu-level
if self.args.include != "" or self.args.exclude != "":
raise ValueError(
f"{self.name} backend does not support worker include/exclusion")
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
raise ValueError(
f"{self.name} backend does not support limiting num nodes/gpus")
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
def get_cmd(self, environment, active_resources):
devices_per_node = self.resource_pool.values()
@ -353,5 +334,4 @@ class MVAPICHRunner(MultiNodeRunner):
if self.args.module:
python_exec.append("-m")
return mpirun_cmd + export_cmd + python_exec + [self.user_script
] + self.user_arguments
return mpirun_cmd + export_cmd + python_exec + [self.user_script] + self.user_arguments

View File

@ -36,9 +36,8 @@ PDSH_MAX_FAN_OUT = 1024
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="DeepSpeed runner to help launch distributed "
"multi-node/multi-gpu training jobs.")
parser = argparse.ArgumentParser(description="DeepSpeed runner to help launch distributed "
"multi-node/multi-gpu training jobs.")
parser.add_argument("-H",
"--hostfile",
@ -109,12 +108,11 @@ def parse_args(args=None):
help="(optional) IP address of node 0, will be "
"inferred via 'hostname -I' if not specified.")
parser.add_argument(
"--launcher",
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
parser.add_argument("--launcher",
default=PDSH_LAUNCHER,
type=str,
help="(optional) choose launcher backend for multi-node "
"training. Options currently include PDSH, OpenMPI, MVAPICH, SLURM, MPICH.")
parser.add_argument("--launcher_args",
default="",
@ -147,35 +145,29 @@ def parse_args(args=None):
help="Force multi-node launcher mode, helps in cases where user "
"wants to launch on single remote node.")
parser.add_argument(
"--save_pid",
action="store_true",
help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, "
"where <main-pid> is the pid of the first process that invoked `deepspeed`. "
"Useful when launching deepspeed processes programmatically.")
parser.add_argument("--save_pid",
action="store_true",
help="Save file containing launcher process id (pid) at /tmp/<main-pid>.ds, "
"where <main-pid> is the pid of the first process that invoked `deepspeed`. "
"Useful when launching deepspeed processes programmatically.")
parser.add_argument(
"--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument("--enable_each_rank_log",
default="None",
type=str,
help="redirect the stdout and stderr from each rank into different log files")
parser.add_argument(
"--autotuning",
default="",
choices=["tune",
"run"],
type=str,
help="Run DeepSpeed autotuner to discover optimal configuration parameters "
"before running job.")
parser.add_argument("--autotuning",
default="",
choices=["tune", "run"],
type=str,
help="Run DeepSpeed autotuner to discover optimal configuration parameters "
"before running job.")
parser.add_argument("--elastic_training",
action="store_true",
help="Enable elastic training support in DeepSpeed.")
parser.add_argument("user_script",
type=str,
help="User script to launch, followed by any required "
parser.add_argument("user_script", type=str, help="User script to launch, followed by any required "
"arguments.")
parser.add_argument('user_args', nargs=argparse.REMAINDER)
return parser.parse_args(args=args)
@ -213,21 +205,15 @@ def _parse_hostfile(hostfile_lines):
num_slots = int(match.group(2))
if host in resource_pool:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
f"Hostfile contains multiple entries for {host}, unable to proceed with launching"
)
raise ValueError(f"Hostfile contains multiple entries for {host}, unable to proceed with launching")
resource_pool[host] = num_slots
else:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile contains a bad entry: {line}, unable to proceed with launching"
)
raise ValueError("Hostfile contains a bad entry: {line}, unable to proceed with launching")
if len(resource_pool) == 0:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile is empty or not formatted correctly, unable to proceed with launching."
)
raise ValueError("Hostfile is empty or not formatted correctly, unable to proceed with launching.")
return resource_pool
@ -337,9 +323,7 @@ def parse_inclusion_exclusion(resource_pool, inclusion, exclusion):
for hostname, slots in resource_pool.items():
active_resources[hostname] = list(range(slots))
return parse_resource_filter(active_resources,
include_str=inclusion,
exclude_str=exclusion)
return parse_resource_filter(active_resources, include_str=inclusion, exclude_str=exclusion)
def encode_world_info(world_info):
@ -389,8 +373,7 @@ def main(args=None):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not resource_pool and len(cuda_visible_devices):
detected_str = f"Detected CUDA_VISIBLE_DEVICES={cuda_visible_devices}"
if len(args.include) or len(
args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
if len(args.include) or len(args.exclude) or args.num_nodes > 1 or args.num_gpus > 0:
print(
f"{detected_str} but ignoring it because one or several of --include/--exclude/--num_gpus/--num_nodes cl args were used. If you want to use CUDA_VISIBLE_DEVICES don't pass any of these arguments to deepspeed."
)
@ -416,20 +399,17 @@ def main(args=None):
if not multi_node_exec and args.num_nodes > 1:
raise ValueError("Num nodes is >1 but no extra nodes available via hostfile")
active_resources = parse_inclusion_exclusion(resource_pool,
args.include,
args.exclude)
active_resources = parse_inclusion_exclusion(resource_pool, args.include, args.exclude)
env = os.environ.copy()
# validate that passwordless-ssh is workly properly with this hostfile
if multi_node_exec and not args.no_ssh_check:
first_host = list(active_resources.keys())[0]
try:
subprocess.check_call(
f'ssh -o PasswordAuthentication=no {first_host} hostname',
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
shell=True)
subprocess.check_call(f'ssh -o PasswordAuthentication=no {first_host} hostname',
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
shell=True)
except subprocess.CalledProcessError:
raise RuntimeError(
f"Using hostfile at {args.hostfile} but host={first_host} was not reachable via ssh. If you are running with a single node please remove {args.hostfile} or setup passwordless ssh."
@ -481,13 +461,8 @@ def main(args=None):
if not multi_node_exec:
deepspeed_launch = [
sys.executable,
"-u",
"-m",
"deepspeed.launcher.launch",
f"--world_info={world_info_base64}",
f"--master_addr={args.master_addr}",
f"--master_port={args.master_port}"
sys.executable, "-u", "-m", "deepspeed.launcher.launch", f"--world_info={world_info_base64}",
f"--master_addr={args.master_addr}", f"--master_port={args.master_port}"
]
if args.no_python:
deepspeed_launch.append("--no_python")
@ -498,8 +473,7 @@ def main(args=None):
if args.save_pid:
deepspeed_launch += ["--save_pid", f"{os.getpid()}"]
if args.enable_each_rank_log:
deepspeed_launch.append(
f"--enable_each_rank_log={args.enable_each_rank_log}")
deepspeed_launch.append(f"--enable_each_rank_log={args.enable_each_rank_log}")
if args.elastic_training:
deepspeed_launch.append("--enable_elastic_training")
deepspeed_launch.append(f"--max_elastic_nodes={args.max_elastic_nodes}")

View File

@ -6,6 +6,7 @@ from ..features.cuda_graph import CUDAGraph
class DSUNet(CUDAGraph, torch.nn.Module):
def __init__(self, unet, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.unet = unet

View File

@ -6,6 +6,7 @@ from ..features.cuda_graph import CUDAGraph
class DSVAE(CUDAGraph, torch.nn.Module):
def __init__(self, vae, enable_cuda_graph=True):
super().__init__(enable_cuda_graph=enable_cuda_graph)
self.vae = vae
@ -44,8 +45,7 @@ class DSVAE(CUDAGraph, torch.nn.Module):
self.static_decoder_kwargs = kwargs
with torch.cuda.graph(self._decoder_cuda_graph):
self.static_decoder_output = self._decode(*self.static_decoder_inputs,
**self.static_decoder_kwargs)
self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs)
self.decoder_cuda_graph_created = True
@ -88,8 +88,7 @@ class DSVAE(CUDAGraph, torch.nn.Module):
self.static_encoder_kwargs = kwargs
with torch.cuda.graph(self._encoder_cuda_graph):
self.static_encoder_output = self._encode(*self.static_encoder_inputs,
**self.static_encoder_kwargs)
self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs)
self.encoder_cuda_graph_created = True

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
class CUDAGraph(ABC):
def __init__(self, enable_cuda_graph=False):
super().__init__()
self.enable_cuda_graph = enable_cuda_graph

View File

@ -7,6 +7,7 @@ from ..features.cuda_graph import CUDAGraph
class DSClipEncoder(CUDAGraph, torch.nn.Module):
def __init__(self, enc, enable_cuda_graph=False):
super().__init__(enable_cuda_graph=enable_cuda_graph)
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
@ -22,11 +23,7 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module):
self.config = self.enc.config
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.empty(bsz,
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().current_device_name())
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1)
mask = mask.unsqueeze(1)
@ -69,9 +66,8 @@ class DSClipEncoder(CUDAGraph, torch.nn.Module):
self.static_kwargs[self.iter] = kwargs
with torch.cuda.graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(
*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])
self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])
self.cuda_graph_created[self.iter] = True

View File

@ -4,6 +4,7 @@ import torch.nn as nn
class DeepSpeedTransformerBase(nn.module):
def __init__(self):
pass

View File

@ -8,6 +8,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedBERTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed BERT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
@ -15,9 +16,4 @@ class DeepSpeedBERTInference(DeepSpeedTransformerInference):
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

View File

@ -8,6 +8,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedBloomInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Bloom Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
@ -15,9 +16,4 @@ class DeepSpeedBloomInference(DeepSpeedTransformerInference):
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

View File

@ -8,6 +8,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed GPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
@ -15,9 +16,4 @@ class DeepSpeedGPTInference(DeepSpeedTransformerInference):
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

View File

@ -8,6 +8,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed Megatron GPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
@ -15,9 +16,4 @@ class DeepSpeedMegatronGPTInference(DeepSpeedTransformerInference):
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

View File

@ -8,6 +8,7 @@ from deepspeed.model_implementations.transformers.ds_transformer import DeepSpee
class DeepSpeedOPTInference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
"""
def __init__(self,
config,
mp_group=None,
@ -15,9 +16,4 @@ class DeepSpeedOPTInference(DeepSpeedTransformerInference):
quantize_groups=1,
merge_count=1,
mlp_extra_grouping=False):
super().__init__(config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
mlp_extra_grouping)
super().__init__(config, mp_group, quantize_scales, quantize_groups, merge_count, mlp_extra_grouping)

View File

@ -56,33 +56,17 @@ class DeepSpeedTransformerInference(nn.Module):
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count)
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
else:
self.attention = DeepSpeedSelfAttention(self.config,
mp_group,
quantize_scales,
quantize_groups,
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config,
mp_group,
quantize_scales,
quantize_groups,
merge_count,
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
device = get_accelerator().current_device_name(
) # if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
self.norm_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device),
requires_grad=False)
self.layer_past = None
self.allocate_workspace = inference_cuda_module.allocate_workspace_fp32 if (not config.fp16) else \
@ -122,20 +106,15 @@ class DeepSpeedTransformerInference(nn.Module):
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
input_mask = (input_mask if attn_mask is None else
attn_mask) if attention_mask is None else attention_mask
input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
self.config.heads,
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0],
DeepSpeedTransformerInference.layer_id,
self.config.mp_size,
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0,
self.config.max_out_tokens)
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens)
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
@ -174,10 +153,7 @@ class DeepSpeedTransformerInference(nn.Module):
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
if not self.config.pre_layer_norm:
output = inference_cuda_module.layer_norm(output,
self.norm_w,
self.norm_b,
self.config.epsilon)
output = inference_cuda_module.layer_norm(output, self.norm_w, self.norm_b, self.config.epsilon)
output = output.to(input_type)
if get_present:

View File

@ -8,6 +8,7 @@ from .replace_policy import replace_policies
class AutoTP():
def in_module_list(module, module_list):
for item in module_list:
if type(item).__name__ == type(module).__name__:
@ -28,18 +29,7 @@ class AutoTP():
return mlist
def supported(model):
unsupported = [
'bloom',
'codegen',
'deberta',
'flaubert',
'fsmt',
'gpt2',
'led',
'longformer',
'xlm',
'xlnet'
]
unsupported = ['bloom', 'codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
@ -56,8 +46,7 @@ class AutoTP():
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
@ -102,9 +91,7 @@ class AutoTP():
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(
submodule,
nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)

View File

@ -15,6 +15,7 @@ class BaseConvolutionContainer(ABC):
class BaseTransformerContainer(ABC):
def __init__(self, policy, config, model_config, layer_id, child):
self.policy = policy
self.config = config
@ -40,18 +41,14 @@ class BaseTransformerContainer(ABC):
hasattr(self.model_config, 'layernorm_epsilon') else 1.0e-12)
self.return_tuple = self.config.return_tuple
self.triangular_masking = True
self.local_attention = ((self.model_config.attention_layers[self.layer_id]
== "local") if hasattr(self.model_config,
'attention_layers') else False)
self.local_attention = ((self.model_config.attention_layers[self.layer_id] == "local") if hasattr(
self.model_config, 'attention_layers') else False)
self.window_size = getattr(self.model_config, "window_size", 1)
self.mlp_act_func_type = self.policy.mlp_act_func_type
self.training_mp_size = self.config.training_mp_size
self.bigscience_bloom = False
self.max_out_tokens = self.config.max_out_tokens
self.scale_attn_by_inverse_layer_idx = getattr(
self.config,
"scale_attn_by_inverse_layer_idx",
False)
self.scale_attn_by_inverse_layer_idx = getattr(self.config, "scale_attn_by_inverse_layer_idx", False)
self.use_mup = self.policy.use_mup
self.return_single_tuple = False
self.rotary_dim = self.model_config.rotary_dim if hasattr(self.model_config, 'rotary_dim') \
@ -168,10 +165,8 @@ class BaseTransformerContainer(ABC):
self.mlp_quantization()
def attention_quantization(self):
self.module.attention.attn_qkvw = self.quantizer.quantize(
self.module.attention.attn_qkvw)
self.module.attention.attn_ow = self.quantizer.quantize(
self.module.attention.attn_ow)
self.module.attention.attn_qkvw = self.quantizer.quantize(self.module.attention.attn_qkvw)
self.module.attention.attn_ow = self.quantizer.quantize(self.module.attention.attn_ow)
def mlp_quantization(self):
self.module.mlp.inter_w = self.quantizer.quantize(self.module.mlp.inter_w)
@ -190,18 +185,12 @@ class BaseTransformerContainer(ABC):
self.apply_weight_quantization()
def attention_qkv_mp(self, mp_replace):
self.module.attention.attn_qkvw = mp_replace.qkv_copy(
self.module.attention.attn_qkvw,
self.qkvw)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(
self.module.attention.attn_qkvb,
self.qkvb)
self.module.attention.attn_qkvw = mp_replace.qkv_copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.qkv_copy(self.module.attention.attn_qkvb, self.qkvb)
def attention_o_mp(self, mp_replace):
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow,
self.dense_w)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob,
self.dense_b)
self.module.attention.attn_ow = mp_replace.copy(self.module.attention.attn_ow, self.dense_w)
self.module.attention.attn_ob = mp_replace.copy(self.module.attention.attn_ob, self.dense_b)
def mlp_inter_mp(self, mp_replace):
self.module.mlp.inter_w = mp_replace.copy(self.module.mlp.inter_w, self._h4h_w)
@ -216,15 +205,11 @@ class BaseTransformerContainer(ABC):
self.module.mlp.attn_nw = self.attn_nw
self.module.mlp.attn_nb = self.attn_nb
else:
self.module.mlp.attn_nw.data.copy_(
self.attn_nw.to(get_accelerator().current_device_name()))
self.module.mlp.attn_nb.data.copy_(
self.attn_nb.to(get_accelerator().current_device_name()))
self.module.mlp.attn_nw.data.copy_(self.attn_nw.to(get_accelerator().current_device_name()))
self.module.mlp.attn_nb.data.copy_(self.attn_nb.to(get_accelerator().current_device_name()))
self.module.norm_w.data.copy_(
self.input_nw.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(
self.input_nb.to(get_accelerator().current_device_name()))
self.module.norm_w.data.copy_(self.input_nw.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(self.input_nb.to(get_accelerator().current_device_name()))
def transpose(self):
self.transpose_attention()

View File

@ -8,6 +8,7 @@ from deepspeed.accelerator import get_accelerator
class BaseTransformerMoEContainer(BaseTransformerContainer):
def __init__(self, **kwargs):
# Call the init function of the parent class to initialize the tensors and configs from parent class
super().__init__(**kwargs)
@ -16,9 +17,7 @@ class BaseTransformerMoEContainer(BaseTransformerContainer):
self.ep_world_size = dist.get_world_size()
self.local_ep_size = 1 if self.num_experts < self.ep_world_size else self.num_experts // self.ep_world_size
self.layer_norm_eps = self.config.layer_norm_eps if hasattr(
self.config,
'layer_norm_eps') else 1e-12,
self.layer_norm_eps = self.config.layer_norm_eps if hasattr(self.config, 'layer_norm_eps') else 1e-12,
# MoE models will have a list of mlp related tensors
self._h4h_w = []
@ -102,40 +101,27 @@ class BaseTransformerMoEContainer(BaseTransformerContainer):
gpu_index = dist.get_rank()
for ep_index in range(self.local_ep_size):
# mlp inter
self.module.mlp[ep_index].inter_w.data = self._h4h_w[
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].inter_b.data = self._h4h_b[
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].inter_w.data = self._h4h_w[gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].inter_b.data = self._h4h_b[gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
# mlp output
self.module.mlp[ep_index].output_w.data = self._4hh_w[
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].output_b.data = self._4hh_b[
gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].output_w.data = self._4hh_w[gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
self.module.mlp[ep_index].output_b.data = self._4hh_b[gpu_index * self.local_ep_size + ep_index].to(
get_accelerator().current_device_name())
def copy_data_to_new_module(self):
self.module.attn_nw.data = self.attn_nw.to(
get_accelerator().current_device_name())
self.module.attn_nb.data = self.attn_nb.to(
get_accelerator().current_device_name())
self.module.attn_nw.data = self.attn_nw.to(get_accelerator().current_device_name())
self.module.attn_nb.data = self.attn_nb.to(get_accelerator().current_device_name())
self.module.norm_w.data.copy_(
self.input_nw.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(
self.input_nb.to(get_accelerator().current_device_name()))
self.module.norm_w.data.copy_(self.input_nw.to(get_accelerator().current_device_name()))
self.module.norm_b.data.copy_(self.input_nb.to(get_accelerator().current_device_name()))
if self.config.moe.type == 'residual':
self.module.res_mlp.inter_w.data = self._res_h4h_w.to(
get_accelerator().current_device_name())
self.module.res_mlp.inter_b.data = self._res_h4h_b.to(
get_accelerator().current_device_name())
self.module.res_mlp.output_w.data = self._res_4hh_w.to(
get_accelerator().current_device_name())
self.module.res_mlp.output_b.data = self._res_4hh_b.to(
get_accelerator().current_device_name())
self.module.res_coef.data = self._res_coef.to(
get_accelerator().current_device_name())
self.module.res_mlp.inter_w.data = self._res_h4h_w.to(get_accelerator().current_device_name())
self.module.res_mlp.inter_b.data = self._res_h4h_b.to(get_accelerator().current_device_name())
self.module.res_mlp.output_w.data = self._res_4hh_w.to(get_accelerator().current_device_name())
self.module.res_mlp.output_b.data = self._res_4hh_b.to(get_accelerator().current_device_name())
self.module.res_coef.data = self._res_coef.to(get_accelerator().current_device_name())

View File

@ -8,6 +8,7 @@ from ..policy import TransformerPolicy
class DS_BERTContainer(BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -23,6 +24,7 @@ class DS_BERTContainer(BaseTransformerContainer):
class HFBertLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=False)
self.client_module = client_module

View File

@ -11,6 +11,7 @@ supported_models = {None}
class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -25,12 +26,8 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
return self.module
def attention_qkv_mp(self, mp_replace):
self.module.attention.attn_qkvw = mp_replace.copy(
self.module.attention.attn_qkvw,
self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(
self.module.attention.attn_qkvb,
self.qkvb)
self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
@ -58,51 +55,28 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
megatron_v2=self.policy.is_megatron_v2,
split_qkv=self.policy.split_qkv)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(4, 10):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(10, 12):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
class BLOOMLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self,
client_module,
inference=True,
use_load_prefix=True,
split_qkv=False):
super().__init__(inference,
linear_layer=True,
use_load_prefix=use_load_prefix,
split_qkv=split_qkv)
def __init__(self, client_module, inference=True, use_load_prefix=True, split_qkv=False):
super().__init__(inference, linear_layer=True, use_load_prefix=use_load_prefix, split_qkv=split_qkv)
self.client_module = client_module
try:
import transformers
BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
global supported_models
supported_models.update(
{transformers.models.bloom.modeling_bloom.BloomModel})
supported_models.update({transformers.models.bloom.modeling_bloom.BloomModel})
except Exception as e:
print(
f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}"
)
print(f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}")
BLOOMLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):

View File

@ -8,6 +8,7 @@ from ..policy import TransformerPolicy
class DS_CLIPContainer(BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -21,6 +22,7 @@ class DS_CLIPContainer(BaseTransformerContainer):
class HFCLIPLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=True, scale_attention=True)
self.client_module = client_module

View File

@ -8,6 +8,7 @@ from ..policy import TransformerPolicy
class DS_DistilBERTContainer(BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@ -5,6 +5,7 @@ from abc import ABC
class MegatronContainer(ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.megatron_v2 = self.policy.is_megatron_v2
@ -15,23 +16,13 @@ class MegatronContainer(ABC):
x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
if len(q.shape) > 2:
return torch.cat((q.reshape(q.shape[0],
-1),
k.reshape(q.shape[0],
-1),
v.reshape(q.shape[0],
-1)),
return torch.cat((q.reshape(q.shape[0], -1), k.reshape(q.shape[0], -1), v.reshape(q.shape[0], -1)),
dim=-1).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
def transpose(self):
super().transpose()
if self.megatron_v2:
self.qkvw = torch.nn.parameter.Parameter(
self.transpose_qkv_alignment(self.qkvw).contiguous())
self.qkvb = torch.nn.parameter.Parameter(
self.transpose_qkv_alignment(self.qkvb).contiguous())
self.qkvw = torch.nn.parameter.Parameter(self.transpose_qkv_alignment(self.qkvw).contiguous())
self.qkvb = torch.nn.parameter.Parameter(self.transpose_qkv_alignment(self.qkvb).contiguous())

View File

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
class MetaTensorContainer(ABC):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_meta = False
@ -53,6 +54,5 @@ class MetaTensorContainer(ABC):
of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API.
"""
raise NotImplementedError(
"A load_params() function must be defined in the model container \
raise NotImplementedError("A load_params() function must be defined in the model container \
when inheriting the MetaTensorContainer feature")

View File

@ -6,6 +6,7 @@ from ..policy import TransformerPolicy
class DS_GPT2Container(BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@ -12,6 +12,7 @@ from ..policy import maybe_copy_qkv
class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -36,36 +37,20 @@ class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
'ln_1.weight', \
'ln_1.bias'
)
maybe_copy_qkv(
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention,
maybe_copy_qkv(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(4, 8):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(8, 10):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i + 2],
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i + 2],
prefix + param_names[i])

View File

@ -12,6 +12,7 @@ from ..policy import maybe_copy_qkv
class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -39,40 +40,25 @@ class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
'ln_1.weight', \
'ln_1.bias'
)
maybe_copy_qkv(
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention,
maybe_copy_qkv(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
'attn_qkvw', [prefix + param_names[0], prefix + param_names[1], prefix + param_names[2]],
split_qkv=self.policy.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(5, 11):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(11, 13):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 1],
prefix + param_names[i])
class HFGPTNEOLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=False)
self.client_module = client_module

View File

@ -11,9 +11,8 @@ from ..policy import maybe_copy
from packaging import version as pkg_version
class DS_GPTNEOXContainer(MetaTensorContainer,
MegatronContainer,
BaseTransformerContainer):
class DS_GPTNEOXContainer(MetaTensorContainer, MegatronContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -57,26 +56,13 @@ class DS_GPTNEOXContainer(MetaTensorContainer,
split_qkv=self.policy.split_qkv,
heads=self.policy.client_module.attention.num_attention_heads)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(4, 10):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(10, 12):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
class GPTNEOXLayerPolicy(TransformerPolicy):

View File

@ -9,6 +9,7 @@ from packaging import version as pkg_version
class DS_MegatronGPTContainer(MegatronContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -36,9 +37,7 @@ class MegatronLayerPolicy(TransformerPolicy):
use_mup = False
def __init__(self, client_module, inference=True):
super().__init__(inference,
megatron_v2=MegatronLayerPolicy.megatron_v2,
use_mup=MegatronLayerPolicy.use_mup)
super().__init__(inference, megatron_v2=MegatronLayerPolicy.megatron_v2, use_mup=MegatronLayerPolicy.use_mup)
self.client_module = client_module
# we use megatron version to differentiate between the old and new
# megatron-lm source code

View File

@ -10,6 +10,7 @@ from packaging import version as pkg_version
class DS_MegatronGPTMoEContainer(MegatronContainer, BaseTransformerMoEContainer):
def __init__(self, policy, config, model_config, layer_id):
super().__init__(policy, config, model_config, layer_id)

View File

@ -13,6 +13,7 @@ from deepspeed.utils.types import ActivationFuncType
class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -50,32 +51,16 @@ class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer):
weight_quantizer,
mp_replace,
transformer_param_names[i // 3],
[
prefix + param_names[i],
prefix + param_names[i + 1],
prefix + param_names[i + 2]
],
[prefix + param_names[i], prefix + param_names[i + 1], prefix + param_names[i + 2]],
split_qkv=self.policy.split_qkv)
for i in range(6, 8):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
prefix + param_names[i])
for i in range(8, 14):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
prefix + param_names[i])
for i in range(14, 16):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i - 4],
prefix + param_names[i])
@ -93,8 +78,7 @@ class HFOPTLayerPolicy(TransformerPolicy):
try:
import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
if isinstance(TransformerPolicy.hf_model_config,
transformers.models.opt.configuration_opt.OPTConfig):
if isinstance(TransformerPolicy.hf_model_config, transformers.models.opt.configuration_opt.OPTConfig):
self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
except:
HFOPTLayerPolicy._orig_layer_class = None

View File

@ -9,6 +9,7 @@ from ...model_implementations.diffusers.unet import DSUNet
class UNetPolicy(DSPolicy):
def __init__(self):
super().__init__()
try:

View File

@ -6,6 +6,7 @@ from ...model_implementations.diffusers.vae import DSVAE
class VAEPolicy(DSPolicy):
def __init__(self):
super().__init__()
try:

View File

@ -5,30 +5,22 @@ import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
def module_inject(layer_obj,
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
def module_inject(layer_obj, model, config, micro_batch_size, max_seq_length, seed, preln, fp16=True):
for name, child in model.named_children():
if isinstance(child, layer_obj):
print('REPLACING BertLayer')
cuda_config = DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)
cuda_config = DeepSpeedTransformerConfig(batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)
new_module = DeepSpeedTransformerLayer(cuda_config)
@ -71,14 +63,7 @@ def module_inject(layer_obj,
setattr(model, name, copy.deepcopy(new_module))
else:
module_inject(layer_obj,
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)
module_inject(layer_obj, child, config, micro_batch_size, max_seq_length, seed, preln, fp16)
return model

View File

@ -10,6 +10,7 @@ from deepspeed.accelerator import get_accelerator
class LinearAllreduce(nn.Module):
def __init__(self, weight, bias=None, mp_group=None):
super(LinearAllreduce, self).__init__()
self.weight = weight
@ -26,6 +27,7 @@ class LinearAllreduce(nn.Module):
class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__()
if weight is not None:
@ -33,9 +35,7 @@ class LinearLayer(nn.Module):
self.bias = bias
else:
self.weight = Parameter(
torch.empty(weight_shape,
dtype=dtype,
device=get_accelerator().current_device_name()))
torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name()))
self.bias = Parameter(
torch.empty(weight_shape[0],
@ -51,11 +51,10 @@ class LinearLayer(nn.Module):
class Normalize(nn.Module):
def __init__(self, dim, dtype=torch.float, eps=1e-5):
super(Normalize, self).__init__()
self.norm = nn.LayerNorm(dim,
eps=eps).to(dtype).to(
get_accelerator().current_device_name())
self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(get_accelerator().current_device_name())
self.weight = self.norm.weight
self.bias = self.norm.bias
@ -64,13 +63,11 @@ class Normalize(nn.Module):
class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.half):
super(EmbeddingLayer, self).__init__()
self.weight = Parameter(
torch.empty(weight_shape[0],
weight_shape[1],
dtype=dtype,
device=get_accelerator().current_device_name()))
torch.empty(weight_shape[0], weight_shape[1], dtype=dtype, device=get_accelerator().current_device_name()))
def forward(self, input):
return F.embedding(input, self.weight)
@ -80,6 +77,7 @@ class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
@ -91,9 +89,7 @@ class OPTEmbedding(EmbeddingLayer):
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

View File

@ -50,10 +50,8 @@ def load_model_with_checkpoint(r_module,
if prefix + 'bias' in sd[0].keys():
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
args = None
gc.collect()
@ -71,86 +69,62 @@ def load_model_with_checkpoint(r_module,
# set the quantizer number of groups using the checkpoint scale shape
weight_quantizer.num_groups = scale.shape[0]
else:
tmp_data = sd[0][prefix + n].to(
get_accelerator().current_device_name())
tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
scale = None
src_shape = tmp_data.shape
dst_shape = p.shape
inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
if (len(src_shape) == 2 and len(dst_shape) == 2):
if (src_shape[inner_dim] == dst_shape[0]
and src_shape[outer_dim] == dst_shape[1]):
if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
if tmp_data.dtype != torch.int8:
p = weight_quantizer.quantize(
transpose(tmp_data) if weight_quantizer.
q_int8 else tmp_data)
transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
else:
p = torch.nn.parameter.Parameter(tmp_data,
requires_grad=False)
p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
p.scale = scale
setattr(module, n, p)
else:
dim = inner_dim if src_shape[inner_dim] != dst_shape[
0] else outer_dim
dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
if src_shape[dim] > dst_shape[dim1]:
weight_partition = torch.split(
tmp_data,
dst_shape[dim1],
dim=dim)[rank].to(
get_accelerator().current_device_name())
weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
get_accelerator().current_device_name())
assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
'''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
scale = scale.view(
-1)[weight_quantizer.num_groups *
(rank + 1):].reshape(
weight_quantizer.num_groups,
-1).contiguous()
scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
weight_quantizer.num_groups, -1).contiguous()
else:
assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [
sd[j][prefix +
n] if type(sd[j][prefix + n]) is list else
sd[j][prefix + n].to(
get_accelerator().current_device_name())
for j in range(len(sd))
sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
get_accelerator().current_device_name()) for j in range(len(sd))
]
# Check if the weight tensor is for the QKV parameter
if src_shape[1] == (3 *
src_shape[0]) // ckpt_mp_size:
if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
qkv_size = src_shape[outer_dim] // 3
src_split = [
torch.split(src[0].data,
qkv_size,
dim=outer_dim)
for src in all_data
torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
]
weight_partition = torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split],
axis=outer_dim)
torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
for i in range(len(src_split[0]))
],
dim=dim)
else:
weight_partition = torch.cat([
ad[0].to(
get_accelerator().current_device_name())
if type(ad) is list else ad
for ad in all_data
ad[0].to(get_accelerator().current_device_name())
if type(ad) is list else ad for ad in all_data
],
dim=dim)
if tmp_data.dtype == torch.int8:
scale = torch.cat([
ad[1].to(
get_accelerator().current_device_name())
for ad in all_data
],
dim=dim)
scale = torch.cat(
[ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
dim=dim)
if tmp_data.dtype != torch.int8:
weight_partition = weight_quantizer.quantize(
@ -158,9 +132,8 @@ def load_model_with_checkpoint(r_module,
parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(weight_partition)
else:
weight_partition = torch.nn.parameter.Parameter(
weight_partition,
requires_grad=False)
weight_partition = torch.nn.parameter.Parameter(weight_partition,
requires_grad=False)
weight_partition.scale = scale
setattr(module, n, weight_partition)
else:
@ -168,42 +141,27 @@ def load_model_with_checkpoint(r_module,
p.data.copy_(tmp_data)
else:
if src_shape[0] > dst_shape[0]:
bias_split = torch.split(
tmp_data,
dst_shape[-1])[rank].to(get_accelerator(
).current_device_name()).contiguous()
bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
get_accelerator().current_device_name()).contiguous()
p.data.copy_(bias_split)
else:
# Check if the weight tensor is for the QKV parameter
if src_shape[0] == (3 * r_module.config.hidden_size
) // ckpt_mp_size:
if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
qkv_size = src_shape[0] // 3
src_split = [
torch.split(sd[j][prefix + n],
qkv_size,
dim=0) for j in range(len(sd))
torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
]
p.data.copy_(
torch.cat(
[
torch.cat([
qkv_s[i] for qkv_s in src_split
],
axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator().current_device_name()).contiguous())
else:
p.data.copy_(
torch.cat(
[
sd[j][prefix + n]
for j in range(len(sd))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
torch.cat([sd[j][prefix + n] for j in range(len(sd))],
dim=0).to(get_accelerator().current_device_name()).contiguous())
load_parameters(module, prefix)
for n, child in module.named_children():
@ -249,20 +207,16 @@ def load_model_with_checkpoint(r_module,
setattr(module, name, child)
continue
child_params = list(child.parameters())
if len(child_params) > 0 and (child_params[0].numel() == 0
or child_params[0].is_meta):
if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
if child.weight.is_meta:
ds_shape = child.weight.shape
else:
ds_shape = child.weight.ds_shape
if child.__class__ is nn.LayerNorm:
child = Normalize(dim=ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
elif child.__class__ is nn.Linear:
child = LinearLayer(weight_shape=child.weight.shape,
bias=child.bias)
child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
@ -271,8 +225,7 @@ def load_model_with_checkpoint(r_module,
ds_id = None
if hasattr(child.weight, 'ds_id'):
ds_id = child.weight.ds_id
child = EmbeddingLayer(weight_shape=ds_shape,
dtype=child.weight.dtype)
child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
if ds_id is not None:
all_ds_ids[ds_id] = child.weight
setattr(module, name, child)

View File

@ -18,34 +18,25 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
Returns:
Updated nn.module with quantized transformer layers
"""
def quantize_weight(weight):
return weight.to(torch.int8)
def megatron_layer_quantize(layer):
layer.attention.query_key_value.weight.data = quantize_weight(
layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(
layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(
layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(
layer.mlp.dense_4h_to_h.weight.data)
layer.attention.query_key_value.weight.data = quantize_weight(layer.attention.query_key_value.weight.data)
layer.attention.dense.weight.data = quantize_weight(layer.attention.dense.weight.data)
layer.mlp.dense_h_to_4h.weight.data = quantize_weight(layer.mlp.dense_h_to_4h.weight.data)
layer.mlp.dense_4h_to_h.weight.data = quantize_weight(layer.mlp.dense_4h_to_h.weight.data)
def bert_layer_quantize(layer):
layer.attention.self.query.weight.data = quantize_weight(
layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(
layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(
layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(
layer.attention.output.dense.weight.data)
layer.attention.self.query.weight.data = quantize_weight(layer.attention.self.query.weight.data)
layer.attention.self.key.weight.data = quantize_weight(layer.attention.self.key.weight.data)
layer.attention.self.value.weight.data = quantize_weight(layer.attention.self.value.weight.data)
layer.attention.output.dense.weight.data = quantize_weight(layer.attention.output.dense.weight.data)
if preln:
layer.intermediate.dense_act.weight.data = quantize_weight(
layer.intermediate.dense_act.weight.data)
layer.intermediate.dense_act.weight.data = quantize_weight(layer.intermediate.dense_act.weight.data)
else:
layer.intermediate.dense.weight.data = quantize_weight(
layer.intermediate.dense.weight.data)
layer.intermediate.dense.weight.data = quantize_weight(layer.intermediate.dense.weight.data)
layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data)
def quantize_fn(child):
@ -58,9 +49,7 @@ def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=Fal
return child
return quantize_module(model=model,
orig_class=orig_layer_impl,
quantize_fn=quantize_fn)
return quantize_module(model=model, orig_class=orig_layer_impl, quantize_fn=quantize_fn)
def quantize_module(model, orig_class, quantize_fn):

View File

@ -124,15 +124,10 @@ def _transpose(x, heads=1, mp_replace=None):
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1)
if len(q.shape) > 2:
new_shape = (q.shape[0], ) + (-1, )
return torch.cat((q.reshape(new_shape),
k.reshape(new_shape),
v.reshape(new_shape)),
return torch.cat((q.reshape(new_shape), k.reshape(new_shape), v.reshape(new_shape)),
dim=outer_dim).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
return torch.cat((q.reshape(-1), k.reshape(-1), v.reshape(-1)), dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
@ -156,19 +151,14 @@ def maybe_copy(module,
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst,
heads=heads,
mp_replace=mp_replace).contiguous())
dst = torch.nn.parameter.Parameter(_transpose(dst, heads=heads, mp_replace=mp_replace).contiguous())
else:
if split_qkv:
dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())), int8=weight_quantizer.q_int8)
else:
if qkv and megatron_v2:
tmp = _transpose(transpose(tmp),
heads=heads,
mp_replace=mp_replace).contiguous()
tmp = _transpose(transpose(tmp), heads=heads, mp_replace=mp_replace).contiguous()
if weight_quantizer.q_int8:
tmp = transpose(tmp)
dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
@ -177,13 +167,7 @@ def maybe_copy(module,
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module,
sd,
weight_quantizer,
mp_replace,
dst_name,
src_names,
split_qkv=False):
def maybe_copy_qkv(module, sd, weight_quantizer, mp_replace, dst_name, src_names, split_qkv=False):
if src_names[0] in sd:
q = sd[src_names[0]]
k = sd[src_names[1]]

View File

@ -23,6 +23,7 @@ from .utils import policy_to_ds_container
class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
if mp_group is not None:
self.gpu_index = dist.get_rank(group=mp_group)
@ -58,32 +59,22 @@ class ReplaceWithTensorSlicing:
if self.out_dim == 1:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
qkv_size = dst_shape[self.out_dim] // 3
qkv_split = [
torch.split(src_s,
qkv_size,
dim=outer_dim) for src_s in src_split
]
qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=outer_dim) for i in range(len(qkv_split[0]))
torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst = dst.reshape(-1).data.copy_(
weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
dst.data.copy_(src_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
dst.data.copy_(src_split[self.gpu_index].to(get_accelerator().current_device_name()).contiguous())
else:
if src_shape[0] == dst_shape[0]:
return torch.nn.parameter.Parameter(src)
if self.out_dim == 1:
qkv_size = dst_shape[0] // 3
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
bias_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=0) for i in range(len(qkv_split[0]))
]
bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
dst.data.copy_(bias_split[self.gpu_index].contiguous())
else:
dst.data.copy_(src_split[self.gpu_index].contiguous())
@ -103,30 +94,22 @@ class ReplaceWithTensorSlicing:
dst_shape = dst.shape
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[inner_dim] == dst_shape[
self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
else:
if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
weight_split = torch.split(
src,
dst_shape[self.in_dim],
dim=inner_dim)[self.gpu_index].contiguous()
weight_split = torch.split(src, dst_shape[self.in_dim], dim=inner_dim)[self.gpu_index].contiguous()
else:
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
weight_split = torch.split(
src.data,
dst_shape[self.out_dim],
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(
weight_split.shape)
weight_split = torch.split(src.data, dst_shape[self.out_dim],
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(weight_split.shape)
else:
if src_shape[0] == dst_shape[0]:
dst.data.copy_(src)
else:
bias_split = torch.split(src.data,
dst_shape[-1])[self.gpu_index].contiguous()
bias_split = torch.split(src.data, dst_shape[-1])[self.gpu_index].contiguous()
dst.data.copy_(bias_split)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
@ -150,6 +133,7 @@ def get_transformer_name(replaced_module):
class GroupQuantizer:
def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
self.group_size = group_size
self.num_bits = num_bits
@ -163,8 +147,7 @@ class GroupQuantizer:
inputs.scale = torch.empty(1)
return inputs
q_range = 2**self.num_bits
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[
0] // self.group_size
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
inputs = inputs.to(get_accelerator().current_device_name())
input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
@ -174,31 +157,14 @@ class GroupQuantizer:
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
out = torch.nn.Parameter(inputs_q, requires_grad=False)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [
inputs_split[i].reshape(num_groups,
-1).contiguous() for i in range(2)
]
input_min = [
torch.min(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
input_max = [
torch.max(input_flat[i],
dim=1,
keepdim=True)[0].float() for i in range(2)
]
scale1 = [
(torch.max(input_min[i].abs(),
input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)
]
input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
for i in range(2)]
out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0],
scale1[1]],
dim=0).reshape(num_groups,
-1).contiguous()
out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
-1).contiguous()
return out
@ -211,6 +177,7 @@ def _module_match(module):
def generic_injection(module, fp16=False, enable_cuda_graph=True):
def replace_attn(child, policy):
policy_attn = policy.attention(child)
if policy_attn is None:
@ -246,8 +213,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
attn_module.attn_qkvb = None
attn_module.attn_ow.data = transpose(attn_ow.data)
attn_module.attn_ob.data.copy_(
attn_ob.data.to(get_accelerator().current_device_name()))
attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
return attn_module
def replace_attn_block(child, policy):
@ -278,8 +244,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
# triangular_masking=True,
# max_out_tokens=8192)
from ..model_implementations.transformers.clip_encoder import DSClipEncoder
cg_encoder = DSClipEncoder(module.text_encoder,
enable_cuda_graph=enable_cuda_graph)
cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
setattr(module, 'text_encoder', cg_encoder)
for name in module.__dict__.keys():
sub_module = getattr(module, name)
@ -291,13 +256,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
for name, child in module.named_children():
_replace_module(child, policy)
if child.__class__ in new_policies:
replaced_module = new_policies[child.__class__](child,
policy)
replaced_module = new_policies[child.__class__](child, policy)
setattr(module, name, replaced_module)
_replace_module(sub_module, policy)
new_module = policy.apply(sub_module,
enable_cuda_graph=enable_cuda_graph)
new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
print(f"**** found and replaced {name} w. {type(new_module)}")
setattr(module, name, new_module)
@ -305,11 +268,7 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
container_g = None
def replace_transformer_layer(orig_layer_impl,
model,
checkpoint_dict,
config,
model_config):
def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
@ -334,15 +293,10 @@ def replace_transformer_layer(orig_layer_impl,
seed = -1
local_rank = -1
mp_replace = ReplaceWithTensorSlicing(
mp_group=config.tensor_parallel.tp_group,
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
def replace_with_policy(child,
policy_cls,
triangular_masking,
inference=False,
layer_id=0):
def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
policy = policy_cls(child, inference=inference)
if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set
@ -364,8 +318,7 @@ def replace_transformer_layer(orig_layer_impl,
_container.set_moe(moe)
# 2. Set the tensor parallelism config
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size,
config.tensor_parallel.tp_group)
_container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
# 3. Initialize tensors
_container.initialize_tensors()
@ -411,25 +364,21 @@ def replace_transformer_layer(orig_layer_impl,
if name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) //
mp_size,
(weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size,
),
device=child.weight.device,
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data)
new_bias = torch.empty((weight_shape[0]),
device=child.weight.device,
dtype=child.weight.dtype)
new_bias = torch.empty((weight_shape[0]), device=child.weight.device, dtype=child.weight.dtype)
if child.bias is not None:
new_bias.data.copy_(child.bias.data)
return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else:
new_weight = torch.empty((
(weight_shape[1] if conv_linear_layer else weight_shape[0]) //
mp_size,
(weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size,
weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
),
device=child.weight.device,
@ -441,51 +390,54 @@ def replace_transformer_layer(orig_layer_impl,
new_bias = torch.empty((weight_shape[0] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
bias_data = None if child.bias is None else mp_replace.copy(
new_bias,
child.bias.data).to(get_accelerator().current_device_name())
return LinearLayer(weight=data.to(
get_accelerator().current_device_name()),
bias=bias_data)
bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
get_accelerator().current_device_name())
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
def _slice_embedding(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0],
child.weight.shape[1] // mp_size),
new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight,
child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
return new_embedding
def update_mp_params(child):
if hasattr(child, 'n_heads'):
assert child.n_heads%mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(child.n_heads, mp_size)
assert child.n_heads % mp_size == 0, "n_heads ({}) must be divisible by mp_size ({})".format(
child.n_heads, mp_size)
child.n_heads = child.n_heads // mp_size
if hasattr(child, 'inner_dim'):
assert child.inner_dim%mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(child.inner_dim, mp_size)
assert child.inner_dim % mp_size == 0, "inner_dim ({}) must be divisible by mp_size ({})".format(
child.inner_dim, mp_size)
child.inner_dim = child.inner_dim // mp_size
if hasattr(child, 'num_heads'):
assert child.num_heads%mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(child.num_heads, mp_size)
assert child.num_heads % mp_size == 0, "num_heads ({}) must be divisible by mp_size ({})".format(
child.num_heads, mp_size)
child.num_heads = child.num_heads // mp_size
if hasattr(child, 'num_attention_heads'):
assert child.num_attention_heads%mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(child.num_attention_heads, mp_size)
assert child.num_attention_heads % mp_size == 0, "num_attention_heads ({}) must be divisible by mp_size ({})".format(
child.num_attention_heads, mp_size)
child.num_attention_heads = child.num_attention_heads // mp_size
if hasattr(child, 'num_attn_heads'):
assert child.num_attn_heads%mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(child.num_attn_heads, mp_size)
assert child.num_attn_heads % mp_size == 0, "num_attn_heads ({}) must be divisible by mp_size ({})".format(
child.num_attn_heads, mp_size)
child.num_attn_heads = child.num_attn_heads // mp_size
if hasattr(child, 'all_head_size'):
assert child.all_head_size%mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(child.all_head_size, mp_size)
assert child.all_head_size % mp_size == 0, "all_head_size ({}) must be divisible by mp_size ({})".format(
child.all_head_size, mp_size)
child.all_head_size = child.all_head_size // mp_size
if hasattr(child, 'embed_dim'):
assert child.embed_dim%mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(child.embed_dim, mp_size)
assert child.embed_dim % mp_size == 0, "embed_dim must ({}) be divisible by mp_size ({})".format(
child.embed_dim, mp_size)
child.embed_dim = child.embed_dim // mp_size
if hasattr(child, 'hidden_size'):
assert child.hidden_size%mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(child.hidden_size, mp_size)
assert child.hidden_size % mp_size == 0, "hidden_size ({}) must be divisible by mp_size ({})".format(
child.hidden_size, mp_size)
child.hidden_size = child.hidden_size // mp_size
conv_linear_layer = False
@ -507,12 +459,8 @@ def replace_transformer_layer(orig_layer_impl,
def _replace_module(r_module, prev_name=''):
for name, child in r_module.named_children():
if child.__class__ in linear_policies:
setattr(
r_module,
name,
linear_policies[child.__class__](child,
prev_name + '.' + name,
conv_linear_layer))
setattr(r_module, name, linear_policies[child.__class__](child, prev_name + '.' + name,
conv_linear_layer))
else:
update_mp_params(child)
_replace_module(child, name)
@ -559,15 +507,10 @@ def replace_transformer_layer(orig_layer_impl,
base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
if ckpt_type == 'pp' and type(checkpoint) is list:
pbar = tqdm.tqdm(total=len(checkpoint),
desc=f"Loading {len(checkpoint)} checkpoint shards")
pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
sd = [
torch.load(os.path.join(base_dir1,
checkpoint[i]),
map_location='cpu')
]
sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
@ -582,22 +525,15 @@ def replace_transformer_layer(orig_layer_impl,
tp_split_size = (world_size / ckpt_mp_size)
sd_offset = int(rank / tp_split_size)
sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
pbar = tqdm.tqdm(total=num_checkpoints,
desc=f"Loading {num_checkpoints} checkpoint shards")
pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
for i in range(num_checkpoints):
pbar.update(1)
ckpt_index = i * ckpt_mp_size + sd_offset
ckpt_files = [
os.path.join(base_dir1,
ckpt_list[ckpt_index +
j]) if base_dir1 else ckpt_list[ckpt_index +
j]
os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
for j in range(sd_count)
]
sds = [
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
@ -610,15 +546,13 @@ def replace_transformer_layer(orig_layer_impl,
gc.collect()
if "non_tp" in checkpoint:
pbar = tqdm.tqdm(
total=len(checkpoint["non_tp"]),
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
for i in range(len(checkpoint["non_tp"])):
pbar.update(1)
ckpt_file = os.path.join(base_dir1,
checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sds,
@ -657,37 +591,22 @@ def replace_transformer_layer(orig_layer_impl,
if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints")
torch.save(
OrderedDict({
k: v
for k,
v in dict(replaced_module.state_dict()).items()
if transformer_name not in k
}),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
OrderedDict({k: v
for k, v in dict(replaced_module.state_dict()).items()
if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
ckpt_config = json.dumps({
'type':
ckpt_name,
'base_dir':
f'{config.save_mp_checkpoint_path}',
'type': ckpt_name,
'base_dir': f'{config.save_mp_checkpoint_path}',
'checkpoints': {
"non_tp":
ckpt_files,
"tp": [
f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions)
for r in range(world_size)
]
"non_tp": ckpt_files,
"tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
},
'version':
1.0,
'parallelization':
'tp',
'tp_size':
world_size,
'dtype':
'int8' if quantize else ('float16' if fp16 else 'float32')
'version': 1.0,
'parallelization': 'tp',
'tp_size': world_size,
'dtype': 'int8' if quantize else ('float16' if fp16 else 'float32')
})
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json",
"w") as cfg:
with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
cfg.write(ckpt_config)
rep_sd = replaced_module.state_dict()
@ -699,13 +618,9 @@ def replace_transformer_layer(orig_layer_impl,
for m in range(num_partitions):
torch.save(
OrderedDict({
k: [rep_sd[k],
rep_sd[k].scale] if hasattr(rep_sd[k],
'scale') else rep_sd[k]
for k in keys[m * partition_size:(m + 1) * partition_size]
if transformer_name in k
}),
f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
}), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
return replaced_module
@ -720,6 +635,7 @@ def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child, _replace_policy, layer_id):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(config)
@ -821,9 +737,7 @@ def _replace_module(model, policies, layer_id=0):
"""
for name, child in model.named_children():
if child.__class__ in policies:
replaced_module = policies[child.__class__][0](child,
policies[child.__class__][-1],
layer_id)
replaced_module = policies[child.__class__][0](child, policies[child.__class__][-1], layer_id)
setattr(model, name, replaced_module)
if isinstance(model, PipelineModule):
assert hasattr(model, 'forward_funcs'),\

View File

@ -16,16 +16,8 @@ from .containers import VAEPolicy
# transformer-based policies
replace_policies = [
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
HFDistilBertLayerPolicy
HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy
]
# non-transformer-based policies

View File

@ -7,11 +7,11 @@ import copy
class Experts(torch.nn.Module):
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super(Experts, self).__init__()
self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate...

View File

@ -31,6 +31,7 @@ class MoE(torch.nn.Module):
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
"""
def __init__(self,
hidden_size,
expert,
@ -65,15 +66,8 @@ class MoE(torch.nn.Module):
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size,
num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts),
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
self.expert_group_name,
self.ep_size,
@ -90,20 +84,16 @@ class MoE(torch.nn.Module):
def _create_process_groups(self):
# Create process group for a layer if needed
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print(
f"No existing process group found, creating a new group named: {self.expert_group_name}"
)
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
# Condition 1 - no groups.mpu means no tensor parallelism
# Condition 2 - disabling expert tensor parallelism on purpose
groups._create_expert_and_data_parallel(self.ep_size)
else:
# expert tensor parallelism is enabled
groups._create_expert_data_and_model_parallel(self.ep_size,
mpu=groups.mpu)
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu)
# Set the group handle for the MOELayer (deepspeed_moe) object
self.deepspeed_moe._set_ep_group(
groups._get_expert_parallel_group(self.expert_group_name))
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
def forward(self, hidden_states, used_token=None):
""" MoE forward

View File

@ -32,14 +32,9 @@ def _gather_tokens(input_, dim=0):
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()
tensor_list = [
torch.empty_like(input_)
for _ in range(mpu.get_tensor_model_parallel_world_size())
]
tensor_list = [torch.empty_like(input_) for _ in range(mpu.get_tensor_model_parallel_world_size())]
tensor_list[rank] = input_
deepspeed.comm.all_gather(tensor_list,
input_,
group=mpu.get_tensor_model_parallel_group())
deepspeed.comm.all_gather(tensor_list, input_, group=mpu.get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
@ -53,7 +48,8 @@ def _drop_tokens(input_, dim=0):
total_chunks = mpu.get_tensor_model_parallel_world_size()
this_chunk = mpu.get_tensor_model_parallel_rank()
assert input_.shape[dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
assert input_.shape[
dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
chunk_size = input_.shape[dim] // total_chunks
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
@ -61,6 +57,7 @@ def _drop_tokens(input_, dim=0):
class _GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks"""
@staticmethod
def symbolic(graph, input_, dim):
return _gather_tokens(input_, dim)

View File

@ -60,11 +60,9 @@ def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - epsilon,
device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
@ -87,6 +85,7 @@ from deepspeed import comm as dist
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
@ -181,25 +180,18 @@ def top1gating(logits: Tensor,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates,
torch.tensor(capacity_factor),
torch.tensor(min_capacity))
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(
logits_w_noise if noisy_gate_policy == 'RSample' else gates,
dim=1)
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
@ -225,18 +217,16 @@ def top1gating(logits: Tensor,
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0,
device=logits.device),
high=torch.tensor(1.0,
device=logits.device)).rsample
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
high=torch.tensor(1.0, device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
assert logits.shape[
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity)
@ -258,7 +248,13 @@ def top1gating(logits: Tensor,
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,], exp_counts
return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
@ -275,19 +271,12 @@ def top1gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int) -> Tuple[Tensor,
Tensor,
Tensor,
Tensor]:
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates,
torch.tensor(capacity_factor * 2),
torch.tensor(min_capacity))
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
@ -393,13 +382,10 @@ class TopKGate(Module):
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor,
Tensor,
Tensor]: # type: ignore
def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
self.timers('TopKGate').start()
@ -413,21 +399,13 @@ class TopKGate(Module):
logits = self.wg(input_fp32)
if self.k == 1:
gate_output = top1gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens,
self.use_rts,
use_tutel)
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
else:
gate_output = top2gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
if self.wall_clock_breakdown:
self.timers('TopKGate').stop()
@ -453,6 +431,7 @@ class MOELayer(Base):
expert (torch.nn.Module):
expert network
"""
def __init__(self,
gate: Module,
experts: Module,
@ -481,9 +460,8 @@ class MOELayer(Base):
logger.warning("Tutel optimization requested but not installed. "
"Proceeding without Tutel.")
elif use_tutel and TUTEL_INSTALLED and gate.k != 1:
logger.warning(
"To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")
logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. "
"Proceeding without Tutel.")
def _set_ep_group(self, ep_group):
self.ep_group = ep_group
@ -506,18 +484,12 @@ class MOELayer(Base):
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E,
C,
M,
dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm",
dispatch_mask.type_as(input[0]),
reshaped_input)
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
if self.wall_clock_breakdown:
self.timers('falltoall').start()
@ -538,10 +510,7 @@ class MOELayer(Base):
self.time_falltoall = self.timers('falltoall').elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size,
self.num_local_experts,
-1,
d_model)
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
@ -555,9 +524,7 @@ class MOELayer(Base):
self.time_salltoall = self.timers('salltoall').elapsed(reset=False)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts,
-1,
d_model)
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
if groups._get_expert_model_parallel_world_size() == 1:
# the dropped duplicate tokens need to be gathered on each
@ -568,9 +535,7 @@ class MOELayer(Base):
if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
else:
combined_output = einsum("sec,ecm->sm",
combine_weights.type_as(input[0]),
expert_output)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
a = combined_output.reshape(input[0].shape)

Some files were not shown because too many files have changed in this diff Show More