[compile] Show breakdown of graph break (#6601)

This PR extends https://github.com/microsoft/DeepSpeed/pull/6570 by
showing a breakdown of graph breaks. So we can see how graph breaks are
distributed among different reasons. An example of graph break output
can be seen from the following workflow run
https://github.com/microsoft/DeepSpeed/actions/runs/11199157962
This commit is contained in:
Ma, Guokai
2024-10-15 01:31:34 +08:00
committed by GitHub
parent 7a5bc4fdf9
commit cf41e8c4e8
4 changed files with 64 additions and 19 deletions

View File

@ -51,9 +51,15 @@ jobs:
- name: Compile Status
shell: bash
run: |
echo "# torch.compile graph breaks" >> $GITHUB_STEP_SUMMARY
export FI_HMEM=system
ulimit -n 1048575
cd tests/torch_compile
export ZE_AFFINITY_MASK=0,1
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
echo "## ZeRO stage 3" >> $GITHUB_STEP_SUMMARY
deepspeed test_compile.py --deepspeed_config ds_config_z3.json 2>&1 | tee log_z3.txt
# for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk
cat log_z3.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY
echo "## ZeRO stage 2" >> $GITHUB_STEP_SUMMARY
deepspeed test_compile.py --deepspeed_config ds_config_z2.json 2>&1 | tee log_z2.txt
cat log_z2.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY

View File

@ -0,0 +1,40 @@
{
"train_batch_size": 8,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"bf16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15
},
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 2,
"overlap_comm": false,
"contiguous_gradients": false
}
}

View File

@ -14,22 +14,9 @@ from torch.utils.data import Dataset, DataLoader
torch._dynamo.config.cache_size_limit = 100
import collections
def get_dynamo_stats():
# TODO: consider deepcopy'ing the entire counters struct and
# adding a helper to do subtraction on it
return collections.Counter({
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
# NB: The plus removes zero counts
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"],
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"],
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"],
})
return torch._dynamo.utils.counters["graph_break"]
class RandomDataset(Dataset):
@ -70,7 +57,7 @@ parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--deepspeed_config',
type=str,
default='ds_config.json',
default='ds_config_z3.json',
help='path to DeepSpeed configuration file')
cmd_args = parser.parse_args()
@ -82,6 +69,11 @@ residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_
start_stats = get_dynamo_stats()
if comm.get_rank() == 0:
#print(dynamo_stats['graph_breaks'])
for item in start_stats.items():
print(item)
for step, batch in enumerate(rand_loader):
if step % 10 == 0 and comm.get_rank() == 0:
print(f'step={step}')
@ -93,7 +85,14 @@ for step, batch in enumerate(rand_loader):
model_engine.step()
dynamo_stats = get_dynamo_stats()
dynamo_stats.subtract(start_stats)
if comm.get_rank() == 0:
print(dynamo_stats)
# print break down of graph break stats with markdown, print in table format, start with reason, then count
# print a tag 'dynamo_output' before each line to allow post processing
print("dynamo_output | Reason | Count |")
print("dynamo_output | ------ | ----- |")
for item in dynamo_stats.items():
# replace '|' in item[0] with a literal '|' to avoid mess with table format
item = (item[0].replace('|', r'\|'), item[1])
print(f"dynamo_output | {item[0]} | {item[1]} |")
print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |")