mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
41 lines
1.8 KiB
Python
Executable File
41 lines
1.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import json
|
|
|
|
import deepspeed
|
|
from deepspeed.elasticity import compute_elastic_config
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
|
|
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
|
|
args = parser.parse_args()
|
|
ds_config = json.load(open(args.config, 'r'))
|
|
|
|
ds_version = deepspeed.__version__
|
|
|
|
elastic_config = ds_config['elasticity']
|
|
print('------------------------------------------')
|
|
print("Elasticity config:")
|
|
print('------------------------------------------')
|
|
print(json.dumps(elastic_config, indent=4, sort_keys=True))
|
|
|
|
if args.world_size > 0:
|
|
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config,
|
|
target_deepspeed_version=ds_version,
|
|
world_size=args.world_size)
|
|
print('------------------------------------------')
|
|
print(f"Calculated results for world size {args.world_size}:")
|
|
print('------------------------------------------')
|
|
print(f'final_batch_size .... {final_batch_size}')
|
|
print(f'valid_gpus .......... {valid_gpus}')
|
|
print(f'micro_batch_size .... {micro_batch_size}')
|
|
else:
|
|
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
|
|
print('------------------------------------------')
|
|
print("Calculated results:")
|
|
print('------------------------------------------')
|
|
print(f'final_batch_size .... {final_batch_size}')
|
|
print(f'valid_gpus .......... {valid_gpus}')
|