Compare commits

..

1 Commits

Author SHA1 Message Date
ff5b19523e [BE] Use torch check the way its intended
Replace
`if (!foo) TORCH_CHECK(false, "bar");` with `TORCH_CHECK(foo, "bar");`
2025-10-08 13:26:59 -07:00
118 changed files with 4117 additions and 2742 deletions

View File

@ -256,7 +256,7 @@ test_torchbench_smoketest() {
local device=mps
local dtypes=(undefined float16 bfloat16 notset)
local dtype=${dtypes[$1]}
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
for backend in eager inductor; do
@ -319,7 +319,7 @@ test_aoti_torchbench_smoketest() {
local device=mps
local dtypes=(undefined float16 bfloat16 notset)
local dtype=${dtypes[$1]}
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
echo "Launching torchbench inference performance run for AOT Inductor and dtype ${dtype}"
local dtype_arg="--${dtype}"

View File

@ -838,7 +838,7 @@ test_dynamo_benchmark() {
elif [[ "${suite}" == "timm_models" ]]; then
export TORCHBENCH_ONLY_MODELS="inception_v3"
elif [[ "${suite}" == "torchbench" ]]; then
export TORCHBENCH_ONLY_MODELS="BERT_pytorch"
export TORCHBENCH_ONLY_MODELS="hf_Bert"
fi
fi
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
@ -869,13 +869,13 @@ test_inductor_torchbench_smoketest_perf() {
mkdir -p "$TEST_REPORTS_DIR"
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only BERT_pytorch \
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \
--output "$TEST_REPORTS_DIR/inductor_training_smoketest.csv"
# The threshold value needs to be actively maintained to make this check useful
python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" -t 1.4
# Check memory compression ratio for a few models
for test in BERT_pytorch yolov3; do
for test in hf_Albert timm_vision_transformer; do
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --amp --training \
--disable-cudagraphs --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" \
--only $test --output "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv"

View File

@ -71,7 +71,14 @@ export PYTORCH_BUILD_NUMBER=1
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
TRITON_CONSTRAINT="platform_system == 'Linux'"
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'"
# CUDA 12.9/13.0 builds have triton for Linux and Linux aarch64 binaries.
if [[ "$DESIRED_CUDA" == "cu129" ]] || [[ "$DESIRED_CUDA" == "cu130" ]]; then
TRITON_CONSTRAINT="platform_system == 'Linux'"
fi
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"

View File

@ -28,10 +28,6 @@ runs:
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: Print GPU info (if present)
shell: bash
run: if [ -f /usr/bin/nvidia-smi ]; then nvidia-smi; fi
- name: Check if in a container runner
shell: bash
id: check_container_runner
@ -86,6 +82,37 @@ runs:
# Prune all of the docker images
docker system prune -af
- name: Manually resolve download.pytorch.org
shell: bash
continue-on-error: true
run: |
set +e
set -x
PT_DOMAIN=download.pytorch.org
# TODO: Flaky access to download.pytorch.org https://github.com/pytorch/pytorch/issues/100400,
# cleaning this up once the issue is fixed. There are more than one resolved IP here, the last
# one is returned at random
RESOLVED_IP=$(dig -4 +short "${PT_DOMAIN}" | tail -n1)
if [ -z "${RESOLVED_IP}" ]; then
echo "Couldn't resolve ${PT_DOMAIN}, retrying with Google DNS..."
RESOLVED_IP=$(dig -4 +short "${PT_DOMAIN}" @8.8.8.8 | tail -n1)
if [ -z "${RESOLVED_IP}" ]; then
echo "Couldn't resolve ${PT_DOMAIN}, exiting..."
exit 1
fi
fi
if grep -r "${PT_DOMAIN}" /etc/hosts; then
# Clean up any old records first
sudo sed -i "/${PT_DOMAIN}/d" /etc/hosts
fi
echo "${RESOLVED_IP} ${PT_DOMAIN}" | sudo tee -a /etc/hosts
cat /etc/hosts
- name: Check that the docker daemon is running
shell: bash
continue-on-error: true

View File

@ -2,7 +2,7 @@ name: inductor-perf-nightly-h100
on:
schedule:
- cron: 15 0 * * 1-6
- cron: 15 0,12 * * 1-6
- cron: 0 7 * * 0
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs

View File

@ -6693,12 +6693,12 @@
- func: native_norm(Tensor self, Scalar p=2) -> Tensor
dispatch:
SparseCPU, SparseCUDA, SparseMPS: norm_sparse
SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.out
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
dispatch:
SparseCPU, SparseCUDA, SparseMPS: norm_sparse
SparseCPU, SparseCUDA: norm_sparse
autogen: native_norm.ScalarOpt_dim_dtype_out
- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
@ -6824,14 +6824,14 @@
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm
SparseCPU, SparseCUDA: sparse_dtype_norm
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
structured_delegate: norm.out
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_norm
SparseCPU, SparseCUDA: sparse_norm
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
structured: True

View File

@ -25,6 +25,15 @@ drq
fambench_dlrm
fambench_xlmr
fastNLP_Bert
hf_Albert
hf_Bart
hf_Bert
hf_BigBird
hf_DistilBert
hf_GPT2
hf_Longformer
hf_Reformer
hf_T5
maml
maml_omniglot
mnasnet1_0
@ -51,6 +60,13 @@ soft_actor_critic
speech_transformer
squeezenet1_1
tacotron2
timm_efficientdet
timm_efficientnet
timm_nfnet
timm_regnet
timm_resnest
timm_vision_transformer
timm_vovnet
tts_angular
vgg16
vision_maskrcnn

View File

@ -23,6 +23,7 @@ TORCHBENCH_MODELS: list[str] = [
"resnet50",
"moco",
"llama",
"hf_T5",
]
HUGGINGFACE_MODELS: list[str] = [
"AllenaiLongformerBase",

View File

@ -11,6 +11,7 @@ import pandas as pd
flaky_models = {
"yolov3",
"detectron2_maskrcnn_r_101_c4",
"timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699
"XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148
"moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291
# discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it
@ -39,9 +40,13 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
"detectron2_fcos_r_50_fpn",
"doctr_det_predictor",
"doctr_reco_predictor",
"dpn107",
"fbnetv3_b",
"levit_128",
"hf_BigBird",
"hf_Longformer",
"hf_Reformer",
"hf_Roberta_base",
"hf_T5",
"hf_T5_base",
"hf_T5_generate",
"llava",
"microbench_unbacked_tolist_sum",
"mnasnet1_0",
@ -58,7 +63,12 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
"squeezenet1_1",
"stable_diffusion_text_encoder",
"stable_diffusion_unet",
"swsl_resnext101_32x16d",
"timm_efficientdet",
"timm_efficientnet",
"timm_nfnet",
"timm_regnet",
"timm_resnest",
"timm_vovnet",
"torchrec_dlrm",
"vgg16",
# LLM

View File

@ -36,7 +36,12 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
"detectron2_fcos_r_50_fpn",
"doctr_det_predictor",
"doctr_reco_predictor",
"levit_128",
"hf_BigBird",
"hf_Longformer",
"hf_Reformer",
"hf_Roberta_base",
"hf_T5",
"hf_T5_base",
"llava",
"microbench_unbacked_tolist_sum",
"resnet50",
@ -46,6 +51,7 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
"stable_diffusion_text_encoder",
"stable_diffusion_unet",
"timm_efficientdet",
"timm_nfnet",
"torchrec_dlrm",
"vgg16",
# LLM

View File

@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

View File

@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,fail_accuracy,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

View File

@ -118,6 +118,62 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -258,6 +314,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

View File

@ -114,6 +114,58 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -226,6 +278,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

View File

@ -114,6 +114,58 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -226,6 +278,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

View File

@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

View File

@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

View File

@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

View File

@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

View File

@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -190,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

View File

@ -98,6 +98,58 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -210,6 +262,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

View File

@ -98,6 +98,58 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -210,6 +262,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

View File

@ -106,6 +106,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,27
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -226,6 +286,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

View File

@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,25
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,8
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_large,pass_due_to_skip,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,model_fail_to_load,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,3

1 name accuracy graph_breaks
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

View File

@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

View File

@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -190,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,fail_accuracy,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

View File

@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

View File

@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

View File

@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

View File

@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,fail_accuracy,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

View File

@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,9
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,8
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
timm_efficientdet,pass,2
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

View File

@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,25
hf_Roberta_base,pass,6
hf_T5,pass,0
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +258,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,pass,2
timm_efficientnet,pass,7
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78 yolov3 pass 8
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

View File

@ -118,6 +118,62 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -258,6 +314,34 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

View File

@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,9
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,8
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
timm_efficientdet,pass,2
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

View File

@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,fail_to_run,3
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,25
hf_Roberta_base,pass,6
hf_T5,pass,0
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -190,6 +254,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,pass,2
timm_efficientnet,pass,7
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

View File

@ -130,6 +130,74 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_to_run,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,5
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +346,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,pass,2
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

View File

@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,fail_to_run,3
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,10
hf_Reformer,pass,20
hf_Roberta_base,pass,6
hf_T5,pass,5
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -190,6 +254,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,pass,8
timm_efficientnet,pass,7
timm_nfnet,pass,6
timm_regnet,pass,0
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

View File

@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,pass,9
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,8
hf_Roberta_base,pass,0
hf_T5,pass,0
hf_T5_base,pass,0
hf_T5_generate,pass,7
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
timm_efficientdet,pass,2
timm_efficientnet,pass,0
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

View File

@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,15
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Longformer,pass,4
hf_Reformer,pass,25
hf_Roberta_base,pass,6
hf_T5,pass,0
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +258,38 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientdet,pass,2
timm_efficientnet,pass,7
timm_nfnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78 yolov3 pass 8
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

View File

@ -130,6 +130,66 @@ functorch_maml_omniglot,pass,0
hf_Albert,pass,0
hf_Bart,pass,0
hf_Bert,pass,0
hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,0
hf_DistilBert,pass,0
hf_GPT2,pass,0
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,8
hf_T5,pass,0
hf_T5_base,eager_fail_to_run,0
hf_T5_generate,pass,11
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,0
hf_distil_whisper,pass,0
lennard_jones,pass,0
@ -274,6 +334,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,0
timm_regnet,pass,0
timm_resnest,pass,0
timm_vision_transformer,pass,0
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,0
torch_multimodal_clip,pass,0

1 name accuracy graph_breaks
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

View File

@ -78,6 +78,58 @@ functorch_maml_omniglot,pass,7
hf_Albert,pass,6
hf_Bart,pass,6
hf_Bert,pass,6
hf_Bert_large,pass,6
hf_BigBird,pass,6
hf_DistilBert,pass,6
hf_GPT2,pass,8
hf_GPT2_large,pass_due_to_skip,0
hf_Reformer,pass,25
hf_T5_base,eager_2nd_run_OOM,0
hf_T5_large,pass_due_to_skip,0
hf_Whisper,pass,6
hf_distil_whisper,model_fail_to_load,0
lennard_jones,pass,7
@ -194,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
timm_efficientnet,pass,7
timm_regnet,pass,7
timm_resnest,pass,6
timm_vision_transformer,pass,6
timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,6
torch_multimodal_clip,pass,7

1 name accuracy graph_breaks
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

View File

@ -149,6 +149,7 @@ CI_SKIP_DYNAMIC_BATCH_ONLY = {
"detectron2_fasterrcnn_r_50_c4",
"detectron2_fasterrcnn_r_50_dc5",
"detectron2_fasterrcnn_r_50_fpn",
"hf_T5_generate",
"Reformer",
"llama",
}.union(INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY)
@ -175,7 +176,13 @@ BENCHMARK_USE_SGD = {
"speech_transformer",
"squeezenet1_1",
"stable_diffusion_text_encoder",
"timm_efficientdet",
"timm_nfnet",
"timm_resnest",
"timm_vision_transformer",
"timm_vovnet",
"vgg16",
"hf_T5", # Fails dynamic https://github.com/pytorch/pytorch/issues/115968
# HF
"AlbertForMaskedLM",
"BartForCausalLM",
@ -209,6 +216,8 @@ CI_USE_SGD = {
"detectron2_maskrcnn_r_101_fpn",
"detectron2_maskrcnn_r_50_c4",
"detectron2_maskrcnn_r_50_fpn",
"hf_T5_base",
"hf_clip",
"llama_v2_7b_16h",
"mobilenet_v2_quantized_qat",
"phi_1_5 resnet50_quantized_qat",
@ -2022,6 +2031,8 @@ class BenchmarkRunner:
from diffusers.models.transformer_2d import Transformer2DModel
from torchbenchmark.models.nanogpt.model import Block
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.t5.modeling_t5 import T5Block
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
from torch.distributed.fsdp.wrap import (
ModuleWrapPolicy,
@ -2031,6 +2042,10 @@ class BenchmarkRunner:
# handcrafted wrap policy
MODEL_FSDP_WRAP = {
"stable_diffusion_unet": (Transformer2DModel,),
"hf_T5": (T5Block,),
"hf_T5_base": (T5Block,),
"hf_T5_large": (T5Block,),
"hf_Whisper": (WhisperEncoderLayer,),
"llama_v2_7b_16h": (LlamaDecoderLayer,),
"nanogpt": (Block,),
}
@ -3795,6 +3810,22 @@ def run(runner, args, original_dir=None):
global synchronize
synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize
if (
args.devices == ["cuda"]
and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
):
# OOM errors on an RTX 3090 with 24gb RAM
runner.skip_models.update(
{
# torchbench
"hf_Longformer",
"timm_nfnet",
"timm_efficientdet",
}
)
if args.training:
runner.skip_models.add("hf_T5")
if args.nnc:
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

View File

@ -21,6 +21,9 @@ try:
except ImportError:
from torchbench import setup_torchbench_cwd
from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
from transformers.models.t5.modeling_t5 import T5Block
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
@ -125,6 +128,8 @@ def fsdp_checkpointing_base(model, blocks):
MODEL_FSDP_WRAP = {
"toy_model": (MyModule,),
"hf_Bert": (BertLayer, BertLMPredictionHead),
"hf_T5": (T5Block,),
}

View File

@ -158,7 +158,7 @@ if __name__ == "__main__":
model_arg.add_argument(
"--torchbench-model",
"--torchbench_model",
help="name of torchbench model, e.g. BERT_pytorch",
help="name of torchbench model, e.g. hf_Bert",
)
model_arg.add_argument(
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"

View File

@ -12,6 +12,17 @@ cuda,dlrm,1024,1.3421,3.2177,4.9493,1.0009
cuda,drq,1,1.0820,3.8157,8.0732,0.9687
cuda,fastNLP_Bert,6,1.4839,37.9050,32.7583,1.1563
cuda,functorch_dp_cifar10,64,1.5014,6.9596,14.1516,0.4432
cuda,hf_Albert,8,2.2452,30.6134,25.9036,1.3098
cuda,hf_Bart,4,1.7012,34.3999,37.9975,1.0128
cuda,hf_Bert,4,1.9003,23.3435,34.8196,1.0273
cuda,hf_Bert_large,4,1.6346,52.8525,62.3112,1.0726
cuda,hf_BigBird,2,1.9208,105.2672,101.4787,1.1415
cuda,hf_DistilBert,8,1.3988,22.5793,20.2386,1.0232
cuda,hf_GPT2,4,1.8075,27.5184,25.3428,1.1562
cuda,hf_GPT2_large,4,1.7716,118.7404,68.1618,1.1725
cuda,hf_Reformer,4,1.1744,70.4228,15.1152,0.9266
cuda,hf_T5,8,1.8778,93.3134,37.0046,1.2279
cuda,hf_T5_large,2,2.3623,101.5518,143.7982,1.1674
cuda,lennard_jones,1000,1.0649,1.5233,4.1119,0.9998
cuda,mnasnet1_0,32,1.1957,19.1993,27.2302,0.7758
cuda,mobilenet_v2,96,1.4876,32.3311,27.4719,1.1729
@ -31,6 +42,14 @@ cuda,shufflenet_v2_x1_0,128,1.3027,25.7017,27.9875,1.1015
cuda,soft_actor_critic,256,0.9965,2.2580,4.6661,0.9995
cuda,speech_transformer,32,1.8405,35.1645,33.3422,1.0888
cuda,squeezenet1_1,32,1.4191,7.3454,9.4751,1.1148
cuda,timm_efficientdet,1,1.6630,78.2697,150.9620,0.9904
cuda,timm_efficientnet,32,1.2689,28.5348,66.3911,0.9428
cuda,timm_nfnet,128,1.5319,79.5429,32.9961,1.1070
cuda,timm_regnet,32,1.0564,56.9897,53.0027,0.9500
cuda,timm_resnest,32,1.6485,14.3908,56.7240,0.9515
cuda,timm_vision_transformer,8,1.6100,18.7736,36.9495,0.7301
cuda,timm_vision_transformer_large,8,1.0842,170.9849,72.0604,0.9762
cuda,timm_vovnet,32,1.0472,25.4676,24.8428,0.8843
cuda,tts_angular,64,1.0366,6.9889,4.2683,0.9973
cuda,vgg16,64,1.2560,52.7072,7.3733,0.9884
cuda,yolov3,16,1.2600,54.2350,42.4711,1.0108

1 dev name batch_size speedup abs_latency compilation_latency compression_ratio
12 cuda drq 1 1.0820 3.8157 8.0732 0.9687
13 cuda fastNLP_Bert 6 1.4839 37.9050 32.7583 1.1563
14 cuda functorch_dp_cifar10 64 1.5014 6.9596 14.1516 0.4432
15 cuda hf_Albert 8 2.2452 30.6134 25.9036 1.3098
16 cuda hf_Bart 4 1.7012 34.3999 37.9975 1.0128
17 cuda hf_Bert 4 1.9003 23.3435 34.8196 1.0273
18 cuda hf_Bert_large 4 1.6346 52.8525 62.3112 1.0726
19 cuda hf_BigBird 2 1.9208 105.2672 101.4787 1.1415
20 cuda hf_DistilBert 8 1.3988 22.5793 20.2386 1.0232
21 cuda hf_GPT2 4 1.8075 27.5184 25.3428 1.1562
22 cuda hf_GPT2_large 4 1.7716 118.7404 68.1618 1.1725
23 cuda hf_Reformer 4 1.1744 70.4228 15.1152 0.9266
24 cuda hf_T5 8 1.8778 93.3134 37.0046 1.2279
25 cuda hf_T5_large 2 2.3623 101.5518 143.7982 1.1674
26 cuda lennard_jones 1000 1.0649 1.5233 4.1119 0.9998
27 cuda mnasnet1_0 32 1.1957 19.1993 27.2302 0.7758
28 cuda mobilenet_v2 96 1.4876 32.3311 27.4719 1.1729
42 cuda soft_actor_critic 256 0.9965 2.2580 4.6661 0.9995
43 cuda speech_transformer 32 1.8405 35.1645 33.3422 1.0888
44 cuda squeezenet1_1 32 1.4191 7.3454 9.4751 1.1148
45 cuda timm_efficientdet 1 1.6630 78.2697 150.9620 0.9904
46 cuda timm_efficientnet 32 1.2689 28.5348 66.3911 0.9428
47 cuda timm_nfnet 128 1.5319 79.5429 32.9961 1.1070
48 cuda timm_regnet 32 1.0564 56.9897 53.0027 0.9500
49 cuda timm_resnest 32 1.6485 14.3908 56.7240 0.9515
50 cuda timm_vision_transformer 8 1.6100 18.7736 36.9495 0.7301
51 cuda timm_vision_transformer_large 8 1.0842 170.9849 72.0604 0.9762
52 cuda timm_vovnet 32 1.0472 25.4676 24.8428 0.8843
53 cuda tts_angular 64 1.0366 6.9889 4.2683 0.9973
54 cuda vgg16 64 1.2560 52.7072 7.3733 0.9884
55 cuda yolov3 16 1.2600 54.2350 42.4711 1.0108

View File

@ -1,16 +1,29 @@
#name,backend,data_type,shape,wrapper,perf_speedup_target_c7i_metal_24xl
#timm_vision_transformer,inductor,float32,static,default,1.039510755
phlippe_densenet,inductor,float32,static,default,1.46474287
basic_gnn_edgecnn,inductor,float32,dynamic,default,1.30092957
llama_v2_7b_16h,inductor,float32,dynamic,default,1.23234331
resnet50,inductor,float32,dynamic,default,1.67742767
#timm_efficientnet,inductor,float32,static,cpp,
mobilenet_v3_large,inductor,float32,static,cpp,2.63311706
timm_resnest,inductor,float32,dynamic,cpp,1.7321529
functorch_maml_omniglot,inductor,float32,dynamic,cpp,1.126799
#hf_GPT2,inductor,float32,dynamic,cpp,
yolov3,export-aot-inductor,float32,static,default,1.40687424
mobilenet_v2,export-aot-inductor,float32,static,default,2.90375357
resnext50_32x4d,export-aot-inductor,float32,dynamic,default,1.49299689
hf_Albert,export-aot-inductor,float32,dynamic,default,1.261471
resnext50_32x4d,inductor,amp,static,default,1.47023111
vgg16,inductor,amp,static,default,1.2692454
hf_Longformer,inductor,amp,dynamic,default,1.22015225
hf_Bert_large,inductor,amp,dynamic,default,1.18572179
llama,inductor,amp,static,default,1.33157028
timm_regnet,inductor,amp,static,cpp,1.12734073
mnasnet1_0,inductor,amp,static,cpp,2.1296814
#hf_T5_generate,inductor,amp,dynamic,cpp,
timm_vovnet,inductor,amp,dynamic,cpp,1.10851009
#mobilenet_v2,inductor,amp,dynamic,cpp,2.27774577 # https://github.com/pytorch/pytorch/issues/131693
hf_GPT2,export-aot-inductor,amp,static,default,1.4432794
densenet121,export-aot-inductor,amp,static,default,1.25591385
hf_DistilBert,export-aot-inductor,amp,dynamic,default,1.2926442
hf_Bart,export-aot-inductor,amp,dynamic,default,1.19515416

1 #name backend data_type shape wrapper perf_speedup_target_c7i_metal_24xl
2 #timm_vision_transformer inductor float32 static default 1.039510755
3 phlippe_densenet inductor float32 static default 1.46474287
4 basic_gnn_edgecnn inductor float32 dynamic default 1.30092957
5 llama_v2_7b_16h inductor float32 dynamic default 1.23234331
6 resnet50 inductor float32 dynamic default 1.67742767
7 #timm_efficientnet inductor float32 static cpp
8 mobilenet_v3_large inductor float32 static cpp 2.63311706
9 timm_resnest inductor float32 dynamic cpp 1.7321529
10 functorch_maml_omniglot inductor float32 dynamic cpp 1.126799
11 #hf_GPT2 inductor float32 dynamic cpp
12 yolov3 export-aot-inductor float32 static default 1.40687424
13 mobilenet_v2 export-aot-inductor float32 static default 2.90375357
14 resnext50_32x4d export-aot-inductor float32 dynamic default 1.49299689
15 hf_Albert export-aot-inductor float32 dynamic default 1.261471
16 resnext50_32x4d inductor amp static default 1.47023111
17 vgg16 inductor amp static default 1.2692454
18 hf_Longformer inductor amp dynamic default 1.22015225
19 hf_Bert_large inductor amp dynamic default 1.18572179
20 llama inductor amp static default 1.33157028
21 timm_regnet inductor amp static cpp 1.12734073
22 mnasnet1_0 inductor amp static cpp 2.1296814
23 #hf_T5_generate inductor amp dynamic cpp
24 timm_vovnet inductor amp dynamic cpp 1.10851009
25 #mobilenet_v2 inductor amp dynamic cpp 2.27774577 # https://github.com/pytorch/pytorch/issues/131693
26 hf_GPT2 export-aot-inductor amp static default 1.4432794
27 densenet121 export-aot-inductor amp static default 1.25591385
28 hf_DistilBert export-aot-inductor amp dynamic default 1.2926442
29 hf_Bart export-aot-inductor amp dynamic default 1.19515416

View File

@ -75,7 +75,29 @@ def setup_torchbench_cwd():
return original_dir
process_train_model_output = {}
def process_hf_reformer_output(out):
assert isinstance(out, list)
# second output is unstable
return [elem for i, elem in enumerate(out) if i != 1]
def process_hf_whisper_output(out):
out_ret = []
for i, elem in enumerate(out):
if i == 0:
if elem is not None:
assert isinstance(elem, dict)
out_ret.append({k: v for k, v in elem.items() if k != "logits"})
elif i != 1:
out_ret.append(elem)
return out_ret
process_train_model_output = {
"hf_Reformer": process_hf_reformer_output,
"hf_Whisper": process_hf_whisper_output,
}
class TorchBenchmarkRunner(BenchmarkRunner):
@ -205,10 +227,12 @@ class TorchBenchmarkRunner(BenchmarkRunner):
"drq",
"hf_Reformer",
"DALLE2_pytorch",
"hf_BigBird",
"detectron2_maskrcnn_r_50_fpn",
"detectron2_maskrcnn_r_101_fpn",
"vision_maskrcnn",
"doctr_reco_predictor",
"hf_T5_generate",
}
def load_model(
@ -371,6 +395,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
and hasattr(model.config, "use_cache")
):
model.config.use_cache = False
if model_name == "hf_T5_generate":
model.model.config.use_cache = False
self.validate_model(model, example_inputs)
return device, benchmark.name, model, example_inputs, batch_size

View File

@ -5,6 +5,8 @@ batch_size:
demucs: 4
dlrm: 1024
densenet121: 4
hf_Reformer: 4
hf_T5_base: 4
timm_efficientdet: 1
llama_v2_7b_16h: 1
# reduced from 16 due to cudagraphs OOM in TorchInductor dashboard
@ -28,6 +30,7 @@ tolerance:
- alexnet
- attention_is_all_you_need_pytorch
- densenet121
- hf_Albert
- vgg16
- mobilenet_v3_large
- nvidia_deeprecommender
@ -37,16 +40,20 @@ tolerance:
- soft_actor_critic
- tacotron2
- yolov3
- timm_efficientdet
- timm_efficientnet
- squeezenet1_1
higher_fp16:
- doctr_reco_predictor
- drq
- hf_Whisper
- phlippe_resnet
higher_bf16:
- doctr_reco_predictor
- drq
- hf_Whisper
# These models need higher tolerance for xpu devices with bf16
higher_bf16_xpu:
@ -64,9 +71,16 @@ tolerance:
require_larger_multiplier_for_smaller_tensor:
- yolov3
- timm_efficientnet
# These benchmarks took >600s on an i9-11900K CPU
very_slow: &VERY_SLOW_MODELS
# 3339s
- hf_BigBird
# 3062s
- hf_Longformer
# 930s
- hf_T5
# These benchmarks took >60s on an i9-11900K CPU
@ -78,6 +92,18 @@ slow:
- demucs
# 242s
- fastNLP_Bert
# 221s
- hf_Albert
# 400s
- hf_Bart
# 334s
- hf_Bert
# 187s
- hf_DistilBert
# 470s
- hf_GPT2
# 141s
- hf_Reformer
# 317s
- speech_transformer
# 99s
@ -161,36 +187,11 @@ skip:
- hf_clip
# multi gpu not always available in benchmark runners
- simple_gpt_tp_manual
# skip hf and timm models in torchbench since
# there are already separate benchmarks for them
- hf_Albert
- hf_Bart
- hf_Bert
- hf_BigBird
- hf_DistilBert
- hf_GPT2
- hf_Longformer
- hf_Reformer
- hf_T5
- timm_efficientdet
- timm_efficientnet
- timm_nfnet
- timm_regnet
- timm_resnest
- timm_vision_transformer
- timm_vovnet
- hf_Bert_large
- hf_GPT2_large
- hf_Roberta_base
- hf_T5_base
- hf_T5_generate
- hf_T5_large
- hf_Whisper
- hf_distil_whisper
- timm_vision_transformer_large
device:
cpu:
# OOMs
- hf_T5_generate
# model is CUDA only
- cm3leon_generate
# timeout
@ -207,12 +208,16 @@ skip:
- torchrec_dlrm
- simple_gpt
# works on cuda, accuracy failure on cpu
- hf_Whisper
- stable_diffusion_text_encoder
- llava
- moco
# Skip these additional models when running on aarch64
cpu_aarch64: []
cpu_aarch64:
# timeout on aarch64
- timm_regnet
- timm_nfnet
cuda: []
@ -230,6 +235,7 @@ skip:
- sam_fast
# Model's DEFAULT_TRAIN_BSIZE is not implemented
- cm3leon_generate
- hf_T5_generate
- doctr_det_predictor
- doctr_reco_predictor
- moondream
@ -241,6 +247,9 @@ skip:
- cm3leon_generate
- detectron2_fcos_r_50_fpn
- fastNLP_Bert
- hf_Longformer
- hf_Reformer
- hf_T5_generate
- opacus_cifar10
- speech_transformer
@ -277,6 +286,9 @@ accuracy:
# Models too large to have eager, dynamo and fp64_numbers simultaneosuly
# even for 40 GB machine. We have tested accuracy for smaller version of
# these models
- hf_GPT2_large
- hf_T5_large
- timm_vision_transformer_large
# accuracy https://github.com/pytorch/pytorch/issues/93847
- maml
- llama_v2_7b_16h
@ -288,4 +300,5 @@ accuracy:
- pytorch_unet
max_batch_size:
hf_GPT2: 2
pytorch_unet: 2

View File

@ -4,6 +4,11 @@ LearningToPaint,1024
alexnet,1024
dcgan,1024
densenet121,64
hf_Albert,32
hf_Bart,16
hf_Bert,16
hf_GPT2,16
hf_T5,4
mnasnet1_0,256
mobilenet_v2,128
mobilenet_v3_large,256
@ -14,4 +19,10 @@ resnet50,128
resnext50_32x4d,128
shufflenet_v2_x1_0,512
squeezenet1_1,512
timm_nfnet,256
timm_efficientnet,128
timm_regnet,128
timm_resnest,256
timm_vision_transformer,256
timm_vovnet,128
vgg16,128

View File

@ -6,6 +6,18 @@ densenet121,512
dlrm,2048
fastNLP_Bert,8
functorch_dp_cifar10,1024
hf_Albert,8
hf_Bart,8
hf_Bert,8
hf_Bert_large,8
hf_DistilBert,8
hf_GPT2,8
hf_GPT2_large,1
hf_Longformer,4
hf_Reformer,8
hf_T5,4
hf_T5_base,1
hf_T5_large,1
LearningToPaint,96
lennard_jones,1024
mnasnet1_0,32
@ -23,6 +35,13 @@ shufflenet_v2_x1_0,64
speech_transformer,1024
squeezenet1_1,16
Super_SloMo,1024
timm_efficientnet,64
timm_nfnet,128
timm_regnet,32
timm_resnest,32
timm_vision_transformer,16
timm_vision_transformer_large,8
timm_vovnet,32
tts_angular,1024
vgg16,64
vision_maskrcnn,1

View File

@ -369,7 +369,7 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(8)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@ -391,6 +391,7 @@ class ComposabilityTest(MultiProcessTestCase):
],
)
def test_replicate_pp(self, ScheduleClass, MixedPrecisionParam):
_device_raii = torch.device(device_type, self.device)
torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
@ -603,281 +604,6 @@ class ComposabilityTest(MultiProcessTestCase):
torch.distributed.destroy_process_group()
@requires_accelerator_dist_backend(["nccl", "xccl"])
@skip_if_lt_x_gpu(4)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
)
@parametrize(
"ScheduleClass",
[
ScheduleGPipe,
Schedule1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ScheduleInterleavedZeroBubble,
],
)
def test_replicate_pp_grads(self, ScheduleClass):
torch.accelerator.set_device_index(self.device)
store = torch.distributed.FileStore(self.file_name, self.world_size)
torch.distributed.init_process_group(
backend=backend,
store=store,
rank=self.rank,
world_size=self.world_size,
)
dim = 8
pp_size = 2
num_microbatches = 8
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate", "shard"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
dp_group = device_mesh["replicate"].get_group()
# create "entire model"
total_layers = 8
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
ref_model = nn.Sequential(*copy.deepcopy(full_model)).to(self.device)
# dummy loss needed just to force backwards to run in schedule step
def loss_fn(y, target):
return y.sum()
# Simulate microbatch processing for reference model
def simulate_stage_forward_backward(model, inputs, labels):
"""Simulate forward and backward passes through stages for microbatch processing"""
batch_size, _ = inputs.shape
total_loss = 0
# Split inputs into microbatches
microbatch_size = batch_size // num_microbatches
for mb_idx in range(num_microbatches):
start_idx = mb_idx * microbatch_size
end_idx = start_idx + microbatch_size
mb_input = inputs[start_idx:end_idx]
mb_label = labels[start_idx:end_idx] if labels is not None else None
# Simulate stage-by-stage processing
if issubclass(ScheduleClass, PipelineScheduleSingle):
num_stages = pp_group.size()
layers_per_stage = total_layers // pp_group.size() # 8 // 2 = 4
else:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
layers_per_stage = total_layers // num_stages
# Forward pass through all stages
x = mb_input
for stage in range(num_stages):
start_layer = stage * layers_per_stage
end_layer = start_layer + layers_per_stage
# Process layers for this stage
for layer_idx in range(start_layer, min(end_layer, len(model))):
x = model[layer_idx](x)
mb_loss = loss_fn(x, mb_label)
total_loss += mb_loss
# Backward pass
mb_loss.backward()
return total_loss / num_microbatches
# Apply replicate to stage module
def apply_replicate(partial_model):
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
reshard_after_forward=False,
)
dp_model = replicate(partial_model, device_mesh=dp_mesh)
return dp_model
def pipelined_models_parameters(start_layer, model):
layer_idx = start_layer
for layer in model.children():
for name, param in layer.named_parameters():
updated_param_name = f"{layer_idx}.{name}"
pipeline_model_parameter_dict[updated_param_name] = param
layer_idx += 1
def check_gradient_parity(
pipeline_model_parameter_dict, ref_model_parameter_dict
):
for parameter in pipeline_model_parameter_dict:
assert parameter in ref_model_parameter_dict
pipeline_parameter = pipeline_model_parameter_dict[parameter]
if pipeline_parameter.grad is not None:
pipeline_parameter_grad = pipeline_parameter.grad.to_local()
ref_parameter = ref_model_parameter_dict[parameter]
if ref_parameter.grad is not None:
torch.testing.assert_close(
pipeline_parameter_grad,
ref_parameter.grad,
rtol=1e-4,
atol=1e-5,
)
else:
assert pipeline_parameter.grad is None
pipeline_model_parameter_dict = {}
# Attach to a schedule
if issubclass(ScheduleClass, PipelineScheduleSingle):
stage_idx = pp_group.rank()
# Calculate layers per stage correctly
layers_per_stage = total_layers // pp_group.size() # 8 // 2 = 4
start_layer = stage_idx * layers_per_stage
end_layer = start_layer + layers_per_stage
partial_model = nn.Sequential(*full_model[start_layer:end_layer])
partial_model.to(self.device)
dp_model = apply_replicate(partial_model)
pipelined_models_parameters(start_layer, dp_model)
pipeline_stage = PipelineStage(
dp_model,
stage_idx,
pp_group.size(),
self.device,
group=pp_group,
)
partial_models = [pipeline_stage.submod]
pipeline_schedule = ScheduleClass(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=False,
)
else:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
layers_per_stage = total_layers // num_stages
stages = []
for i in range(n_virtual):
stage_idx = pp_group.rank() + pp_group.size() * i
start_layer = stage_idx * layers_per_stage
end_layer = start_layer + layers_per_stage
# divide the model layers by the number of stages
partial_model = nn.Sequential(*full_model[start_layer:end_layer])
partial_model.to(self.device)
dp_model = apply_replicate(partial_model)
pipelined_models_parameters(start_layer, dp_model)
stage = PipelineStage(
dp_model,
stage_idx,
num_stages,
self.device,
group=pp_group,
)
stages.append(stage)
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
pipeline_schedule = ScheduleClass(
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=False,
)
optimizer_kwargs = {
"lr": 0.01,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": False,
"foreach": True,
}
optimizers = [
torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
for model in partial_models
]
ref_optimizer = torch.optim.AdamW(ref_model.parameters(), **optimizer_kwargs)
# Helper function to simulate all-reduce for reference model gradients
def simulate_all_reduce_grads(model, group):
"""Simulate all-reduce operation on gradients like replicate does"""
for param in model.parameters():
if param.grad is not None:
# Scale by the number of replicas (like replicate does)
param.grad.div_(group.size())
# Simulate all-reduce
torch.distributed.all_reduce(param.grad, group=group)
ref_model_parameter_dict = {}
ref_model_parameter_dict = dict(ref_model.named_parameters())
torch.manual_seed(42 + self.rank)
for _ in range(5):
for optimizer in optimizers:
optimizer.zero_grad()
ref_optimizer.zero_grad()
inputs = torch.rand((num_microbatches, dim), device=self.device)
labels = torch.rand((num_microbatches, dim), device=self.device)
# Ensure all ranks use the same inputs/labels for comparison
torch.distributed.broadcast(inputs, 0)
torch.distributed.broadcast(labels, 0)
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
# Run pipeline schedule
if pp_mesh.get_local_rank() == 0:
pipeline_schedule.step(inputs)
elif is_last_stage:
losses = []
pipeline_schedule.step(target=labels, losses=losses)
else:
pipeline_schedule.step()
# Run reference model simulation
if is_last_stage:
ref_loss = simulate_stage_forward_backward(ref_model, inputs, labels)
# Simulate all-reduce on reference model gradients
simulate_all_reduce_grads(ref_model, dp_group)
# Compare losses - only check on last stage where we have losses
if "losses" in locals() and len(losses) > 0:
# Average the microbatch losses to match ref_loss
avg_pipeline_loss = sum(losses) / len(losses)
torch.testing.assert_close(
avg_pipeline_loss, ref_loss, rtol=1e-4, atol=1e-5
)
else:
# For non-last stages, still run ref model to generate gradients
simulate_stage_forward_backward(ref_model, inputs, None)
simulate_all_reduce_grads(ref_model, dp_group)
# Step optimizers
for optimizer in optimizers:
optimizer.step()
ref_optimizer.step()
check_gradient_parity(
pipeline_model_parameter_dict, ref_model_parameter_dict
)
torch.distributed.destroy_process_group()
instantiate_parametrized_tests(ComposabilityTest)

View File

@ -54,9 +54,6 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
# Simple and boring model
class TestDummyModel(torch.nn.Module):
def __init__(self) -> None:
@ -75,12 +72,12 @@ class TestDummyModel(torch.nn.Module):
return x
def get_input(self):
return torch.rand(8, 8, device=device_type)
return torch.rand(8, 8, device="cuda")
class TestStatefulObj:
def __init__(self) -> None:
self.data = torch.rand(10, 10, device=device_type)
self.data = torch.rand(10, 10, device="cuda")
def state_dict(self):
return {"data": self.data}
@ -154,11 +151,10 @@ def _train(model, optim, train_steps=1):
class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@property
def backend(self):
curr_backend = dist.get_default_backend_for_device(self.device_type)
return f"cpu:gloo,{self.device_type}:{curr_backend}"
return "cpu:gloo,cuda:nccl"
def _create_model(self, compile, model_type, state_dict_options=None):
dummy_model = TestDummyModel().to(self.device_type)
dummy_model = TestDummyModel().cuda()
assert model_type in ModelType, f"{model_type} is not supported."
if model_type == ModelType.FSDP:
@ -211,8 +207,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
def _optim(self, model):
return torch.optim.Adam(model.parameters(), lr=0.1)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("compile", [True, False])
# TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
@ -221,8 +217,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
def test_e2e(self, compile, model_type):
self._run_e2e_test(compile, model_type)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize(
"cache_staged_state_dict, async_checkpointer_type, zoc",
@ -382,9 +378,9 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
# Validate that the non-stateful state dict was replaced with the loaded state dict
self.assertTrue(sd.set_sd_item_called)
@skip_if_lt_x_gpu(4)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_different_ordered_state_dict_keys(self):
"""Tests that the order of keys in the state dict does not matter when loading
If order was not accounted for, the following test would cause a deadlock.
@ -398,11 +394,11 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
def load_state_dict(self, state_dict):
tl = [
torch.ones(2, dtype=torch.int64, device=device_type)
torch.ones(2, dtype=torch.int64, device="cuda")
for _ in range(world_size)
]
t = (
torch.arange(2, dtype=torch.int64, device=device_type)
torch.arange(2, dtype=torch.int64, device="cuda")
+ 1
+ 2 * dist.get_rank()
)
@ -414,7 +410,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
def load_state_dict(self, state_dict):
tensor = (
torch.arange(2, dtype=torch.int64, device=device_type)
torch.arange(2, dtype=torch.int64, device="cuda")
+ 1
+ 2 * dist.get_rank()
)
@ -441,8 +437,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
DCP.save({}, checkpoint_id=self.temp_dir)
DCP.load({}, checkpoint_id=self.temp_dir)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_partial_load(self):
model, optim = self._create_model(compile=False, model_type=ModelType.NONE)
@ -480,8 +476,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
loaded_optim_state[k][optim_key], v[optim_key], offload_to_cpu=True
)
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_overwrite(self):
t1, t2 = torch.randn(10), torch.randn(10)

View File

@ -82,23 +82,22 @@ class FineTuningModel(nn.Module):
class TestFineTuning(DTensorTestBase):
@property
def world_size(self) -> int:
return min(4, torch.accelerator.device_count())
return min(4, torch.cuda.device_count())
@property
def backend(self):
curr_backend = dist.get_default_backend_for_device(self.device_type)
return f"cpu:gloo,{self.device_type}:{curr_backend}"
return "cpu:gloo,cuda:nccl"
def pretrain(self, pretrain_dir: str) -> None:
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = PreTrainedModel().to(self.device_type)
model = PreTrainedModel().cuda()
model = FSDP(model, device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training
for _ in range(3):
batch = torch.rand(32, DIM, device=self.device_type)
batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum()
loss.backward()
optim.step()
@ -115,7 +114,7 @@ class TestFineTuning(DTensorTestBase):
def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FineTuningModel().to(self.device_type)
model = FineTuningModel().cuda()
# TODO: make the parallelism more complicated, e.g., using 2D + DDP.
model = FSDP(model, use_orig_params=True, device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
@ -163,7 +162,7 @@ class TestFineTuning(DTensorTestBase):
# Training
for _ in range(3):
batch = torch.rand(32, DIM, device=self.device_type)
batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum()
loss.backward()
optim.step()

View File

@ -61,13 +61,13 @@ class TopModel(nn.Module):
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
@property
def world_size(self) -> int:
return min(8, torch.accelerator.device_count())
return min(8, torch.cuda.device_count())
@with_comms
@skip_if_lt_x_gpu(8)
@with_temp_dir
def test_e2e(self):
model = TopModel(self.rank).to(self.device_type)
model = TopModel(self.rank).cuda()
mesh_fsdp_tp = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=("dp", "tp")

View File

@ -32,13 +32,10 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
from torch.utils._pytree import tree_all_only
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class TestFullyShardWithDistributedStateDict(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.accelerator.device_count())
return min(4, torch.cuda.device_count())
def _get_base_model(self, mlp_dim: int = 2):
base_model = nn.Sequential(
@ -76,7 +73,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
for module in model2:
fully_shard(module, reshard_after_forward=False)
fully_shard(model2, reshard_after_forward=False)
inp = torch.randn((2, mlp_dim), device=device_type)
inp = torch.randn((2, mlp_dim), device="cuda")
model2(inp) # parameters are not resharded after this forward
# Check that state dict hooks reshard
osd_2 = model2.state_dict()
@ -134,7 +131,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
# Save state dict with model wrapped with FSDP1
fsdp1_model = FSDP(
self._get_base_model().to(device_type),
self._get_base_model().cuda(),
use_orig_params=True,
auto_wrap_policy=always_wrap_policy,
)
@ -210,14 +207,14 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
# init device mesh
dp_size = 2
global_mesh = init_device_mesh(
device_type,
"cuda",
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# Save state dict with original model
base_model = _get_base_model().to(device_type)
base_model = _get_base_model().cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
# Save state dict with model wrapped with FSDP1
@ -344,17 +341,15 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
# init device mesh
dp_size = 2
global_mesh_1d = init_device_mesh(
device_type, (self.world_size,), mesh_dim_names=("tp",)
"cuda", (self.world_size,), mesh_dim_names=("tp",)
)
global_mesh_2d = init_device_mesh(
device_type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
# Save state dict with original model
base_model = _get_base_model().to(device_type)
base_model = _get_base_model().cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
# Save state dict with TP model
@ -500,10 +495,10 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
# init device mesh
dp_size = 2
global_mesh_1d = init_device_mesh(
device_type, (self.world_size,), mesh_dim_names=("tp",)
"cuda", (self.world_size,), mesh_dim_names=("tp",)
)
global_mesh_2d = init_device_mesh(
device_type,
"cuda",
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
@ -511,7 +506,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
for save_full_state_dict in [True, False]:
# Save state dict with original model
base_model = _get_base_model(mlp_dim).to(device_type)
base_model = _get_base_model(mlp_dim).cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
# Save state dict with FSDP2 + TP model

View File

@ -32,10 +32,7 @@ from torch.distributed.checkpoint.planner import (
)
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
@ -43,9 +40,6 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -68,8 +62,8 @@ class TestModule(torch.nn.Module):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
@ -81,12 +75,12 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_tensor_metadata_with_missing_rank_spec(self) -> None:
spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:1/{device_type}:1",
"rank:1/cuda:1",
],
)
@ -98,14 +92,14 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_default_metadata(self) -> None:
device = f"{device_type}:{dist.get_rank()}"
device = f"cuda:{dist.get_rank()}"
spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
@ -239,14 +233,12 @@ class TestDistributedFailure(ShardedTensorTestBase):
def get_spec(self):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:{r}/{device_type}:{r}" for r in range(dist.get_world_size())
],
placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_dummy_writer_works(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
@ -258,7 +250,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_dummy_reader_works(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
@ -321,7 +313,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_save_error_handling(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
@ -355,7 +347,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_load_error_handling(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),

View File

@ -106,7 +106,7 @@ class DTensorPlanner(DTensorTestBase):
replicated_dt,
submesh_sharded_dt,
submesh_replicated_dt,
).to(self.device_type)
).cuda()
return (
model,
@ -135,7 +135,7 @@ class DTensorPlanner(DTensorTestBase):
(
'rdt',
DTensor(
local_tensor=tensor([4., 5., 6., 7.], device=f'{self.device_type}:0'),
local_tensor=tensor([4., 5., 6., 7.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 1, 2, 3]),
placements=[Replicate()]
)
@ -143,7 +143,7 @@ class DTensorPlanner(DTensorTestBase):
(
'sdt',
DTensor(
local_tensor=tensor([0.], device=f'{self.device_type}:0'),
local_tensor=tensor([0.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 1, 2, 3]),
placements=[Shard(dim=0)])
),
@ -151,7 +151,7 @@ class DTensorPlanner(DTensorTestBase):
(
'submesh_sdt',
DTensor(
local_tensor=tensor([8., 9.], device=f'{self.device_type}:0'),
local_tensor=tensor([8., 9.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 2]),
placements=[Shard(dim=0)]
),
@ -159,7 +159,7 @@ class DTensorPlanner(DTensorTestBase):
(
'submesh_rdt',
DTensor(
local_tensor=tensor([12., 13., 14., 15.], device=f'{self.device_type}:0'),
local_tensor=tensor([12., 13., 14., 15.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 2]),
placements=[Replicate()]
)
@ -189,7 +189,7 @@ class DTensorPlanner(DTensorTestBase):
(
'rdt',
DTensor(
local_tensor=tensor([40., 50., 60., 70.], device=f'{self.device_type}:0'),
local_tensor=tensor([40., 50., 60., 70.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 1, 2, 3]),
placements=[Replicate()],
)
@ -197,7 +197,7 @@ class DTensorPlanner(DTensorTestBase):
(
'sdt',
DTensor(
local_tensor=tensor([0.], device=f'{self.device_type}:0'),
local_tensor=tensor([0.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 1, 2, 3]),
placements=[Shard(dim=0)],
)
@ -205,14 +205,14 @@ class DTensorPlanner(DTensorTestBase):
(
'submesh_sdt',
DTensor(
local_tensor=tensor([80., 90.], device=f'{self.device_type}:0'),
local_tensor=tensor([80., 90.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 2]),
placements=[Shard(dim=0)]
)
),
('submesh_rdt',
DTensor(
local_tensor=tensor([120., 130., 140., 150.], device=f'{self.device_type}:0'),
local_tensor=tensor([120., 130., 140., 150.], device='cuda:0'),
device_mesh=DeviceMesh:([0, 2]),
placements=[Replicate()]
)

View File

@ -278,7 +278,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
"""
Test dtensor checkpoint resharding with dtensor containing empty shards.
"""
tensor = torch.rand(1).to(self.device_type)
tensor = torch.rand(1).cuda()
mesh = init_device_mesh(self.device_type, (self.world_size,))
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
ref_state_dict = {"dtensor": dtensor}
@ -288,7 +288,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
storage_writer=dist_cp.FileSystemWriter(path=self.temp_dir),
)
tensor = torch.rand(1).to(self.device_type)
tensor = torch.rand(1).cuda()
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
state_dict = {"dtensor": dtensor}

View File

@ -23,10 +23,7 @@ from torch.distributed.checkpoint import (
)
from torch.distributed.checkpoint._extension import ZStandard
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -48,9 +45,6 @@ from torch.testing._internal.distributed.checkpoint_utils import (
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@ -172,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
def test_read_write_shard_tensor(self, extensions) -> None:
paths = [tempfile.mkdtemp()]
@ -184,8 +178,8 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
@ -234,16 +228,14 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
res = (
torch.zeros(tensor.shape, device=f"{device_type}:0")
if dist.get_rank() == 0
else None
torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
)
tensor.gather(out=res)
return res
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_load_with_different_shard_plan(self) -> None:
path = self.get_file_path()
@ -255,18 +247,18 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
),
# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
f"rank:1/{device_type}:1",
f"rank:0/{device_type}:0",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:1/cuda:1",
"rank:0/cuda:0",
],
),
# This requires the tensors to be [10, 20]
@ -275,27 +267,27 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[2, 20],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[2, 0],
shard_sizes=[1, 20],
placement=f"rank:1/{device_type}:1",
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[3, 0],
shard_sizes=[3, 20],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[6, 0],
shard_sizes=[3, 20],
placement=f"rank:1/{device_type}:1",
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[9, 0],
shard_sizes=[1, 20],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
]
),
@ -305,12 +297,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[8, 20],
placement=f"rank:1/{device_type}:1",
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[8, 0],
shard_sizes=[2, 20],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
]
),
@ -358,7 +350,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_load_rowwise_to_colwise(self) -> None:
path = self.get_file_path()
self.assertEqual(self.world_size, dist.get_world_size())
@ -367,8 +359,8 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
src_spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
@ -376,8 +368,8 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
dst_spec = ChunkShardingSpec(
dim=1,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
@ -385,14 +377,14 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
model_to_save = MyShardedModel3(src_spec).to(dist.get_rank())
model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
model_to_save._register_state_dict_hook(state_dict_hook)
state_dict_to_save = model_to_save.state_dict()
fs_writer = FileSystemWriter(path=path)
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
model_to_load = MyShardedModel3(dst_spec).to(dist.get_rank())
model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
model_to_load._register_state_dict_hook(state_dict_hook)
state_dict_to_load_to = model_to_load.state_dict()
@ -409,7 +401,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_save_load_bytes(self) -> None:
path = self.get_file_path()
@ -428,7 +420,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
def test_switch_between_sharded_tensor_to_tensor(self) -> None:
path = self.get_file_path()
tensor_size = 32
@ -437,17 +429,17 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
),
ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
f"rank:1/{device_type}:1",
f"rank:0/{device_type}:0",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:1/cuda:1",
"rank:0/cuda:0",
],
),
EnumerableShardingSpec(
@ -455,12 +447,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ShardMetadata(
shard_offsets=[0],
shard_sizes=[8],
placement=f"rank:1/{device_type}:1",
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[8],
shard_sizes=[tensor_size - 8],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
]
),
@ -469,12 +461,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
ShardMetadata(
shard_offsets=[0],
shard_sizes=[10],
placement=f"rank:0/{device_type}:0",
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[10],
shard_sizes=[tensor_size - 10],
placement=f"rank:1/{device_type}:1",
placement="rank:1/cuda:1",
),
]
),
@ -520,15 +512,15 @@ class TestDistributedStateDictSaveLoadWithCaching(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_accelerator_dist_backend()
@requires_nccl()
@with_temp_dir
def test_read_write_shard_tensor(self) -> None:
# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
spec = ChunkShardingSpec(
dim=0,
placements=[
f"rank:0/{device_type}:0",
f"rank:1/{device_type}:1",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)

View File

@ -22,9 +22,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class SimpleModelUneven(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -43,7 +40,7 @@ class SimpleModelUneven(nn.Module):
return x
def get_input(self):
return torch.rand(4, 5, device=device_type)
return torch.rand(4, 5, device="cuda")
class TestFormatUtils(DTensorTestBase):
@ -90,7 +87,7 @@ class TestFormatUtils(DTensorTestBase):
# Load into a sharded model
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = SimpleModelUneven().to(self.device_type)
model = SimpleModelUneven().cuda()
model = FSDP(
model,
device_mesh=device_mesh,

View File

@ -21,8 +21,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class FsdpModelStateCheckpoint(DTensorTestBase):
@property
def backend(self):
curr_backend = dist.get_default_backend_for_device(self.device_type)
return f"cpu:gloo,{self.device_type}:{curr_backend}"
return "cpu:gloo,cuda:nccl"
def _test_fsdp_model_state(self, process_group) -> None:
CHECKPOINT_DIR = self.temp_dir
@ -68,8 +67,8 @@ class FsdpModelStateCheckpoint(DTensorTestBase):
self.assertEqual(model.weight, model_2.weight)
self.assertEqual(model.bias, model_2.bias)
@skip_if_lt_x_gpu(2)
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_fsdp_model_state_no_resharding(self):
self._test_fsdp_model_state(process_group=None)
@ -89,8 +88,8 @@ class FsdpModelStateCheckpoint(DTensorTestBase):
return my_fsdp
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_fsdp_model_state_with_resharding(self):
self._test_fsdp_model_state(process_group=self._create_new_dist_group())

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
@ -29,9 +28,8 @@ class FsdpOptimStateCheckpoint(DTensorTestBase):
layer3_weight_dim = self.world_size * 3
class TestDummyModel(torch.nn.Module):
def __init__(self, device_type) -> None:
def __init__(self) -> None:
super().__init__()
self.device_type = device_type
self.net1 = nn.Sequential(nn.Linear(8, layer1_weight_dim), nn.ReLU())
self.net2 = nn.Sequential(
nn.Linear(layer1_weight_dim, layer2_weight_dim), nn.ReLU()
@ -44,18 +42,17 @@ class FsdpOptimStateCheckpoint(DTensorTestBase):
return self.net3(self.net2(self.net1(x)))
def get_input(self):
return torch.rand(8, 8, device=self.device_type)
return torch.rand(8, 8, device="cuda")
model = TestDummyModel(self.device_type).to(self.device_type)
model = TestDummyModel().cuda()
return model
@property
def backend(self):
curr_backend = dist.get_default_backend_for_device(self.device_type)
return f"cpu:gloo,{self.device_type}:{curr_backend}"
return "cpu:gloo,cuda:nccl"
@skip_if_lt_x_gpu(2)
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
@parametrize("pass_planner", [True, False])
def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:

View File

@ -30,7 +30,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
def test_fsdp_to_tp(self):
CHECKPOINT_DIR = self.temp_dir
model = MLPModule(self.device_type).to(self.rank)
model = MLPModule(self.device_type).cuda(self.rank)
# create a FSDP wrapped model
fsdp_model = FSDP(model, use_orig_params=True)
@ -49,7 +49,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
# create a TP wrapped model
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
model = MLPModule(self.device_type).to(self.rank)
model = MLPModule(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),
@ -60,7 +60,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
# Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict().
torch.manual_seed(0)
inp = torch.rand(20, 10).to(self.rank)
inp = torch.rand(20, 10).cuda(self.rank)
output = tp_model(inp)
output.sum().backward()
optimizer.step()

View File

@ -587,7 +587,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
print("safetensors not installed")
return
tensor = torch.rand(1).to(self.device_type)
tensor = torch.rand(1).cuda()
mesh = init_device_mesh(self.device_type, (self.world_size,))
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
ref_state_dict = {"dtensor": dtensor}
@ -599,7 +599,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
),
)
tensor = torch.rand(1).to(self.device_type)
tensor = torch.rand(1).cuda()
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
state_dict = {"dtensor": dtensor}

View File

@ -2,7 +2,6 @@
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
import torch.nn.functional as F
@ -30,9 +29,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -48,7 +44,7 @@ class SimpleModel(torch.nn.Module):
return x
def get_input(self):
return torch.rand(4, 5, device=device_type)
return torch.rand(4, 5, device="cuda")
class SimpleModelUneven(torch.nn.Module):
@ -68,17 +64,16 @@ class SimpleModelUneven(torch.nn.Module):
return x
def get_input(self):
return torch.rand(4, 5, device=device_type)
return torch.rand(4, 5, device="cuda")
class TestHSDPCheckpoint(DTensorTestBase):
@property
def backend(self):
curr_backend = dist.get_default_backend_for_device(self.device_type)
return f"cpu:gloo,{self.device_type}:{curr_backend}"
return "cpu:gloo,cuda:nccl"
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("is_even_sharded_model", [True, False])
def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
@ -87,7 +82,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model = FSDP(
simple_model().to(self.device_type),
simple_model().cuda(),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_mesh=mesh_2d,
)
@ -135,8 +130,8 @@ class TestHSDPCheckpoint(DTensorTestBase):
self.assertEqual(v1.placements, v2.placements)
self.assertEqual(v1.to_local(), v2.to_local())
@skip_if_lt_x_gpu(4)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("is_even_sharded_model", [True, False])
def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
@ -146,7 +141,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
# save the hsdp model state_dict
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
hsdp_model = FSDP(
simple_model().to(self.device_type),
simple_model().cuda(),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_mesh=mesh_2d,
)
@ -164,7 +159,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
# initialize a fsdp model to load checkpoint into
mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
fsdp_model = FSDP(
simple_model().to(self.device_type),
simple_model().cuda(),
device_mesh=mesh_1d,
)
FSDP.set_state_dict_type(

View File

@ -1,13 +1,11 @@
# Owner(s): ["oncall: distributed"]
import logging
import unittest
from datetime import timedelta
from typing import Optional
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import (
init_from_local_shards,
@ -25,11 +23,10 @@ from torch.distributed.checkpoint._pg_transport import (
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
at_least_x_gpu,
HAS_ACCELERATOR,
MultiProcContinuousTest,
requires_accelerator_dist_backend,
requires_nccl,
)
from torch.testing._internal.common_utils import (
run_tests,
@ -38,8 +35,6 @@ from torch.testing._internal.common_utils import (
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
logger = logging.getLogger(__name__)
@ -165,9 +160,9 @@ def _test_pg_transport_with_mixed_content(self, device) -> None:
def _test_pg_transport_with_sharded_tensor(self, device) -> None:
# Set current accelerator device for NCCL/XCCL
if device.type == "cuda" or device.type == "xpu":
torch.accelerator.set_device_index(device)
# Set current CUDA device for NCCL
if device.type == "cuda":
torch.cuda.set_device(device)
state_dict = _create_sharded_tensor_state_dict(self.rank, self.world_size, device)
transport = PGTransport(_get_default_group(), timedelta(seconds=10), device)
@ -232,36 +227,34 @@ class PgTransportCPU(MultiProcContinuousTest):
_test_pg_transport_with_sharded_tensor(self, self.device)
class PgTransportGPU(MultiProcContinuousTest):
class PgTransportCUDA(MultiProcContinuousTest):
world_size = 2
timeout: timedelta = timedelta(seconds=20)
@classmethod
def backend_str(cls) -> Optional[str]:
return dist.get_default_backend_for_device(cls.device_type())
return "nccl"
@classmethod
def device_type(cls) -> str:
return "cuda"
@property
def device(self) -> torch.device:
return torch.device(f"{self.device_type()}:{self.rank}")
@requires_accelerator_dist_backend()
@skip_but_pass_in_sandcastle_if(
not at_least_x_gpu(2), "test requires 2+ accelerators"
)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_pg_transport(self) -> None:
_test_pg_transport(self, self.device)
@requires_accelerator_dist_backend()
@skip_but_pass_in_sandcastle_if(
not at_least_x_gpu(2), "test requires 2+ accelerators"
)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_pg_transport_with_mixed_content(self) -> None:
_test_pg_transport_with_mixed_content(self, self.device)
@requires_accelerator_dist_backend()
@skip_but_pass_in_sandcastle_if(
not at_least_x_gpu(2), "test requires 2+ accelerators"
)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_pg_transport_with_sharded_tensor(self) -> None:
_test_pg_transport_with_sharded_tensor(self, self.device)
@ -585,10 +578,13 @@ class TestPGTransportEdgeCases(TestCase):
self.pg.send = MagicMock(return_value=self.mock_work)
self.pg.recv = MagicMock(return_value=self.mock_work)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
def test_send_checkpoint_with_cpu_tensors(self):
"""Test send_checkpoint with CPU tensors when device is accelerator."""
device = torch.device(f"{device_type}:0")
"""Test send_checkpoint with CPU tensors when device is CUDA."""
# Skip if CUDA is not available
if not torch.cuda.is_available():
self.skipTest("CUDA not available")
device = torch.device("cuda:0")
# Create a state dict with CPU tensors
state_dict = {
@ -596,7 +592,7 @@ class TestPGTransportEdgeCases(TestCase):
"cpu_tensor2": torch.randn(3, 4),
}
# Create transport with accelerator device
# Create transport with CUDA device
transport = PGTransport(self.pg, self.timeout, device)
# Call send_checkpoint

View File

@ -37,7 +37,7 @@ class TestSaveAndLoadAPI(DTensorTestBase):
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_auto_detect(self):
model = FSDP(MyTestModule().to(self.device_type))
model = FSDP(MyTestModule().cuda())
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FSDP(model, device_mesh=device_mesh)
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))

View File

@ -3,7 +3,6 @@
import dataclasses
import os
import tempfile
import unittest
from datetime import timedelta
import torch
@ -19,21 +18,14 @@ from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed.checkpoint._state_dict_stager import StateDictStager
from torch.distributed.checkpoint.staging import _ReplicationStager
from torch.distributed.tensor import DeviceMesh, distribute_tensor
from torch.testing._internal.common_distributed import (
HAS_ACCELERATOR,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def create_cpu_state_dict(state_dict):
cpu_state_dict = {}
for key, value in state_dict.items():
@ -41,16 +33,16 @@ def create_cpu_state_dict(state_dict):
return cpu_state_dict
def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
def compare_state_dicts(cuda_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
"""
Compare if two state dictionaries (one on GPU, one on CPU) are otherwise the same.
Compare if two state dictionaries (one on CUDA, one on CPU) are otherwise the same.
This function checks if the tensors in both state dictionaries have the same values,
shapes, dtypes, etc., ignoring the device difference. It also checks if tensors that
share storage in one state dict also share storage in the other.
Args:
gpu_state_dict: The state dictionary with tensors on GPU
cuda_state_dict: The state dictionary with tensors on CUDA
cpu_state_dict: The state dictionary with tensors on CPU
rtol: Relative tolerance for comparing tensor values
atol: Absolute tolerance for comparing tensor values
@ -60,65 +52,65 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
str: Error message if the state dictionaries are not equivalent, empty string otherwise
"""
# Track storage data pointers to check storage sharing
gpu_storage_ptrs = {}
cuda_storage_ptrs = {}
cpu_storage_ptrs = {}
def compare_objects(gpu_obj, cpu_obj, path=""):
def compare_objects(cuda_obj, cpu_obj, path=""):
# If objects are tensors, compare them
if isinstance(gpu_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor):
if isinstance(cuda_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor):
# Check if devices are as expected
if gpu_obj.device.type != device_type:
if cuda_obj.device.type != "cuda":
return (
False,
f"Expected accelerator tensor, got {gpu_obj.device.type} tensor at {path}",
f"Expected CUDA tensor, got {cuda_obj.device.type} tensor at {path}",
)
if cpu_obj.device.type != "cpu":
return (
False,
f"Expected CPU tensor, got {cpu_obj.device.type} tensor at {path}",
)
if gpu_obj.storage_offset() != cpu_obj.storage_offset():
if cuda_obj.storage_offset() != cpu_obj.storage_offset():
return (
False,
f"Storage offset mismatch at {path}: {gpu_obj.storage_offset()} vs {cpu_obj.storage_offset()}",
f"Storage offset mismatch at {path}: {cuda_obj.storage_offset()} vs {cpu_obj.storage_offset()}",
)
if not torch.equal(gpu_obj.cpu(), cpu_obj):
if not torch.equal(cuda_obj.cpu(), cpu_obj):
return (
False,
f"Tensors are not same at {path}",
)
# Track storage sharing
gpu_storage_ptr = gpu_obj.storage().data_ptr()
cuda_storage_ptr = cuda_obj.storage().data_ptr()
cpu_storage_ptr = cpu_obj.storage().data_ptr()
if gpu_storage_ptr in gpu_storage_ptrs:
# This GPU tensor shares storage with another tensor
if cuda_storage_ptr in cuda_storage_ptrs:
# This CUDA tensor shares storage with another tensor
# Check if the corresponding CPU tensors also share storage
if cpu_storage_ptr != gpu_storage_ptrs[gpu_storage_ptr]:
if cpu_storage_ptr != cuda_storage_ptrs[cuda_storage_ptr]:
return (
False,
f"Storage sharing mismatch: GPU tensors share storage but CPU tensors don't at {path}",
f"Storage sharing mismatch: CUDA tensors share storage but CPU tensors don't at {path}",
)
else:
# First time seeing this storage
gpu_storage_ptrs[gpu_storage_ptr] = cpu_storage_ptr
cpu_storage_ptrs[cpu_storage_ptr] = gpu_storage_ptr
cuda_storage_ptrs[cuda_storage_ptr] = cpu_storage_ptr
cpu_storage_ptrs[cpu_storage_ptr] = cuda_storage_ptr
return True, ""
# If objects are dictionaries, compare them recursively
elif isinstance(gpu_obj, dict) and isinstance(cpu_obj, dict):
if gpu_obj.keys() != cpu_obj.keys():
elif isinstance(cuda_obj, dict) and isinstance(cpu_obj, dict):
if cuda_obj.keys() != cpu_obj.keys():
return (
False,
f"Dictionary keys mismatch at {path}: {gpu_obj.keys()} vs {cpu_obj.keys()}",
f"Dictionary keys mismatch at {path}: {cuda_obj.keys()} vs {cpu_obj.keys()}",
)
for key in gpu_obj:
for key in cuda_obj:
result, error = compare_objects(
gpu_obj[key], cpu_obj[key], f"{path}.{key}" if path else key
cuda_obj[key], cpu_obj[key], f"{path}.{key}" if path else key
)
if not result:
return False, error
@ -126,37 +118,37 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
return True, ""
# If objects are lists, tuples, or sets, compare them recursively
elif isinstance(gpu_obj, (list, tuple, set)) and isinstance(
elif isinstance(cuda_obj, (list, tuple, set)) and isinstance(
cpu_obj, (list, tuple, set)
):
if len(gpu_obj) != len(cpu_obj):
if len(cuda_obj) != len(cpu_obj):
return (
False,
f"Collection length mismatch at {path}: {len(gpu_obj)} vs {len(cpu_obj)}",
f"Collection length mismatch at {path}: {len(cuda_obj)} vs {len(cpu_obj)}",
)
if type(gpu_obj) != type(cpu_obj):
if type(cuda_obj) != type(cpu_obj):
return (
False,
f"Collection type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
f"Collection type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
)
for i, (gpu_item, cpu_item) in enumerate(zip(gpu_obj, cpu_obj)):
result, error = compare_objects(gpu_item, cpu_item, f"{path}[{i}]")
for i, (cuda_item, cpu_item) in enumerate(zip(cuda_obj, cpu_obj)):
result, error = compare_objects(cuda_item, cpu_item, f"{path}[{i}]")
if not result:
return False, error
return True, ""
# If objects are custom classes, compare their attributes
elif hasattr(gpu_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
if type(gpu_obj) != type(cpu_obj):
elif hasattr(cuda_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
if type(cuda_obj) != type(cpu_obj):
return (
False,
f"Object type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
f"Object type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
)
result, error = compare_objects(
gpu_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__"
cuda_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__"
)
if not result:
return False, error
@ -165,18 +157,18 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
# For other types, use direct equality comparison
else:
if type(gpu_obj) != type(cpu_obj):
if type(cuda_obj) != type(cpu_obj):
return (
False,
f"Type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
f"Type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
)
if gpu_obj != cpu_obj:
return False, f"Value mismatch at {path}: {gpu_obj} vs {cpu_obj}"
if cuda_obj != cpu_obj:
return False, f"Value mismatch at {path}: {cuda_obj} vs {cpu_obj}"
return True, ""
# Start the recursive comparison
result, error = compare_objects(gpu_state_dict, cpu_state_dict)
result, error = compare_objects(cuda_state_dict, cpu_state_dict)
return result, error
@ -206,7 +198,7 @@ class FrozenDataClass:
class TestStateDictStager(TestCase):
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_views(self):
test_configs = [
(False, False), # pin_memory=False, share_memory=False,
@ -216,9 +208,9 @@ class TestStateDictStager(TestCase):
]
for pin_memory, share_memory in test_configs:
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
tensor1 = torch.randn(4, 4).to(device_type)
tensor1 = torch.randn(4, 4).cuda()
tensor2 = tensor1.view(16)
tensor3 = torch.randn(4, 4).to(device_type)
tensor3 = torch.randn(4, 4).cuda()
state_dict = {
"tensor1": tensor1,
"tensor2": tensor2,
@ -261,7 +253,7 @@ class TestStateDictStager(TestCase):
assert num_bytes == expected_bytes, (
f"Expected {expected_bytes} bytes, got {num_bytes}"
)
# Verify that the CPU state dict is equivalent to the original GPU state dict
# Verify that the CPU state dict is equivalent to the original CUDA state dict
result, error = compare_state_dicts(state_dict, cpu_state_dict)
assert result, f"State dicts are not equivalent: {error}"
@ -281,7 +273,7 @@ class TestStateDictStager(TestCase):
== recursive["type"].tensor1.storage().data_ptr()
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_caching(self):
"""
Test that the StateDictStager correctly caches and reuses storages.
@ -295,9 +287,9 @@ class TestStateDictStager(TestCase):
for pin_memory, share_memory in test_configs:
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
# Create test tensors and state dict
tensor1 = torch.randn(4, 4).to(device_type)
tensor1 = torch.randn(4, 4).cuda()
tensor2 = tensor1.view(16)
tensor3 = torch.randn(4, 4).to(device_type)
tensor3 = torch.randn(4, 4).cuda()
state_dict = {
"tensor1": tensor1,
"tensor2": tensor2,
@ -373,14 +365,14 @@ class TestStateDictStager(TestCase):
"Updated values should be reflected in the cached state dict"
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_tensor_attrs(self):
"""
Test that tensor attributes are preserved during stage with StateDictStager.
"""
tensor1 = torch.randn(4, 4).to(device_type)
tensor1 = torch.randn(4, 4).cuda()
tensor2 = tensor1.view(16)
tensor3 = torch.randn(4, 4).to(device_type)
tensor3 = torch.randn(4, 4).cuda()
# Add custom attributes to tensors
tensor1.a = 42
@ -419,22 +411,18 @@ class TestStateDictStager(TestCase):
"Tensor attribute 'c' has incorrect value"
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_different_dtypes(self):
"""
Test that StateDictStager works correctly with tensors of different data types.
"""
# Create tensors with different dtypes
tensors = {
"float32": torch.randn(4, 4, dtype=torch.float32).to(device_type),
"float64": torch.randn(4, 4, dtype=torch.float64).to(device_type),
"int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).to(
device_type
),
"int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).to(
device_type
),
"bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).to(device_type),
"float32": torch.randn(4, 4, dtype=torch.float32).cuda(),
"float64": torch.randn(4, 4, dtype=torch.float64).cuda(),
"int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).cuda(),
"int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).cuda(),
"bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).cuda(),
}
# Create a state dict with these tensors
@ -459,7 +447,7 @@ class TestStateDictStager(TestCase):
f"Tensor {dtype_name} has incorrect values",
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_empty_tensors(self):
"""
Test that StateDictStager works correctly with empty tensors.
@ -474,17 +462,15 @@ class TestStateDictStager(TestCase):
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
# Create empty tensors with different shapes
tensors = {
"empty_0d": torch.tensor([], dtype=torch.float32).to(device_type),
"empty_1d": torch.tensor([], dtype=torch.float32)
.reshape(0)
.to(device_type),
"empty_0d": torch.tensor([], dtype=torch.float32).cuda(),
"empty_1d": torch.tensor([], dtype=torch.float32).reshape(0).cuda(),
"empty_2d": torch.tensor([], dtype=torch.float32)
.reshape(0, 0)
.to(device_type),
.cuda(),
"empty_3d": torch.tensor([], dtype=torch.float32)
.reshape(0, 0, 0)
.to(device_type),
"zero_dim": torch.tensor(0.0).to(device_type), # scalar tensor
.cuda(),
"zero_dim": torch.tensor(0.0).cuda(), # scalar tensor
}
# Create a state dict with these tensors
@ -514,13 +500,13 @@ class TestStateDictStager(TestCase):
f"Tensor {tensor_name} has incorrect dtype",
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_complex_storage_sharing(self):
"""
Test that StateDictStager correctly handles complex storage sharing scenarios.
"""
# Create a base tensor
base_tensor = torch.randn(10, 10).to(device_type)
base_tensor = torch.randn(10, 10).cuda()
# Create various views and slices that share storage
view1 = base_tensor.view(100)
@ -596,13 +582,13 @@ class TestStateDictStager(TestCase):
"slice3 should reflect changes to base",
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_dataclasses(self):
# Create tensors
tensor1 = torch.randn(4, 4).to(device_type)
tensor2 = torch.randn(8, 8).to(device_type)
tensor3 = torch.randn(2, 6).to(device_type)
tensor4 = torch.randn(3, 5).to(device_type)
tensor1 = torch.randn(4, 4).cuda()
tensor2 = torch.randn(8, 8).cuda()
tensor3 = torch.randn(2, 6).cuda()
tensor4 = torch.randn(3, 5).cuda()
# Create dataclass instances
nested = NestedTensorStruct(tensor=tensor3)
@ -709,14 +695,14 @@ class TestStateDictStager(TestCase):
"CPU tensor should have the same values as the original tensor",
)
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
@requires_cuda
def test_tensor_pinned_and_shared(self):
"""
Test that verifies tensors are actually pinned and shared using tensor.is_pinned() and tensor.is_shared() methods.
"""
# Create test tensors
tensor1 = torch.randn(4, 4).to(device_type)
tensor2 = torch.randn(8, 8).to(device_type)
tensor1 = torch.randn(4, 4).cuda()
tensor2 = torch.randn(8, 8).cuda()
# Create a state dict with these tensors
state_dict = {
@ -811,17 +797,15 @@ class TestStateDictStager(TestCase):
class TestDTensorStateDictStager(DTensorTestBase):
@with_comms
@requires_accelerator_dist_backend()
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_dtensor(self):
"""
Test that StateDictStager works correctly with DTensors.
"""
# Create a DTensor
device_mesh = dist.DeviceMesh(
self.device_type, list(range(dist.get_world_size()))
)
tensor = torch.randn(3, 3, device=self.device_type)
device_mesh = dist.DeviceMesh("cuda", list(range(dist.get_world_size())))
tensor = torch.randn(3, 3, device="cuda")
dtensor = DTensor.from_local(tensor, device_mesh, [Shard(0)])
dtensor = dtensor + 1

View File

@ -47,7 +47,7 @@ class TestTpCheckpoint(DTensorTestBase):
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
# create model and move it to GPU with id rank
model = MLPModule(self.device_type).to(self.rank)
model = MLPModule(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),
@ -65,7 +65,7 @@ class TestTpCheckpoint(DTensorTestBase):
# Update the parameters so model.state_dict() will be different from original_state_dict.
torch.manual_seed(0)
inp = torch.rand(20, 10).to(self.rank)
inp = torch.rand(20, 10).cuda(self.rank)
output = model(inp)
output.sum().backward()
optimizer.step()
@ -94,7 +94,7 @@ class TestTpCheckpoint(DTensorTestBase):
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
# create model and move it to GPU with id rank
model = UnevenShardedModel(self.device_type).to(self.rank)
model = UnevenShardedModel(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),

View File

@ -199,7 +199,7 @@ class TestReaderView(TestCase):
class TestDistWrapper(DTensorTestBase):
@property
def world_size(self):
return min(4, torch.accelerator.device_count())
return min(4, torch.cuda.device_count())
@with_comms
@skip_if_lt_x_gpu(4)

View File

@ -2,6 +2,7 @@
# Owner(s): ["oncall: distributed"]
import copy
import logging
import tempfile
from dataclasses import dataclass
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw
@ -26,15 +27,7 @@ from torch.distributed.pipelining import (
ScheduleLoopedBFS,
ScheduleZBVZeroBubble,
)
from torch.distributed.pipelining.schedules import (
_Action,
_PipelineContext,
_PipelineScheduleRuntime,
_wait_batch_p2p,
FORWARD,
OVERLAP_F_B,
)
from torch.distributed.pipelining.stage import _PipelineStageBase # noqa: TC002
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
@ -522,7 +515,8 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
def test_grad_with_manual_interleaved(self, ScheduleClass):
@parametrize("use_new_runtime", [False, True])
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
mod, ref_mod, x, target, loss_fn = setup_models_and_data(
@ -549,6 +543,46 @@ class ScheduleTest(MultiProcContinuousTest):
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
# Handle new runtime testing
if use_new_runtime:
old_schedule = schedule
tmp_schedule = _PipelineScheduleRuntime(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
tmp_schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
# Test CSV round-trip for compute_comms schedule
schedule = _PipelineScheduleRuntime(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
with tempfile.NamedTemporaryFile() as f:
tmp_schedule._dump_csv(f.name)
f.seek(0)
schedule._load_csv(f.name, format="compute_comms")
one_more_schedule = _PipelineScheduleRuntime(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
one_more_schedule._prepare_schedule_with_comms(
schedule.pipeline_order_with_comms, format="compute_comms"
)
# Verify schedule consistency
self.assertEqual(
len(schedule.pipeline_order_with_comms),
len(one_more_schedule.pipeline_order_with_comms),
)
for rank in schedule.pipeline_order_with_comms:
self.assertEqual(
len(schedule.pipeline_order_with_comms[rank]),
len(one_more_schedule.pipeline_order_with_comms[rank]),
)
for a, b in zip(
schedule.pipeline_order_with_comms[rank],
one_more_schedule.pipeline_order_with_comms[rank],
):
self.assertEqual(a, b)
# Run pipeline with tensor leak checking
out = None
losses = []
@ -716,201 +750,6 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
mod, ref_mod, x, target, loss_fn = setup_models_and_data(
self.config, n_layers=n_stages
)
# Run reference
ref_out, ref_loss = run_reference_model(ref_mod, x, target, loss_fn)
# Create multi-stage pipeline with custom stage indices
num_microbatches = 8
stage_indices = rank_stages[self.rank]
stages, stage_modules, submod_names = create_multi_stage_pipeline(
self.config, mod, len(stage_indices), n_stages, stage_indices
)
# Use DualPipeV schedule as the base schedule
base_schedule = ScheduleDualPipeV(
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
base_schedule._prepare_schedule_with_comms(base_schedule.pipeline_order)
# Track both types of callbacks separately
forward_calls = []
overlap_calls = []
def forward_callback(action: _Action, ctx: _PipelineContext):
"""Custom callback for FORWARD computation that mimics the original implementation."""
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in schedule._stages
}
stage = stage_index_to_stage[action.stage_index]
stage_index = stage.stage_index
mb_index = action.microbatch_index
assert mb_index is not None
fwd_recv_ops = schedule.fwd_recv_ops
arg_mbs = ctx.arg_mbs
kwarg_mbs = ctx.kwarg_mbs
is_next_stage_on_this_rank = stage_index + 1 in stage_index_to_stage
is_prev_stage_on_this_rank = stage_index - 1 in stage_index_to_stage
# used in verification at the end
forward_calls.append((stage_index, mb_index))
if (
not stage.is_first
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
and not is_prev_stage_on_this_rank
):
assert (
stage_index,
mb_index,
) in fwd_recv_ops, f"Computing {action=} before receiving input"
from torch.distributed.pipelining.schedules import _wait_batch_p2p
_wait_batch_p2p(fwd_recv_ops.pop((stage_index, mb_index)))
output = stage.forward_one_chunk(
mb_index,
arg_mbs[mb_index], # type: ignore[index]
kwarg_mbs[mb_index], # type: ignore[index]
)
schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if is_next_stage_on_this_rank:
stage_index_to_stage[stage_index + 1].set_local_fwd_input(
output, mb_index
)
def overlap_callback(action: _Action, ctx: _PipelineContext):
"""Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in schedule._stages
}
assert action.sub_actions is not None
fwd_action = action.sub_actions[0]
bwd_action = action.sub_actions[1]
# Forward ========================================================
forward_callback(fwd_action, ctx)
overlap_calls.append(
(
fwd_action.stage_index,
fwd_action.microbatch_index,
bwd_action.stage_index,
bwd_action.microbatch_index,
)
)
# Backward ========================================================
backward_stage_index = bwd_action.stage_index
backward_stage = stage_index_to_stage[backward_stage_index]
backward_mb_index = bwd_action.microbatch_index
assert backward_mb_index is not None
bwd_recv_ops = schedule.bwd_recv_ops
is_next_stage_on_this_rank = (
backward_stage.stage_index + 1 in stage_index_to_stage
)
is_prev_stage_on_this_rank = (
backward_stage.stage_index - 1 in stage_index_to_stage
)
if (
not backward_stage.is_last
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
and not is_next_stage_on_this_rank
):
assert (
backward_stage_index,
backward_mb_index,
) in bwd_recv_ops, (
f"Attempted to run compute {action=} before receiving input"
)
_wait_batch_p2p(
bwd_recv_ops.pop((backward_stage_index, backward_mb_index))
)
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
schedule.backward_counter[backward_stage_index] += 1
last_backward = (
schedule.backward_counter[backward_stage_index]
== schedule._n_microbatches
)
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
backward_stage.backward_one_chunk(
backward_mb_index,
loss=loss,
full_backward=True,
last_backward=last_backward,
)
if last_backward:
backward_stage.scale_grads(grad_scale_factor)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if is_prev_stage_on_this_rank:
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
backward_stage.get_local_bwd_output(backward_mb_index),
backward_mb_index,
)
# Add the callback for FORWARD computation type
base_schedule.register_custom_function(FORWARD, forward_callback)
base_schedule.register_custom_function(OVERLAP_F_B, overlap_callback)
# Run pipeline - special case where first and last stage are on rank 0
out = None
losses = []
num_loops = 2
for _ in range(num_loops):
zero_gradients(stage_modules)
if self.rank == 0:
out = base_schedule.step(x, target=target, losses=losses)
else:
base_schedule.step()
dist.barrier()
# Verify results (rank 0 has both first and last stages)
if self.rank == 0:
torch.testing.assert_close(out, ref_out)
pipe_loss = sum(losses)
torch.testing.assert_close(pipe_loss, ref_loss)
# Verify overlap callbacks were called
self.assertGreater(
len(overlap_calls), 0, "OVERLAP_F_B callback should have been called"
)
# In a V-schedule with 8 microbatches and 2 stages per rank,
# rank 0 should have 32 calls (8 microbatches * 2 stages * 2 loops)
expected_count = num_microbatches * 2 * num_loops
self.assertEqual(len(forward_calls), expected_count)
# Verify all callback calls are for stages on this rank
for stage_idx, _ in forward_calls:
self.assertIn(
stage_idx,
stage_indices,
f"Callback called for stage {stage_idx} not on rank {self.rank}",
)
# Check gradients using helper method
check_gradients(self.config, stage_modules, ref_mod, submod_names)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, "NCCL test requires 2+ GPUs"
)
@parametrize(
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
@ -1008,7 +847,8 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
def test_non_symmetric_stage_ids(self, schedule_class):
@parametrize("use_new_runtime", [False, True])
def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1031,6 +871,13 @@ class CustomSchedulesTest(MultiProcContinuousTest):
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
)
if use_new_runtime:
old_schedule = schedule
schedule = _PipelineScheduleRuntime(
stages, num_microbatches, loss_fn=loss_fn
)
schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
# Run pipeline - special case where first and last stage are on rank 0
out = None
losses = []

View File

@ -336,6 +336,20 @@ class DeviceMeshTest(DTensorTestBase):
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
)
@with_comms
def test_set_mesh_dim_group_options(self):
device_type = (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
mesh_tensor = torch.arange(4).reshape(2, 2)
mesh = DeviceMesh(device_type, mesh_tensor)
# Fake pg only have BackendType as BackendType::CUSTOM.
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
@with_comms
def test_get_root_mesh_multiple_independent_meshes(self):
# regression test for issue #163330

View File

@ -893,29 +893,6 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(gn(inp), inp + 3)
self.assertEqual(cnts.frame_count, 1)
def test_step_unsupported(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = x + 1 + 2
torch._dynamo.step_unsupported()
return x + 4
inp = torch.ones(3)
self.assertEqual(fn(inp), inp + 7)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
def test_step_unsupported_empty_checkpoint(self):
@torch.compile(backend="eager")
def fn(x):
torch._dynamo.step_unsupported()
return x + 1
inp = torch.ones(3)
self.assertEqual(fn(inp), inp + 1)
@skipIfWindows(
msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows."
)

View File

@ -14,7 +14,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch.utils._pytree as python_pytree
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
from torch._dynamo.testing import skipIfNotPy312, skipIfOnlyNotPy312
from torch._dynamo.testing import skipIfNotPy312
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
IS_FBCODE,
@ -1015,7 +1015,6 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
"<Internal traceback>\n",
msg,
)
self.assertExpectedInline(
msg,
"""\
@ -1052,6 +1051,7 @@ from user code:
torch.compile(fn, backend="eager")(torch.randn(3))
# check the log for the 2nd torch._dynamo.graph_break()
self.assertExpectedInline(
munge_exc(records[-1].getMessage(), skip=0),
"""\
@ -1075,104 +1075,6 @@ User code traceback:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_fullgraph(self, records):
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(3)),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@skipIfOnlyNotPy312
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_python_versioning(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
s = munge_exc(records[0].getMessage(), skip=0)
self.assertExpectedInline(
s,
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_latest_bytecode_to_graph_break_python_versioning
fn(torch.ones(3))
========== most recent `torch.compile` tracing attempt started here ==========
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
TRACE RESUME 0 []
TRACE LOAD_FAST 'x' []
TRACE LOAD_CONST 1 [LazyVariableTracker()]
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
TRACE STORE_FAST 'y' [TensorVariable()]
TRACE LOAD_FAST 'x' []
TRACE LOAD_FAST 'y' [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE STORE_FAST 'z' [TensorVariable()]
TRACE LOAD_GLOBAL 'torch' []
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
pattern = r"TRACE.*"
s = munge_exc(records[0].getMessage(), skip=0)
matches = re.findall(pattern, s)
self.assertEqual((len(matches) > 10), True)
self.assertEqual((len(matches) <= 20), True)
self.assertIn("Most recent bytecode instructions traced (max 20):", s)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_graph_break_traceback_above_dynamo_shows_user_code(self, records):

View File

@ -1,270 +0,0 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch.fx.traceback as fx_traceback
import torch.utils.checkpoint
from torch._dynamo.test_case import run_tests
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.triton_utils import requires_cuda_and_triton
def checkpoint_wrapper(fn):
def inner(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
return inner
class AnnotateTests(torch._dynamo.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), torch.fx.GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))
def test_annotations(self):
class Mod(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
sin = torch.sin(x)
sub = sin - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
mul = sub * 2
div = mul / 3
return div
m = Mod()
backend = AotEagerAndRecordGraphs()
opt_m = torch.compile(m, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_m(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sub', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""\
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sub', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_activation_checkpointing(self):
@checkpoint_wrapper
def gn(x):
return torch.sin(x)
def fn(x):
with fx_traceback.annotate({"ac_sin": 0}):
ac = gn(x)
return torch.sigmoid(ac)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_x_', {'ac_sin': 0})
('get_attr', 'wrap_body_0', {'ac_sin': 0})
[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})]
('call_function', 'tag_activation_checkpoint', {'ac_sin': 0})
('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'cos', {'ac_sin': 0})
('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950
)
def test_activation_checkpointing_annotation_inside(self):
@checkpoint_wrapper
def gn(x):
x = x + 1
with fx_traceback.annotate({"stage": 0}):
p = torch.sin(x)
return p + 1
def fn(x):
ac = gn(x)
return torch.sigmoid(ac)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""('call_function', 'sin', {'stage': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'cos', {'stage': 0})
('call_function', 'mul', {'stage': 0})""", # noqa: B950
)
@requires_cuda_and_triton
def test_ac_flex_attention(self):
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a = 12
b = 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
def gn(x: torch.Tensor):
with fx_traceback.annotate({"compile_inductor": 0}):
return flex_attention(
x, x, x, block_mask=block_mask, score_mod=_squared
)
def fn(x):
x = torch.sin(x)
x = gn(x)
return torch.cos(x)
x = torch.randn(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0})
('get_attr', 'score_mod_0', {'compile_inductor': 0})
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'mask_fn_0', {'compile_inductor': 0})
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention', {'compile_inductor': 0})
('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""\
('get_attr', 'sdpa_score0', {'compile_inductor': 0})
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'sdpa_mask0', {'compile_inductor': 0})
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention', {'compile_inductor': 0})
('call_function', 'getitem', {'compile_inductor': 0})
('call_function', 'getitem_1', {'compile_inductor': 0})
('call_function', 'detach_1', {'compile_inductor': 0})
('call_function', 'detach_4', {'compile_inductor': 0})
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('placeholder', 'getitem', {'compile_inductor': 0})
('placeholder', 'detach_5', {'compile_inductor': 0})
('call_function', 'zeros', {'compile_inductor': 0})
('call_function', 'detach', {'compile_inductor': 0})
('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0})
[]
('get_attr', 'joint_graph0', {'compile_inductor': 0})
[]
('get_attr', 'mask_graph0', {'compile_inductor': 0})
[('call_function', 'ge', {'compile_inductor': 0})]
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
('call_function', 'getitem_3', {'compile_inductor': 0})
('call_function', 'getitem_4', {'compile_inductor': 0})
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
)
if __name__ == "__main__":
run_tests()

View File

@ -363,31 +363,6 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 13)
def test_cells_double_graph_break(self):
def f1(x1):
cell1 = x1 + 1
def f2(x2):
nonlocal cell1
cell1 += 2
torch._dynamo.graph_break()
torch._dynamo.graph_break()
return x2 + cell1
return f2(x1 + 4), cell1
def outer(x):
return f1(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
@ -536,7 +511,6 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 5)
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
self.assertEqual(cnts.op_count, 12)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
def test_nested_graph_break_in_try_block(self):
# NOTE: this also tests nested step_graph_break
@ -577,40 +551,13 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
print(ref, res)
self.assertEqual(ref, res)
# skip frame due to graph break in try block
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
self.assertEqual(cnts.frame_count, 4)
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
self.assertEqual(cnts.op_count, 9)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
def test_nested_step_unsupported(self):
global f1, f2, f3
def f1(x):
return x + 1
def f2(x):
x = x + 2
torch._dynamo.step_unsupported()
return f1(x) + 4
def f3(x):
x = x + 8
return f2(x) + 16
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
self.assertEqual(cnts.frame_count, 3)
# all ops except + 4
self.assertEqual(cnts.op_count, 4)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
if __name__ == "__main__":

View File

@ -7256,26 +7256,6 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
flag = False
self.assertEqual(fn(inp), opt_fn(inp))
def test_cells_unsupported_step_exception(self):
# This error happened because:
# - we were generating cells into a list on the stack
# - we encountered an unsupported step, resulting in a step graph break
# - we encounter an exception, which pops the stack until it reaches a certain length;
# the presence of the list of cells then messes things up.
cell = 0
@torch.compile(backend="eager")
def fn(x):
x = x + 1 + 2
torch._dynamo.step_unsupported()
with contextlib.nullcontext():
print(cell)
raise AssertionError
with self.assertRaises(AssertionError):
fn(torch.ones(3))
def test_unbind_copy_out(self):
def f(eye, out):
torch.unbind_copy(eye, out=out)

View File

@ -15660,6 +15660,11 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -5,7 +5,6 @@ import torch
import torch.fx as fx
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch.testing._internal.common_utils import TestCase
from torch.utils._ordered_set import OrderedSet
class TestAugmentedGraphHelper(TestCase):
@ -62,29 +61,9 @@ class TestAugmentedGraphHelper(TestCase):
]:
self.nodes[node.name] = node
# Get all nodes and compute ancestors
# Get all nodes and create tracker
self.all_nodes = list(self.graph.nodes)
self.node_ancestors = self._collect_node_ancestors(self.graph)
# Create tracker with ancestors
self.tracker = AugmentedGraphHelper(
self.graph, node_ancestors=self.node_ancestors
)
def _collect_node_ancestors(
self, graph: fx.Graph
) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Collect all ancestors for each node."""
from collections import defaultdict
from torch.utils._ordered_set import OrderedSet
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for node in graph.nodes:
for input_node in node.all_input_nodes:
ancestors[node].add(input_node)
ancestors[node] |= ancestors[input_node]
return ancestors
self.tracker = AugmentedGraphHelper(self.graph)
def get_deps(self, node):
"""Helper to get dependencies for a node."""

View File

@ -1,173 +0,0 @@
# Owner(s): ["module: inductor"]
import functools
import weakref
from collections import Counter
from typing import Callable, Optional
import torch
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary
def tensor_storage_id(tensor):
return tensor._typed_storage()._cdata
def device_filter(device):
return device.type == "cuda"
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
# counter of storage ids to live references
self.storage_count: dict[int, int] = Counter()
# live fake tensors
self.live_tensors = WeakIdKeyDictionary()
self.memory_use = 0
self.max_memory = 0
self.device_filter = device_filter
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs is not None else {}
rs = func(*args, **kwargs)
tree_map_only(torch._subclasses.FakeTensor, self.increase_memory_use, rs)
return rs
def increase_memory_use(self, tensor):
# already accounted for
if tensor in self.live_tensors:
return
if self.device_filter is not None and not self.device_filter(tensor.device):
return
self.live_tensors[tensor] = True
nbytes = tensor.untyped_storage().nbytes()
storage_id = tensor_storage_id(tensor)
# new storage, add to memory
if storage_id not in self.storage_count:
self.change_memory(nbytes)
self.storage_count[storage_id] += 1
# when this tensor dies, we need to adjust memory
weakref.finalize(
tensor, functools.partial(self.tensor_cleanup, storage_id, nbytes)
)
def tensor_cleanup(self, storage_id, nbytes):
self.storage_count[storage_id] -= 1
if self.storage_count[storage_id] == 0:
del self.storage_count[storage_id]
self.change_memory(-nbytes)
def change_memory(self, delta):
self.memory_use += delta
self.max_memory = max(self.memory_use, self.max_memory)
class TestMemoryProfilingResNet(InductorTestCase):
def test_simple_linear_layers(self):
"""Test with a simple sequential model with explicit weights on CUDA."""
def create_inputs_and_weights():
"""Create inputs and weights on CUDA."""
x = torch.randn(32, 1000, device="cuda")
w1 = torch.randn(500, 1000, device="cuda")
w2 = torch.randn(100, 500, device="cuda")
w3 = torch.randn(10, 100, device="cuda")
return x, w1, w2, w3
def fn(x, w1, w2, w3):
h1 = torch.nn.functional.linear(x, w1)
h1 = torch.nn.functional.relu(h1)
h2 = torch.nn.functional.linear(h1, w2)
h2 = torch.nn.functional.relu(h2)
out = torch.nn.functional.linear(h2, w3)
return out
with FakeTensorMode():
# Trace with make_fx
x, w1, w2, w3 = create_inputs_and_weights()
fx_graph = make_fx(fn)(x, w1, w2, w3)
# Static analysis
def is_releasable(node):
return node.op not in ("placeholder", "get_attr")
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
fx_peak = max(fx_memory_profile)
# Runtime profiling
profiler = FakeTensorMemoryProfilerMode()
with profiler:
x_runtime, w1_runtime, w2_runtime, w3_runtime = (
create_inputs_and_weights()
)
result = fn(x_runtime, w1_runtime, w2_runtime, w3_runtime)
del result
runtime_peak = profiler.max_memory
self.assertEqual(fx_peak, runtime_peak)
def test_conv_network(self):
"""Test with a convolutional network."""
def create_inputs_and_weights():
"""Create inputs and weights on CUDA."""
x = torch.randn(8, 3, 224, 224, device="cuda")
conv1_weight = torch.randn(64, 3, 3, 3, device="cuda")
conv2_weight = torch.randn(128, 64, 3, 3, device="cuda")
linear_weight = torch.randn(10, 128 * 56 * 56, device="cuda")
return x, conv1_weight, conv2_weight, linear_weight
def fn(x, conv1_weight, conv2_weight, linear_weight):
h = torch.nn.functional.conv2d(x, conv1_weight, padding=1)
h = torch.nn.functional.relu(h)
h = torch.nn.functional.max_pool2d(h, 2)
h = torch.nn.functional.conv2d(h, conv2_weight, padding=1)
h = torch.nn.functional.relu(h)
h = torch.nn.functional.max_pool2d(h, 2)
h = torch.flatten(h, 1)
out = torch.nn.functional.linear(h, linear_weight)
return out
with FakeTensorMode():
# Trace with make_fx
x, conv1_weight, conv2_weight, linear_weight = create_inputs_and_weights()
fx_graph = make_fx(fn)(x, conv1_weight, conv2_weight, linear_weight)
def is_releasable(node):
return node.op not in ("placeholder", "get_attr")
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
fx_peak = max(fx_memory_profile)
# Runtime profiling
profiler = FakeTensorMemoryProfilerMode()
with profiler:
x_runtime, conv1_w, conv2_w, linear_w = create_inputs_and_weights()
result = fn(x_runtime, conv1_w, conv2_w, linear_w)
del result
runtime_peak = profiler.max_memory
self.assertEqual(fx_peak, runtime_peak)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")

View File

@ -22,8 +22,7 @@ from torch.testing._internal.common_cuda import \
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, dtypesIfMPS, onlyCPU, onlyCUDA, precisionOverride,
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS,
expectedFailureMPSComplex, largeTensorTest)
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS, largeTensorTest)
from torch.testing._internal.common_methods_invocations import \
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
from torch.testing._internal.common_dtype import (
@ -1854,7 +1853,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0)
@coalescedonoff
@expectedFailureMPSComplex
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_norm(self, device, dtype, coalesced):

View File

@ -40,7 +40,6 @@ from .decorators import (
run,
set_stance,
skip_frame,
step_unsupported,
substitute_in_graph,
)
from .eval_frame import (
@ -103,7 +102,6 @@ __all__ = [
"error_on_graph_break",
"set_stance",
"skip_frame",
"step_unsupported",
"substitute_in_graph",
]

View File

@ -397,37 +397,19 @@ def create_call_function(nargs: int, push_null: bool) -> list[Instruction]:
return [create_instruction("CALL_FUNCTION", arg=nargs)]
def create_call_function_ex(
has_kwargs: bool, push_null: bool, ignore_314_kwargs_push: bool = False
) -> list[Instruction]:
def create_call_function_ex(has_kwargs: bool) -> list[Instruction]:
"""
Assumes that in 3.14+, if has_kwargs=False, there is NOT a NULL
on the TOS for the kwargs. This utility function will add a PUSH_NULL.
If the caller has already pushed a NULL for the kwargs, then set ignore_314_kwargs_push=True
so we don't push another NULL for the kwargs.
If the caller has already pushed a NULL, then do not call this function -
just use create_instruction("CALL_FUNCTION_EX", arg=...).
"""
if sys.version_info >= (3, 11):
output = []
if (
sys.version_info >= (3, 14)
and not has_kwargs
and not ignore_314_kwargs_push
):
output.append(create_instruction("PUSH_NULL"))
if push_null:
output.append(create_instruction("PUSH_NULL"))
# 3.13 swapped NULL and callable
# if flags == 1, 2 values popped - otherwise if flags == 0, 1 value
rots = (
int(has_kwargs) + 2
if sys.version_info >= (3, 13)
else int(has_kwargs) + 3
)
output.extend(create_rot_n(rots))
output.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
return output
return [create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs))]
insts = []
if sys.version_info >= (3, 14) and not has_kwargs:
insts.append(create_instruction("PUSH_NULL"))
insts.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
return insts
def create_call_method(nargs: int) -> list[Instruction]:
@ -533,8 +515,6 @@ def create_binary_slice(
def create_copy(i: int) -> list[Instruction]:
if sys.version_info >= (3, 11):
return [create_instruction("COPY", arg=i)]
if i == 1:
return [create_instruction("DUP_TOP")]
# COPY 4
# 0 1 2 3
# 3 1 2 0

View File

@ -519,7 +519,7 @@ class PyCodegen:
create_build_tuple(n),
self.create_load_const_unchecked(rot_n_helper(n)),
*create_rot_n(2),
*create_call_function_ex(False, False),
*create_call_function_ex(False),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
@ -540,33 +540,51 @@ class PyCodegen:
def make_function_with_closure(
self,
tx: "InstructionTranslatorBase",
fn_name: str,
code: types.CodeType,
push_null: bool,
num_on_stack: int = 0,
) -> None:
"""Creates a closure with code object `code`.
Expects the TOS to be the tuple of cells to use for this closure.
TOS will be popped to create the closure.
Args:
- fn_name: name of the function
- code: code object of the function
(does not include the tuple of cells on the TOS)
"""
freevars = code.co_freevars
assert freevars
output = self._output
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
def gen_fn() -> None:
self.clear_tos()
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
for var in freevars:
if tx is self.tx: # root frame
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))
else: # nested frame
assert var in tx.cell_and_freevars()
assert tx.post_prune_cell_and_freevars
self(tx.post_prune_cell_and_freevars[var])
output.append(create_build_tuple(len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
if push_null and sys.version_info >= (3, 11):
self.add_push_null(gen_fn)
output.extend(self.rot_n(num_on_stack + 2))
output.extend(self.rot_n(num_on_stack + 2))
else:
gen_fn()
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()
def create_load_python_module(self, mod: types.ModuleType) -> Instruction:

View File

@ -750,9 +750,6 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
return handle
# TODO - We want to run preserve_node_meta context manager here, but the CI
# fails (its unclear if the failures were flaky)
# @torch.fx.traceback.preserve_node_meta()
@preserve_global_state
def trace_frame(
code: types.CodeType,

View File

@ -296,14 +296,6 @@ def skip_frame(msg: str = "") -> None:
"""Force a skipped frame"""
@_disallow_in_graph_helper(throw_if_not_allowed=False)
def step_unsupported(msg: str = "") -> None:
"""Force a step unsupported graph break, which results in compiling
the traced FX graph so far, then skipping the rest of the frame.
In order to get expected behavior, there should be at least 2 ops
and a part of the code not contained in any try/with blocks."""
def forbid_in_graph(fn: Any) -> Any:
"""
Customize which functions TorchDynamo will assert are not present while tracing.

View File

@ -263,11 +263,6 @@ class RecompileLimitExceeded(Unsupported):
pass
# debug exception thrown when tracing torch._dynamo.step_unsupported()
class StepUnsupported(TorchDynamoException):
pass
class UnsafeScriptObjectError(TorchDynamoException):
pass

View File

@ -2763,18 +2763,5 @@
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0275": [
{
"Gb_type": "torch._dynamo.step_unsupported() with empty checkpoint",
"Context": "",
"Explanation": "traced torch._dynamo.step_unsupported(), but there is no checkpoint to step_graph_break from. This graph break is used for debugging only.",
"Hints": [
"Remove the torch._dynamo.step_unsupported() call.",
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some ",
"line of code that is not in a try/with block, and has an empty Python stack.",
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
]
}

View File

@ -79,7 +79,6 @@ from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
create_binary_slice,
create_binary_subscr,
create_build_tuple,
create_call_function,
create_dup_top,
create_instruction,
@ -1535,9 +1534,8 @@ class OutputGraph(OutputGraphCommon):
# Codegen stack convention before the unsupported instruction
# NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
# NOTE: stack/locals/cells must be codegen'd BEFORE the unsupported instruction, since the latter
# NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
# can arbitrarily mutate the former.
# [frame N cells, .., frame 1 cells],
# [
# frame N locals,
# frame N-1 stack + locals,
@ -1547,7 +1545,7 @@ class OutputGraph(OutputGraphCommon):
# see symbolic_convert.py for
# codegen stack convention after the unsupported instruction
# NOTE: cells will be loaded into continuation functions directly by symbolic_convert
# NOTE: cells are loaded into continuation functions directly
# this determines the order that values are codegen'd to the stack
stack_values_flat = [val for vals in all_stack_values for val in vals]
@ -1579,19 +1577,12 @@ class OutputGraph(OutputGraphCommon):
and not all_stack_locals_metas[-1].locals_null_keys
):
# optimization to generate better code in a common case
# codegen cells
# no side effects, so no new cells created - no need to call side_effects.codegen_save_tempvars
cell_cg = PyCodegen(self.root_tx)
self.codegen_cells(tx, cell_cg)
self.add_output_instructions(
[
# load in reverse since UNPACK_SEQUENCE will reverse
*self.compile_and_call_fx_graph(
tx, list(reversed(stack_values_flat)), root
),
*cell_cg.get_instructions(),
*create_swap(2),
create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
]
)
@ -1693,7 +1684,6 @@ class OutputGraph(OutputGraphCommon):
# store all stack and locals for each frame
# current state of the stack:
# all cells,
# *(frame N stack), *(frame N locals),
# ...,
# *(frame 1 stack), *(frame 1 locals)
@ -1708,7 +1698,6 @@ class OutputGraph(OutputGraphCommon):
)
# current state of the stack:
# all cells,
# *(frame N stack), [
# *(frame N locals),
# *(frame N-1 stack), *(frame N-1 locals),
@ -1769,8 +1758,7 @@ class OutputGraph(OutputGraphCommon):
# *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
# current state of the stack:
# all cells,
# *(frame N stack),
# *(frame N stack)
# frame N locals,
# frame N-1 stack, frame N-1 locals,
# ...
@ -1787,7 +1775,6 @@ class OutputGraph(OutputGraphCommon):
)
# final state of the stack before running the unsupported bytecode:
# all cells,
# [
# [frame N locals],
# [frame N-1 stack + locals],
@ -1844,31 +1831,6 @@ class OutputGraph(OutputGraphCommon):
return all_stack_locals_metas
def codegen_cells(self, tx: "InstructionTranslatorBase", cg: PyCodegen) -> None:
# no need to codegen if reason.graph_break is False (since we won't resume)
if self.compile_subgraph_reason.graph_break:
tx_cnt = 0
cur_tx: Optional[InstructionTranslatorBase] = tx
while cur_tx is not None:
# NOTE: we generate cells in the same order as resume_execution.py: sorted freevars + cellvars
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
freevars = tuple(sorted(cur_tx.cell_and_freevars()))
for cell in freevars:
if cur_tx is self.root_tx: # root frame
cg.append_output(cg.create_load_closure(cell))
else: # nested frame
assert cur_tx.post_prune_cell_and_freevars
cg(cur_tx.post_prune_cell_and_freevars[cell])
cg.append_output(create_build_tuple(len(freevars)))
cur_tx = cur_tx.parent
tx_cnt += 1
cg.append_output(create_instruction("BUILD_LIST", arg=tx_cnt))
else:
cg.append_output(create_instruction("BUILD_LIST", arg=0))
def codegen_suffix(
self,
tx: "InstructionTranslatorBase",
@ -1888,7 +1850,6 @@ class OutputGraph(OutputGraphCommon):
cg.store_attr(name)
self.side_effects.codegen_hooks(cg)
# TODO get debug_locals working for nested graph breaks
# Return variables used for logging at the end
for debug_var, args in tx.debug_locals:
cg.add_push_null(lambda: cg(debug_var))
@ -1897,9 +1858,6 @@ class OutputGraph(OutputGraphCommon):
cg.extend_output(create_call_function(len(args), False))
cg.extend_output([create_instruction("POP_TOP")])
# codegen cells before we apply side effects
self.codegen_cells(tx, cg)
cg.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(cg)

View File

@ -318,7 +318,6 @@ class ContinueExecutionCache:
argnames: tuple[str, ...],
argnames_null: tuple[str, ...],
setup_fns: tuple[ReenterWith, ...],
handle_inactive_ctx: bool,
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
null_idxes: tuple[int, ...],
@ -342,7 +341,6 @@ class ContinueExecutionCache:
argnames,
argnames_null,
setup_fns,
handle_inactive_ctx,
stack_ctx_vars,
argnames_ctx_vars,
null_idxes,
@ -434,7 +432,7 @@ class ContinueExecutionCache:
prefix.append(
create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
)
if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
if stack_i in stack_ctx_vars_d:
# NOTE: we assume that current stack var is a context manager CLASS!
# Load args for context variable and construct it
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
@ -461,11 +459,10 @@ class ContinueExecutionCache:
# NOTE: we assume that local var is a context manager CLASS!
# initialize inactive context vars in argnames
if handle_inactive_ctx:
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))
# 3.12+: store NULL into variables that were NULL
if argnames_null:
@ -527,7 +524,7 @@ class ContinueExecutionCache:
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
# finish the call
*create_call_function_ex(False, False),
*create_call_function_ex(False),
]
)
else:

View File

@ -43,7 +43,6 @@ import threading
import traceback
import types
import weakref
from collections import deque
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
@ -80,7 +79,6 @@ from .bytecode_transformation import (
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_const,
create_rot_n,
create_swap,
get_code_keys,
@ -98,7 +96,6 @@ from .exc import (
format_graph_break_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
StepUnsupported,
unimplemented_v2,
Unsupported,
)
@ -545,7 +542,6 @@ def log_graph_break(
reason: str = "",
exc_info: bool = False,
user_stack: Optional[StackSummary] = None,
latest_bytecode_log: Optional[str] = None,
) -> None:
if user_stack is None:
user_stack = torch._guards.TracingContext.extract_stack()
@ -608,10 +604,6 @@ def log_graph_break(
# This log line MUST contain the string "Graph break in user code",
# This log line is exercised from
# python test/dynamo/test_exc.py -k test_graph_break_log
if latest_bytecode_log and config.verbose:
user_stack_trace += "Most recent bytecode instructions traced (max 20):\n"
user_stack_trace += latest_bytecode_log
graph_break_log.debug(
user_stack_trace,
)
@ -677,20 +669,14 @@ def generic_jump(
)
self.pop()
if_next = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
) + self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
if_next = self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
)
if push:
self.push(value)
assert inst.target is not None
if_jump = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], inst.target
) + self.create_call_resume_at(
inst.target,
all_stack_locals_metadata,
if_jump = self.create_call_resume_at(
inst.target, all_stack_locals_metadata, False
)
if sys.version_info >= (3, 13):
@ -939,7 +925,6 @@ def break_graph_if_unsupported(
exc_info=True,
reason=str(excp),
user_stack=excp.real_stack,
latest_bytecode_log="\n".join(self.latest_bytecode_queue),
)
if self.maybe_has_backedge():
@ -975,7 +960,7 @@ def break_graph_if_unsupported(
all_stack_locals_metadata = self.output.compile_subgraph(
self, reason=reason, stack_pops=push - stack_effect
)
cg = PyCodegen(self.output.root_tx)
cg = PyCodegen(self)
cleanup: list[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
@ -1024,12 +1009,8 @@ def break_graph_if_unsupported(
for _ in range(push):
self.push(UnknownVariable())
self.output.add_output_instructions(
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
)
)
@ -1191,8 +1172,6 @@ class InstructionTranslatorBase(
parent: Optional[InstructionTranslatorBase]
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
package: Optional[CompilePackage]
latest_bytecode_queue: deque[str]
# Store the latest bytecode before graph_break() call by user
def mark_inconsistent_side_effects(self) -> None:
"""
@ -1360,17 +1339,6 @@ class InstructionTranslatorBase(
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
)
# Store the latest 20 bytecode execution for the process,
# Used repr for byte processing and limiting the length to 2048
try:
stack_repr = repr(self.stack)
except ValueError:
# Handle large integers that exceed sys.int_info.str_digits_check_threshold
stack_repr = "<self.stack repr truncated due to large integer>"
self.latest_bytecode_queue.append(
f"TRACE {inst.opname} {repr(inst.argval)} {stack_repr}"
)
self.update_block_stack(inst)
try:
@ -1383,22 +1351,9 @@ class InstructionTranslatorBase(
return True
except (ReturnValueOp, YieldValueOp):
return False
except (Unsupported, StepUnsupported) as e:
except Unsupported:
if self.current_speculation is None:
log.debug("empty checkpoint")
if isinstance(e, StepUnsupported):
unimplemented_v2(
gb_type="torch._dynamo.step_unsupported() with empty checkpoint",
context="",
explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint "
"to step_graph_break from. This graph break is used for debugging only.",
hints=[
"Remove the torch._dynamo.step_unsupported() call.",
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some "
"line of code that is not in a try/with block, and has an empty Python stack.",
*graph_break_hints.DYNAMO_BUG,
],
)
raise
log.debug("step triggered compile", exc_info=True)
@ -1472,110 +1427,24 @@ class InstructionTranslatorBase(
partial_convert=True,
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ],
if self.parent:
from .eval_frame import skip_code
# nested graph break
assert config.nested_graph_breaks
cg = PyCodegen(self.output.root_tx)
# codegen cells and frame values only for frame N
cg.extend_output(
[
*create_copy(2),
cg.create_load_const(0),
cg.create_binary_subscr(),
create_instruction("BUILD_LIST", arg=1),
*create_copy(2),
cg.create_load_const(0),
cg.create_binary_subscr(),
create_instruction("BUILD_LIST", arg=1),
]
)
# No need to fix stack, since stack is assumed to be empty here.
# Do NOT handle_inactive_ctx because we will be skipping this resume code.
leaf_resume_code, leaf_resume_name = self.create_resume(
0, continue_inst, all_stack_locals_metadata[0], [], cg, True, False
)
skip_code(leaf_resume_code)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ], [frame N cells], [frame N locals],
self.codegen_call_resume([leaf_resume_code], [leaf_resume_name], cg)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ], leaf_resume result
# add the leaf_resume result to frame N-1 stack
num_stack = all_stack_locals_metadata[1].num_stack
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=1),
*create_copy(2),
cg.create_load_const(1),
cg.create_binary_subscr(),
*create_binary_slice(num_stack, num_stack, True),
]
)
# pop frame N cells and locals
cg.extend_output(
[
*create_copy(1),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
*create_copy(2),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
]
)
# call the remaining resume functions
# current frame state
# [frame N-1 cells, ..., frame 1 cells],
# [
# frame N-1 stack (including leaf_resume result) + locals,
# ...,
# frame 1 stack + locals,
# ],
self.parent.push(UnknownVariable())
all_stack_locals_metadata[1].num_stack += 1
self.output.add_output_instructions(
cg.get_instructions()
+ self.parent.create_call_resume_at(
self.parent.next_instruction, all_stack_locals_metadata[1:]
self.create_call_resume_at(
continue_inst, all_stack_locals_metadata, True
)
)
else:
# pop cells
self.output.add_output_instructions(
[
*create_swap(2),
create_instruction("POP_TOP"),
]
)
# load locals from frame values
cg = PyCodegen(self.output.root_tx)
# current frame state
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ],
cg = PyCodegen(self)
self.output.add_output_instructions(
[
cg.create_load_const(-1),
@ -2640,12 +2509,8 @@ class InstructionTranslatorBase(
self.output.add_output_instructions([copy.copy(inst)])
self.popn(2)
self.output.add_output_instructions(
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
)
)
@ -2658,292 +2523,48 @@ class InstructionTranslatorBase(
)
@staticmethod
def codegen_return_with_pops(
inst: Instruction, num_stack: int
def codegen_return_after_compile_subgraph(
inst: Instruction, meta: StackLocalsMetadata
) -> list[Instruction]:
"""
Debug CPython expects the stack to be empty after the return.
Calling compile_subgraph will push cells and frame values to TOS.
This function will pop those 2 values from the stack before actually returning.
Expects the stack to be:
cells, frame values, current frame stack (0 or 1 values)
Pops cells and frame values, leaving the current frame stack as TOS.
A return instruction is included.
"""
insts = []
# NOTE: Debug CPython expects the stack to be empty after the return.
# Expect the current stack to be in the state
# cells, frame values, current frame stack (0 or 1 values)
assert num_stack <= 1
if num_stack == 1:
insts.extend(create_swap(3))
# [[]] (empty frame values), current frame stack (0 or 1 values)
assert meta.num_stack <= 1
if meta.num_stack == 1:
insts.extend(create_swap(2))
return_inst = (
create_instruction("RETURN_VALUE")
if inst.opname == "RETURN_VALUE"
else create_instruction("RETURN_CONST", argval=inst.argval)
)
insts.extend(
[create_instruction("POP_TOP"), create_instruction("POP_TOP"), return_inst]
)
insts.extend([create_instruction("POP_TOP"), return_inst])
return insts
def codegen_fix_leaf_stack(
self, meta: StackLocalsMetadata, resume_inst: Instruction
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: Any,
disable_current_frame_resume: bool,
) -> list[Instruction]:
"""
Fixes the stack values of the current/leaf frame (self).
Codegen resume function(s) and call it.
Assumes that the unsupported instruction has already been run.
Expects the TOS to be:
Expects the stack to be in the state:
[
frame N locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
], *(frame N stack (post-unsupported instruction))
Rearranges the TOS to become:
[
frame N stack + locals,
...,
frame 1 stack + locals
]
Args:
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
"""
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return []
# move frame N stack to the frame values list
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
meta.num_stack = current_num_stack
return [
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
create_load_const(0),
create_instruction("BINARY_SUBSCR"),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
def create_resume(
self,
idx: int,
resume_inst: Instruction,
meta: StackLocalsMetadata,
resume_codes: list[types.CodeType],
cg: PyCodegen,
is_leaf: bool,
handle_inactive_ctx: bool,
) -> tuple[types.CodeType, str]:
"""
Creates the resume function for the frame corresponding to `self`.
Expects the TOS to be:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
...,
frame 1 stack + locals
]
Some additional codegen may happen to prepare the frame stack + locals values for the generated resume function:
- inactive context variables in the stack and locals will be replaced by their types
- if the frame is a leaf frame, prune dead locals
Regardless of codegen, the stack will be left in the same state as before.
Args:
- idx: depth of this frame: 0 corresponds to the leaf frame (frame N), N-1 to the root frame (frame 1).
- resume_inst: the instruction that this frame should resume at
- meta: metadata for this frame returned from OutputGraph.compile_subgraph
- resume_codes: nested resume code objects generated from previous create_resume calls.
- cg: codegen object to output to
- is_leaf: True if `self` corresponds to the leaf frame.
- handle_inactive_ctx: If True, handles inactive context variables as described above. This is necessary
iff the resume function is traced
"""
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
if handle_inactive_ctx:
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, self.stack[j_orig])
# frames[idx][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(idx),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
)
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
# frames[idx][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(idx),
cg.create_binary_subscr(),
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
# If the resume instruction is a jump absolute, then resume
# at the target instead. This handles the case where we
# graph break again in a nested function before jump-resuming
# this frame.
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
# More locals may have been pruned in the current/leaf frame
# after the unsupported instruction (e.g. branch).
# There should not be any pruning in the other frames since
# the current instruction there should be a CALL.
if is_leaf:
reads = livevars_analysis(self.instructions, resume_inst)
all_argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
# codegen filter for current frame's locals
# current stack state: frames
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(idx),
cg.create_binary_subscr(),
create_dup_top(),
]
)
for arg in argnames:
# current stack state: frames, frames[i], *(prev locals), frames[i]
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(meta.num_stack + meta.locals_names[arg]),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(2),
# frames, frames i live locals, frames[i]
*create_binary_slice(meta.num_stack, None, True),
# frames[i][num_stack:] = frame i live locals
]
)
# current stack state: frames
else:
argnames = tuple(meta.locals_names.keys())
argnames_null = tuple(meta.locals_null_keys)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(self.stack) - len(meta.stack_null_idxes)
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
resume_inst.offset,
tuple(b.target.offset for b in self.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
handle_inactive_ctx,
tuple(meta.stack_ctx_args),
tuple(meta.locals_ctx_args),
tuple(meta.stack_null_idxes),
tuple(resume_codes),
)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
self.output.install_global_unsafe(resume_name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
resume_name,
types.FunctionType(new_code, self.f_globals, resume_name),
)
package_name = resume_name
if self.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], package_name
)
return new_code, resume_name
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: list[StackLocalsMetadata],
) -> list[Instruction]:
"""
Codegen all resume function(s) from the frame stack starting at `self` and call them.
Assumes that the unsupported instruction has already been run.
Expects the stack to be in the state:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
]
Pops the cells and frame values list from the stack.
Also includes a return instruction (stack expected to be empty after return).
], frame N stack (post-instruction)
Args:
- inst: the instruction of the current (deepest) frame to resume at
- all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
metadata such as local names, NULL positions, stack length, etc.
- disable_current_frame_resume: If True, disable tracing on the current frame's resume function.
Used for implementing nested step_graph_break.
"""
self.instruction_pointer = None
@ -2954,115 +2575,234 @@ class InstructionTranslatorBase(
all_stack_locals_metadata[0].num_stack = current_num_stack
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return self.codegen_return_with_pops(
inst, all_stack_locals_metadata[0].num_stack
return self.codegen_return_after_compile_subgraph(
inst, all_stack_locals_metadata[0]
)
cg = PyCodegen(self.output.root_tx)
# move frame N stack to the frame values list
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
cg.create_load_const(0),
cg.create_binary_subscr(),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
)
# current frame state
# [
# [frame N stack (fixed) + locals]
# ...,
# [frame 1 stack + locals]
# ],
#
txes = []
cur_tx: Optional[InstructionTranslatorBase] = self
idx = 0
resume_codes: list[types.CodeType] = []
resume_names = []
while cur_tx is not None:
txes.append(cur_tx)
cur_tx = cur_tx.parent
assert len(txes) == len(all_stack_locals_metadata)
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
for i, meta in enumerate(all_stack_locals_metadata):
if i == 0 and disable_current_frame_resume:
continue
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
# frames[i][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
)
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
# build the resume function for each frame
resume_names = []
resume_codes: list[types.CodeType] = []
for i, meta in enumerate(all_stack_locals_metadata):
cur_tx = txes[i]
if cur_tx is self:
resume_inst = inst
else:
resume_inst = cur_tx.next_instruction
resume_code, resume_name = cur_tx.create_resume(
idx,
resume_inst,
all_stack_locals_metadata[idx],
resume_codes,
cg,
cur_tx is self,
True,
)
resume_codes.append(resume_code)
# If the resume instruction is a jump absolute, then resume
# at the target instead. This handles the case where we
# graph break again in a nested function before jump-resuming
# this frame.
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
resume_names.append(resume_name)
cur_tx = cur_tx.parent
idx += 1
# More locals may have been pruned in the current frame
# after the unsupported instruction (e.g. branch).
# There should not be any pruning in the other frames since
# the current instruction is a CALL.
if cur_tx is self:
reads = livevars_analysis(cur_tx.instructions, resume_inst)
all_argnames = tuple(
k
for k in cur_tx.symbolic_locals.keys()
if k in reads and k not in cur_tx.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
self.codegen_call_resume(resume_codes, resume_names, cg)
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
@staticmethod
def codegen_call_resume(
resume_codes: list[types.CodeType], resume_names: list[str], cg: PyCodegen
) -> None:
"""
Calls the provided resume functions.
Expects the TOS to be in the state:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
]
Pops the cells and frame values, leaving the result of calling the resume functions on TOS.
Args:
- resume_codes: list of resume function code objects to call
- resume_names: list of the corresponding names of the resume functions
- cg: PyCodegen object to output instructions to
"""
# NOTE: We will load cells as we load resume functions
# load resume functions except the root's
cg.extend_output(create_copy(2))
for i, (name, code) in enumerate(zip(resume_names, resume_codes)):
if i == len(resume_names) - 1:
break
# stack: cells, frames, *(resume 1, ...), cells
if code.co_freevars:
# codegen filter for current frame's locals
# current stack state: frames
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(i),
cg.create_binary_subscr(),
create_dup_top(),
]
)
cg.make_function_with_closure(name, code)
for arg in argnames:
# current stack state: frames, frames[i], *(prev locals), frames[i]
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(
meta.num_stack + meta.locals_names[arg]
),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(2),
# frames, frames i live locals, frames[i]
*create_binary_slice(meta.num_stack, None, True),
# frames[i][num_stack:] = frame i live locals
]
)
# current stack state: frames
else:
argnames = tuple(meta.locals_names.keys())
argnames_null = tuple(meta.locals_null_keys)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
new_code: types.CodeType = ContinueExecutionCache.lookup(
cur_tx.f_code,
cur_tx.lineno,
resume_inst.offset,
tuple(b.target.offset for b in cur_tx.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in cur_tx.block_stack),
tuple(meta.stack_ctx_args),
tuple(meta.locals_ctx_args),
tuple(meta.stack_null_idxes),
tuple(resume_codes),
)
resume_codes.append(new_code)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
cur_tx.output.install_global_unsafe(resume_name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
cur_tx.output.install_global_unsafe(
resume_name,
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
)
package_name = resume_name
if cur_tx.package is not None:
cur_tx.package.add_resume_function(
new_code, cur_tx.f_globals["__name__"], package_name
)
if disable_current_frame_resume:
from .eval_frame import skip_code
skip_code(resume_codes[0])
# load first resume function (to be called this frame)
if resume_codes[-1].co_freevars:
cg.make_function_with_closure(
txes[-1], resume_names[-1], resume_codes[-1], True, 1
)
else:
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
# load all other resume functions (to be called later)
resume_names.pop()
resume_codes.pop()
for tx, name, code in zip(txes, resume_names, resume_codes):
if code.co_freevars:
cg.make_function_with_closure(tx, name, code, False, 0)
else:
cg.extend_output(cg.load_function_name(name, False, 0))
cg.extend_output(create_swap(2))
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(resume_codes) - 1),
create_instruction("BUILD_LIST", arg=len(resume_codes)),
*create_swap(2),
]
)
# stack: cells, frames, [resume 1, ..., resume N - 1]
# load root resume function
cg.extend_output(create_swap(3))
if resume_codes[-1].co_freevars:
cg.extend_output(
[
cg.create_load_const(-1),
cg.create_binary_subscr(),
]
)
cg.make_function_with_closure(resume_names[-1], resume_codes[-1])
cg.extend_output(
[
*create_rot_n(3),
]
)
else:
cg.extend_output(
[
create_instruction("POP_TOP"),
*cg.load_function_name(resume_names[-1], False),
*create_rot_n(3),
]
)
# resume 1, [resume N, ..., resume 2], frames
# resume 1 (+ NULL), [resume N, ..., resume 2], frames
# load top level-frame; final stack state should be:
# first resume function (+ NULL),
@ -3103,9 +2843,11 @@ class InstructionTranslatorBase(
# TOS: [resumes, frames, *(frame 1 stack + locals)]
cg.extend_output(
[
*create_call_function_ex(False, True),
*create_call_function_ex(False),
create_instruction("RETURN_VALUE"),
]
)
return cg.get_instructions()
def should_compile_partial_graph(self) -> bool:
if sys.version_info >= (3, 11):
@ -3748,7 +3490,7 @@ class InstructionTranslatorBase(
self.active_generic_context_managers.append(ctx)
if sys.version_info >= (3, 11):
# See update_block_stack/create_resume for block stack details.
# See create_call_resume_at for block stack details.
# Only push a block if the current instruction's block is a
# with block that is not nested in a try block - that is, the current
# instruction's block target is the same as the top block's target.
@ -4103,7 +3845,6 @@ class InstructionTranslatorBase(
self.accept_prefix_inst = True
self.prefix_insts = []
self.exn_vt_stack = exn_vt_stack
self.latest_bytecode_queue = deque(maxlen=20)
# Properties of the input/output code
self.instructions: list[Instruction] = instructions
@ -4449,7 +4190,9 @@ class InstructionTranslator(InstructionTranslatorBase):
assert len(all_stack_locals_metadata) == 1
assert not all_stack_locals_metadata[0].stack_null_idxes
self.output.add_output_instructions(
self.codegen_return_with_pops(inst, all_stack_locals_metadata[0].num_stack)
self.codegen_return_after_compile_subgraph(
inst, all_stack_locals_metadata[0]
)
)
raise ReturnValueOp
@ -4834,10 +4577,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: list[StackLocalsMetadata],
all_stack_locals_metadata: Any,
disable_current_frame_resume: bool,
) -> list[Instruction]:
if config.nested_graph_breaks:
return super().create_call_resume_at(inst, all_stack_locals_metadata)
return super().create_call_resume_at(
inst, all_stack_locals_metadata, disable_current_frame_resume
)
unimplemented_v2(
gb_type="Graph break in inlined function",
context="",

View File

@ -506,12 +506,6 @@ def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
return unittest.skip("Requires Python 3.12+")(fn)
def skipIfOnlyNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 13) or sys.version_info < (3, 12):
return unittest.skip("Requires Python 3.12")(fn)
return fn
def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 12):
return unittest.expectedFailure(fn)

View File

@ -51,6 +51,7 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .utils import (
getfile,
hashable,
is_annotate_wrapped_function,
is_lru_cache_wrapped_function,
NP_SUPPORTED_MODULES,
unwrap_if_wrapper,
@ -154,6 +155,7 @@ manual_torch_name_rule_map: dict[
type[UserFunctionVariable],
],
] = {
"torch.fx.traceback.annotate": UserFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@ -2994,6 +2996,9 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
continue
obj = torch_dir + k[len("torch/") :]
if obj is not None:
if is_annotate_wrapped_function(obj):
# pyrefly: ignore # missing-attribute
obj = obj.__wrapped__
if is_lru_cache_wrapped_function(obj):
obj = obj.__wrapped__
if obj in d and d[obj] != v:
@ -3425,7 +3430,6 @@ MOD_INLINELIST = [
"torch.fx._symbolic_trace",
"torch.fx.experimental.proxy_tensor",
"torch.fx.passes.shape_prop",
"torch.fx.traceback",
"torch.nn",
"torch.overrides",
"torch.random",

View File

@ -1111,6 +1111,14 @@ def is_lru_cache_wrapped_function(
)
def is_annotate_wrapped_function(
value: Any,
) -> bool:
return value == torch.fx.traceback.annotate and is_function(
inspect.getattr_static(value, "__wrapped__")
)
_FuncTypes: TypeAlias = Union[
types.FunctionType,
types.BuiltinFunctionType,

View File

@ -29,7 +29,6 @@ from .ctx_manager import (
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
FSDPParamGroupUseTrainingStateVariable,
FxTracebackAnnotateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,

View File

@ -1262,34 +1262,6 @@ class SDPAKernelVariable(ContextWrappingVariable):
return "_sdpa_kernel_variadic"
class FxTracebackAnnotateVariable(ContextWrappingVariable):
"""
fx.traceback.annotate is a context manager that allows users to annotate the
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
to trace the body of the context manager. Instead we want to directly run
the body of the context manager, so the Dynamo created Fx graphs have the
right custom metadata. This variable tracker just runs __enter__ and
__exit__ method (instead of tracing).
"""
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
def enter(self, tx, *args):
cm = torch.fx.traceback.annotate(self.target_values)
cm.__enter__()
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
return variables.ConstantVariable.create(None)
def module_name(self):
return "torch.fx.traceback"
def fn_name(self):
return "annotate"
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:

View File

@ -52,7 +52,6 @@ from ..exc import (
ObservedUserStopIteration,
raise_observed_exception,
SkipFrame,
StepUnsupported,
unimplemented_v2,
Unsupported,
)
@ -1528,8 +1527,6 @@ class SkipFunctionVariable(VariableTracker):
raise SkipFrame(
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
)
elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported
else:
if config.dont_skip_tracing:
from .builder import SourcelessBuilder

View File

@ -449,7 +449,7 @@ class ZipVariable(IteratorVariable):
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
*create_call_function_ex(True, False),
*create_call_function_ex(True),
]
)
@ -487,7 +487,7 @@ class MapVariable(ZipVariable):
codegen.extend_output(
[
create_build_tuple(len(self.iterables) + 1),
*create_call_function_ex(False, False),
*create_call_function_ex(False),
]
)

View File

@ -1579,7 +1579,7 @@ class StringFormatVariable(VariableTracker):
variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
}
codegen(variables.ConstDictVariable(kwargs))
codegen.extend_output(create_call_function_ex(True, False))
codegen.extend_output(create_call_function_ex(True))
class DebuggingVariable(VariableTracker):

View File

@ -125,7 +125,6 @@ supported_ctx_manager_classes = dict.fromkeys(
torch.autograd.graph.disable_saved_tensors_hooks,
torch.cpu.amp.autocast_mode.autocast,
torch.cuda.amp.autocast_mode.autocast,
torch.fx.traceback.annotate,
# We'll let Dynamo inline into the contextlib part of these context
# manager instances, all the way till it invokes the wrapped function
# itself (at which point we wrap it back to special context manager
@ -326,7 +325,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
FxTracebackAnnotateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
@ -361,11 +359,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
assert len(args) <= 1 and len(kwargs) == 0
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
return InferenceModeVariable.create(tx, inf_mode)
elif self.value is torch.fx.traceback.annotate:
assert len(args) <= 1 and len(kwargs) == 0
return FxTracebackAnnotateVariable(
args[0].as_python_constant(), source=self.source
)
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls

Some files were not shown because too many files have changed in this diff Show More