mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
[compile] Show breakdown of graph break (#6601)
This PR extends https://github.com/microsoft/DeepSpeed/pull/6570 by showing a breakdown of graph breaks. So we can see how graph breaks are distributed among different reasons. An example of graph break output can be seen from the following workflow run https://github.com/microsoft/DeepSpeed/actions/runs/11199157962
This commit is contained in:
10
.github/workflows/xpu-compile.yml
vendored
10
.github/workflows/xpu-compile.yml
vendored
@ -51,9 +51,15 @@ jobs:
|
||||
- name: Compile Status
|
||||
shell: bash
|
||||
run: |
|
||||
echo "# torch.compile graph breaks" >> $GITHUB_STEP_SUMMARY
|
||||
export FI_HMEM=system
|
||||
ulimit -n 1048575
|
||||
cd tests/torch_compile
|
||||
export ZE_AFFINITY_MASK=0,1
|
||||
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
|
||||
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
|
||||
echo "## ZeRO stage 3" >> $GITHUB_STEP_SUMMARY
|
||||
deepspeed test_compile.py --deepspeed_config ds_config_z3.json 2>&1 | tee log_z3.txt
|
||||
# for each line start with 'dynamo_output', extract the second field and following fields and append to GITHUB_STEP_SUMMARY using awk
|
||||
cat log_z3.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY
|
||||
echo "## ZeRO stage 2" >> $GITHUB_STEP_SUMMARY
|
||||
deepspeed test_compile.py --deepspeed_config ds_config_z2.json 2>&1 | tee log_z2.txt
|
||||
cat log_z2.txt | awk '/^dynamo_output/ {$1=""; print $0}' >> $GITHUB_STEP_SUMMARY
|
||||
|
40
tests/torch_compile/ds_config_z2.json
Normal file
40
tests/torch_compile/ds_config_z2.json
Normal file
@ -0,0 +1,40 @@
|
||||
{
|
||||
"train_batch_size": 8,
|
||||
"steps_per_print": 2000,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.001,
|
||||
"betas": [
|
||||
0.8,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 0.001,
|
||||
"warmup_num_steps": 1000
|
||||
}
|
||||
},
|
||||
"gradient_clipping": 1.0,
|
||||
"prescale_gradients": false,
|
||||
"bf16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 500,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1,
|
||||
"initial_scale_power": 15
|
||||
},
|
||||
"wall_clock_breakdown": false,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"overlap_comm": false,
|
||||
"contiguous_gradients": false
|
||||
}
|
||||
}
|
@ -14,22 +14,9 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 100
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
def get_dynamo_stats():
|
||||
# TODO: consider deepcopy'ing the entire counters struct and
|
||||
# adding a helper to do subtraction on it
|
||||
return collections.Counter({
|
||||
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
|
||||
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
|
||||
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
|
||||
# NB: The plus removes zero counts
|
||||
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
|
||||
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"],
|
||||
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"],
|
||||
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"],
|
||||
})
|
||||
return torch._dynamo.utils.counters["graph_break"]
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
@ -70,7 +57,7 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
|
||||
parser.add_argument('--deepspeed_config',
|
||||
type=str,
|
||||
default='ds_config.json',
|
||||
default='ds_config_z3.json',
|
||||
help='path to DeepSpeed configuration file')
|
||||
cmd_args = parser.parse_args()
|
||||
|
||||
@ -82,6 +69,11 @@ residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_
|
||||
|
||||
start_stats = get_dynamo_stats()
|
||||
|
||||
if comm.get_rank() == 0:
|
||||
#print(dynamo_stats['graph_breaks'])
|
||||
for item in start_stats.items():
|
||||
print(item)
|
||||
|
||||
for step, batch in enumerate(rand_loader):
|
||||
if step % 10 == 0 and comm.get_rank() == 0:
|
||||
print(f'step={step}')
|
||||
@ -93,7 +85,14 @@ for step, batch in enumerate(rand_loader):
|
||||
model_engine.step()
|
||||
|
||||
dynamo_stats = get_dynamo_stats()
|
||||
dynamo_stats.subtract(start_stats)
|
||||
|
||||
if comm.get_rank() == 0:
|
||||
print(dynamo_stats)
|
||||
# print break down of graph break stats with markdown, print in table format, start with reason, then count
|
||||
# print a tag 'dynamo_output' before each line to allow post processing
|
||||
print("dynamo_output | Reason | Count |")
|
||||
print("dynamo_output | ------ | ----- |")
|
||||
for item in dynamo_stats.items():
|
||||
# replace '|' in item[0] with a literal '|' to avoid mess with table format
|
||||
item = (item[0].replace('|', r'\|'), item[1])
|
||||
print(f"dynamo_output | {item[0]} | {item[1]} |")
|
||||
print(f"dynamo_output | Total | {sum(dynamo_stats.values())} |")
|
||||
|
Reference in New Issue
Block a user