mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Compare commits
1 Commits
DynamoFixG
...
malfet-pat
| Author | SHA1 | Date | |
|---|---|---|---|
| ff5b19523e |
@ -256,7 +256,7 @@ test_torchbench_smoketest() {
|
||||
local device=mps
|
||||
local dtypes=(undefined float16 bfloat16 notset)
|
||||
local dtype=${dtypes[$1]}
|
||||
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
|
||||
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
|
||||
|
||||
for backend in eager inductor; do
|
||||
|
||||
@ -319,7 +319,7 @@ test_aoti_torchbench_smoketest() {
|
||||
local device=mps
|
||||
local dtypes=(undefined float16 bfloat16 notset)
|
||||
local dtype=${dtypes[$1]}
|
||||
local models=(llama BERT_pytorch dcgan yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor vgg16)
|
||||
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam sam_fast pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo doctr_det_predictor doctr_reco_predictor timm_resnet timm_vovnet vgg16)
|
||||
|
||||
echo "Launching torchbench inference performance run for AOT Inductor and dtype ${dtype}"
|
||||
local dtype_arg="--${dtype}"
|
||||
|
||||
@ -838,7 +838,7 @@ test_dynamo_benchmark() {
|
||||
elif [[ "${suite}" == "timm_models" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="inception_v3"
|
||||
elif [[ "${suite}" == "torchbench" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="BERT_pytorch"
|
||||
export TORCHBENCH_ONLY_MODELS="hf_Bert"
|
||||
fi
|
||||
fi
|
||||
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
|
||||
@ -869,13 +869,13 @@ test_inductor_torchbench_smoketest_perf() {
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
|
||||
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \
|
||||
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only BERT_pytorch \
|
||||
--batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \
|
||||
--output "$TEST_REPORTS_DIR/inductor_training_smoketest.csv"
|
||||
# The threshold value needs to be actively maintained to make this check useful
|
||||
python benchmarks/dynamo/check_perf_csv.py -f "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" -t 1.4
|
||||
|
||||
# Check memory compression ratio for a few models
|
||||
for test in BERT_pytorch yolov3; do
|
||||
for test in hf_Albert timm_vision_transformer; do
|
||||
python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --amp --training \
|
||||
--disable-cudagraphs --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" \
|
||||
--only $test --output "$TEST_REPORTS_DIR/inductor_training_smoketest_$test.csv"
|
||||
|
||||
@ -71,7 +71,14 @@ export PYTORCH_BUILD_NUMBER=1
|
||||
|
||||
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
|
||||
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux'"
|
||||
|
||||
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64'"
|
||||
|
||||
# CUDA 12.9/13.0 builds have triton for Linux and Linux aarch64 binaries.
|
||||
if [[ "$DESIRED_CUDA" == "cu129" ]] || [[ "$DESIRED_CUDA" == "cu130" ]]; then
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux'"
|
||||
fi
|
||||
|
||||
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
|
||||
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
|
||||
|
||||
35
.github/actions/setup-linux/action.yml
vendored
35
.github/actions/setup-linux/action.yml
vendored
@ -28,10 +28,6 @@ runs:
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
echo "system info $(uname -a)"
|
||||
|
||||
- name: Print GPU info (if present)
|
||||
shell: bash
|
||||
run: if [ -f /usr/bin/nvidia-smi ]; then nvidia-smi; fi
|
||||
|
||||
- name: Check if in a container runner
|
||||
shell: bash
|
||||
id: check_container_runner
|
||||
@ -86,6 +82,37 @@ runs:
|
||||
# Prune all of the docker images
|
||||
docker system prune -af
|
||||
|
||||
- name: Manually resolve download.pytorch.org
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set +e
|
||||
set -x
|
||||
|
||||
PT_DOMAIN=download.pytorch.org
|
||||
# TODO: Flaky access to download.pytorch.org https://github.com/pytorch/pytorch/issues/100400,
|
||||
# cleaning this up once the issue is fixed. There are more than one resolved IP here, the last
|
||||
# one is returned at random
|
||||
RESOLVED_IP=$(dig -4 +short "${PT_DOMAIN}" | tail -n1)
|
||||
|
||||
if [ -z "${RESOLVED_IP}" ]; then
|
||||
echo "Couldn't resolve ${PT_DOMAIN}, retrying with Google DNS..."
|
||||
RESOLVED_IP=$(dig -4 +short "${PT_DOMAIN}" @8.8.8.8 | tail -n1)
|
||||
|
||||
if [ -z "${RESOLVED_IP}" ]; then
|
||||
echo "Couldn't resolve ${PT_DOMAIN}, exiting..."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -r "${PT_DOMAIN}" /etc/hosts; then
|
||||
# Clean up any old records first
|
||||
sudo sed -i "/${PT_DOMAIN}/d" /etc/hosts
|
||||
fi
|
||||
|
||||
echo "${RESOLVED_IP} ${PT_DOMAIN}" | sudo tee -a /etc/hosts
|
||||
cat /etc/hosts
|
||||
|
||||
- name: Check that the docker daemon is running
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
|
||||
@ -2,7 +2,7 @@ name: inductor-perf-nightly-h100
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 15 0 * * 1-6
|
||||
- cron: 15 0,12 * * 1-6
|
||||
- cron: 0 7 * * 0
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
|
||||
@ -6693,12 +6693,12 @@
|
||||
|
||||
- func: native_norm(Tensor self, Scalar p=2) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: norm_sparse
|
||||
SparseCPU, SparseCUDA: norm_sparse
|
||||
autogen: native_norm.out
|
||||
|
||||
- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: norm_sparse
|
||||
SparseCPU, SparseCUDA: norm_sparse
|
||||
autogen: native_norm.ScalarOpt_dim_dtype_out
|
||||
|
||||
- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
@ -6824,14 +6824,14 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_dtype_norm
|
||||
SparseCPU, SparseCUDA: sparse_dtype_norm
|
||||
|
||||
- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
|
||||
structured_delegate: norm.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_norm
|
||||
SparseCPU, SparseCUDA: sparse_norm
|
||||
|
||||
- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
|
||||
@ -25,6 +25,15 @@ drq
|
||||
fambench_dlrm
|
||||
fambench_xlmr
|
||||
fastNLP_Bert
|
||||
hf_Albert
|
||||
hf_Bart
|
||||
hf_Bert
|
||||
hf_BigBird
|
||||
hf_DistilBert
|
||||
hf_GPT2
|
||||
hf_Longformer
|
||||
hf_Reformer
|
||||
hf_T5
|
||||
maml
|
||||
maml_omniglot
|
||||
mnasnet1_0
|
||||
@ -51,6 +60,13 @@ soft_actor_critic
|
||||
speech_transformer
|
||||
squeezenet1_1
|
||||
tacotron2
|
||||
timm_efficientdet
|
||||
timm_efficientnet
|
||||
timm_nfnet
|
||||
timm_regnet
|
||||
timm_resnest
|
||||
timm_vision_transformer
|
||||
timm_vovnet
|
||||
tts_angular
|
||||
vgg16
|
||||
vision_maskrcnn
|
||||
|
||||
@ -23,6 +23,7 @@ TORCHBENCH_MODELS: list[str] = [
|
||||
"resnet50",
|
||||
"moco",
|
||||
"llama",
|
||||
"hf_T5",
|
||||
]
|
||||
HUGGINGFACE_MODELS: list[str] = [
|
||||
"AllenaiLongformerBase",
|
||||
|
||||
@ -11,6 +11,7 @@ import pandas as pd
|
||||
flaky_models = {
|
||||
"yolov3",
|
||||
"detectron2_maskrcnn_r_101_c4",
|
||||
"timm_efficientnet", # see https://github.com/pytorch/pytorch/issues/148699
|
||||
"XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148
|
||||
"moondream", # discovered in https://github.com/pytorch/pytorch/pull/159291
|
||||
# discovered in https://github.com/pytorch/pytorch/issues/161419. Its not flaky but really hard to repro, so skipping it
|
||||
@ -39,9 +40,13 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
|
||||
"detectron2_fcos_r_50_fpn",
|
||||
"doctr_det_predictor",
|
||||
"doctr_reco_predictor",
|
||||
"dpn107",
|
||||
"fbnetv3_b",
|
||||
"levit_128",
|
||||
"hf_BigBird",
|
||||
"hf_Longformer",
|
||||
"hf_Reformer",
|
||||
"hf_Roberta_base",
|
||||
"hf_T5",
|
||||
"hf_T5_base",
|
||||
"hf_T5_generate",
|
||||
"llava",
|
||||
"microbench_unbacked_tolist_sum",
|
||||
"mnasnet1_0",
|
||||
@ -58,7 +63,12 @@ def check_accuracy(actual_csv, expected_csv, expected_filename):
|
||||
"squeezenet1_1",
|
||||
"stable_diffusion_text_encoder",
|
||||
"stable_diffusion_unet",
|
||||
"swsl_resnext101_32x16d",
|
||||
"timm_efficientdet",
|
||||
"timm_efficientnet",
|
||||
"timm_nfnet",
|
||||
"timm_regnet",
|
||||
"timm_resnest",
|
||||
"timm_vovnet",
|
||||
"torchrec_dlrm",
|
||||
"vgg16",
|
||||
# LLM
|
||||
|
||||
@ -36,7 +36,12 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
|
||||
"detectron2_fcos_r_50_fpn",
|
||||
"doctr_det_predictor",
|
||||
"doctr_reco_predictor",
|
||||
"levit_128",
|
||||
"hf_BigBird",
|
||||
"hf_Longformer",
|
||||
"hf_Reformer",
|
||||
"hf_Roberta_base",
|
||||
"hf_T5",
|
||||
"hf_T5_base",
|
||||
"llava",
|
||||
"microbench_unbacked_tolist_sum",
|
||||
"resnet50",
|
||||
@ -46,6 +51,7 @@ def check_graph_breaks(actual_csv, expected_csv, expected_filename):
|
||||
"stable_diffusion_text_encoder",
|
||||
"stable_diffusion_unet",
|
||||
"timm_efficientdet",
|
||||
"timm_nfnet",
|
||||
"torchrec_dlrm",
|
||||
"vgg16",
|
||||
# LLM
|
||||
|
||||
@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -118,6 +118,62 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -258,6 +314,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -114,6 +114,58 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -226,6 +278,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -114,6 +114,58 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -226,6 +278,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,27
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,27
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,27
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -190,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -98,6 +98,58 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -210,6 +262,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -98,6 +98,58 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -210,6 +262,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -106,6 +106,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,27
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -226,6 +286,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -122,6 +122,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,25
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,8
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -242,6 +302,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,3
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -190,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,70 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +342,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,62 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +250,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,9
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,8
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,25
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +258,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -118,6 +118,62 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -258,6 +314,34 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,9
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,8
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_to_run,3
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,25
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -190,6 +254,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,74 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,5
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +346,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_to_run,3
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,10
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,20
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,5
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -190,6 +254,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,8
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,6
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,73 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,9
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,8
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,0
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,7
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -278,6 +345,38 @@ stable_diffusion_unet,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,70 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,15
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Longformer,pass,4
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,25
|
||||
|
||||
|
||||
|
||||
hf_Roberta_base,pass,6
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +258,38 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientdet,pass,2
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_nfnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -130,6 +130,66 @@ functorch_maml_omniglot,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,0
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,8
|
||||
|
||||
|
||||
|
||||
hf_T5,pass,0
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
hf_T5_generate,pass,11
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,0
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,pass,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,0
|
||||
|
||||
|
||||
@ -274,6 +334,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,0
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,0
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,0
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,0
|
||||
|
||||
|
||||
|
||||
|
@ -78,6 +78,58 @@ functorch_maml_omniglot,pass,7
|
||||
|
||||
|
||||
|
||||
hf_Albert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bart,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,6
|
||||
|
||||
|
||||
|
||||
hf_DistilBert,pass,6
|
||||
|
||||
|
||||
|
||||
hf_GPT2,pass,8
|
||||
|
||||
|
||||
|
||||
hf_GPT2_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Reformer,pass,25
|
||||
|
||||
|
||||
|
||||
hf_T5_base,eager_2nd_run_OOM,0
|
||||
|
||||
|
||||
|
||||
hf_T5_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
hf_Whisper,pass,6
|
||||
|
||||
|
||||
|
||||
hf_distil_whisper,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
lennard_jones,pass,7
|
||||
|
||||
|
||||
@ -194,6 +246,30 @@ stable_diffusion_unet,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_efficientnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_regnet,pass,7
|
||||
|
||||
|
||||
|
||||
timm_resnest,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer,pass,6
|
||||
|
||||
|
||||
|
||||
timm_vision_transformer_large,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
timm_vovnet,pass,6
|
||||
|
||||
|
||||
|
||||
torch_multimodal_clip,pass,7
|
||||
|
||||
|
||||
|
||||
|
@ -149,6 +149,7 @@ CI_SKIP_DYNAMIC_BATCH_ONLY = {
|
||||
"detectron2_fasterrcnn_r_50_c4",
|
||||
"detectron2_fasterrcnn_r_50_dc5",
|
||||
"detectron2_fasterrcnn_r_50_fpn",
|
||||
"hf_T5_generate",
|
||||
"Reformer",
|
||||
"llama",
|
||||
}.union(INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY)
|
||||
@ -175,7 +176,13 @@ BENCHMARK_USE_SGD = {
|
||||
"speech_transformer",
|
||||
"squeezenet1_1",
|
||||
"stable_diffusion_text_encoder",
|
||||
"timm_efficientdet",
|
||||
"timm_nfnet",
|
||||
"timm_resnest",
|
||||
"timm_vision_transformer",
|
||||
"timm_vovnet",
|
||||
"vgg16",
|
||||
"hf_T5", # Fails dynamic https://github.com/pytorch/pytorch/issues/115968
|
||||
# HF
|
||||
"AlbertForMaskedLM",
|
||||
"BartForCausalLM",
|
||||
@ -209,6 +216,8 @@ CI_USE_SGD = {
|
||||
"detectron2_maskrcnn_r_101_fpn",
|
||||
"detectron2_maskrcnn_r_50_c4",
|
||||
"detectron2_maskrcnn_r_50_fpn",
|
||||
"hf_T5_base",
|
||||
"hf_clip",
|
||||
"llama_v2_7b_16h",
|
||||
"mobilenet_v2_quantized_qat",
|
||||
"phi_1_5 resnet50_quantized_qat",
|
||||
@ -2022,6 +2031,8 @@ class BenchmarkRunner:
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
from torchbenchmark.models.nanogpt.model import Block
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
from transformers.models.t5.modeling_t5 import T5Block
|
||||
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
|
||||
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
ModuleWrapPolicy,
|
||||
@ -2031,6 +2042,10 @@ class BenchmarkRunner:
|
||||
# handcrafted wrap policy
|
||||
MODEL_FSDP_WRAP = {
|
||||
"stable_diffusion_unet": (Transformer2DModel,),
|
||||
"hf_T5": (T5Block,),
|
||||
"hf_T5_base": (T5Block,),
|
||||
"hf_T5_large": (T5Block,),
|
||||
"hf_Whisper": (WhisperEncoderLayer,),
|
||||
"llama_v2_7b_16h": (LlamaDecoderLayer,),
|
||||
"nanogpt": (Block,),
|
||||
}
|
||||
@ -3795,6 +3810,22 @@ def run(runner, args, original_dir=None):
|
||||
global synchronize
|
||||
synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize
|
||||
|
||||
if (
|
||||
args.devices == ["cuda"]
|
||||
and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
|
||||
):
|
||||
# OOM errors on an RTX 3090 with 24gb RAM
|
||||
runner.skip_models.update(
|
||||
{
|
||||
# torchbench
|
||||
"hf_Longformer",
|
||||
"timm_nfnet",
|
||||
"timm_efficientdet",
|
||||
}
|
||||
)
|
||||
if args.training:
|
||||
runner.skip_models.add("hf_T5")
|
||||
|
||||
if args.nnc:
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
@ -21,6 +21,9 @@ try:
|
||||
except ImportError:
|
||||
from torchbench import setup_torchbench_cwd
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
|
||||
from transformers.models.t5.modeling_t5 import T5Block
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
|
||||
@ -125,6 +128,8 @@ def fsdp_checkpointing_base(model, blocks):
|
||||
|
||||
MODEL_FSDP_WRAP = {
|
||||
"toy_model": (MyModule,),
|
||||
"hf_Bert": (BertLayer, BertLMPredictionHead),
|
||||
"hf_T5": (T5Block,),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ if __name__ == "__main__":
|
||||
model_arg.add_argument(
|
||||
"--torchbench-model",
|
||||
"--torchbench_model",
|
||||
help="name of torchbench model, e.g. BERT_pytorch",
|
||||
help="name of torchbench model, e.g. hf_Bert",
|
||||
)
|
||||
model_arg.add_argument(
|
||||
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
|
||||
|
||||
@ -12,6 +12,17 @@ cuda,dlrm,1024,1.3421,3.2177,4.9493,1.0009
|
||||
cuda,drq,1,1.0820,3.8157,8.0732,0.9687
|
||||
cuda,fastNLP_Bert,6,1.4839,37.9050,32.7583,1.1563
|
||||
cuda,functorch_dp_cifar10,64,1.5014,6.9596,14.1516,0.4432
|
||||
cuda,hf_Albert,8,2.2452,30.6134,25.9036,1.3098
|
||||
cuda,hf_Bart,4,1.7012,34.3999,37.9975,1.0128
|
||||
cuda,hf_Bert,4,1.9003,23.3435,34.8196,1.0273
|
||||
cuda,hf_Bert_large,4,1.6346,52.8525,62.3112,1.0726
|
||||
cuda,hf_BigBird,2,1.9208,105.2672,101.4787,1.1415
|
||||
cuda,hf_DistilBert,8,1.3988,22.5793,20.2386,1.0232
|
||||
cuda,hf_GPT2,4,1.8075,27.5184,25.3428,1.1562
|
||||
cuda,hf_GPT2_large,4,1.7716,118.7404,68.1618,1.1725
|
||||
cuda,hf_Reformer,4,1.1744,70.4228,15.1152,0.9266
|
||||
cuda,hf_T5,8,1.8778,93.3134,37.0046,1.2279
|
||||
cuda,hf_T5_large,2,2.3623,101.5518,143.7982,1.1674
|
||||
cuda,lennard_jones,1000,1.0649,1.5233,4.1119,0.9998
|
||||
cuda,mnasnet1_0,32,1.1957,19.1993,27.2302,0.7758
|
||||
cuda,mobilenet_v2,96,1.4876,32.3311,27.4719,1.1729
|
||||
@ -31,6 +42,14 @@ cuda,shufflenet_v2_x1_0,128,1.3027,25.7017,27.9875,1.1015
|
||||
cuda,soft_actor_critic,256,0.9965,2.2580,4.6661,0.9995
|
||||
cuda,speech_transformer,32,1.8405,35.1645,33.3422,1.0888
|
||||
cuda,squeezenet1_1,32,1.4191,7.3454,9.4751,1.1148
|
||||
cuda,timm_efficientdet,1,1.6630,78.2697,150.9620,0.9904
|
||||
cuda,timm_efficientnet,32,1.2689,28.5348,66.3911,0.9428
|
||||
cuda,timm_nfnet,128,1.5319,79.5429,32.9961,1.1070
|
||||
cuda,timm_regnet,32,1.0564,56.9897,53.0027,0.9500
|
||||
cuda,timm_resnest,32,1.6485,14.3908,56.7240,0.9515
|
||||
cuda,timm_vision_transformer,8,1.6100,18.7736,36.9495,0.7301
|
||||
cuda,timm_vision_transformer_large,8,1.0842,170.9849,72.0604,0.9762
|
||||
cuda,timm_vovnet,32,1.0472,25.4676,24.8428,0.8843
|
||||
cuda,tts_angular,64,1.0366,6.9889,4.2683,0.9973
|
||||
cuda,vgg16,64,1.2560,52.7072,7.3733,0.9884
|
||||
cuda,yolov3,16,1.2600,54.2350,42.4711,1.0108
|
||||
|
||||
|
@ -1,16 +1,29 @@
|
||||
#name,backend,data_type,shape,wrapper,perf_speedup_target_c7i_metal_24xl
|
||||
#timm_vision_transformer,inductor,float32,static,default,1.039510755
|
||||
phlippe_densenet,inductor,float32,static,default,1.46474287
|
||||
basic_gnn_edgecnn,inductor,float32,dynamic,default,1.30092957
|
||||
llama_v2_7b_16h,inductor,float32,dynamic,default,1.23234331
|
||||
resnet50,inductor,float32,dynamic,default,1.67742767
|
||||
#timm_efficientnet,inductor,float32,static,cpp,
|
||||
mobilenet_v3_large,inductor,float32,static,cpp,2.63311706
|
||||
timm_resnest,inductor,float32,dynamic,cpp,1.7321529
|
||||
functorch_maml_omniglot,inductor,float32,dynamic,cpp,1.126799
|
||||
#hf_GPT2,inductor,float32,dynamic,cpp,
|
||||
yolov3,export-aot-inductor,float32,static,default,1.40687424
|
||||
mobilenet_v2,export-aot-inductor,float32,static,default,2.90375357
|
||||
resnext50_32x4d,export-aot-inductor,float32,dynamic,default,1.49299689
|
||||
hf_Albert,export-aot-inductor,float32,dynamic,default,1.261471
|
||||
resnext50_32x4d,inductor,amp,static,default,1.47023111
|
||||
vgg16,inductor,amp,static,default,1.2692454
|
||||
hf_Longformer,inductor,amp,dynamic,default,1.22015225
|
||||
hf_Bert_large,inductor,amp,dynamic,default,1.18572179
|
||||
llama,inductor,amp,static,default,1.33157028
|
||||
timm_regnet,inductor,amp,static,cpp,1.12734073
|
||||
mnasnet1_0,inductor,amp,static,cpp,2.1296814
|
||||
#hf_T5_generate,inductor,amp,dynamic,cpp,
|
||||
timm_vovnet,inductor,amp,dynamic,cpp,1.10851009
|
||||
#mobilenet_v2,inductor,amp,dynamic,cpp,2.27774577 # https://github.com/pytorch/pytorch/issues/131693
|
||||
hf_GPT2,export-aot-inductor,amp,static,default,1.4432794
|
||||
densenet121,export-aot-inductor,amp,static,default,1.25591385
|
||||
hf_DistilBert,export-aot-inductor,amp,dynamic,default,1.2926442
|
||||
hf_Bart,export-aot-inductor,amp,dynamic,default,1.19515416
|
||||
|
||||
|
@ -75,7 +75,29 @@ def setup_torchbench_cwd():
|
||||
return original_dir
|
||||
|
||||
|
||||
process_train_model_output = {}
|
||||
def process_hf_reformer_output(out):
|
||||
assert isinstance(out, list)
|
||||
# second output is unstable
|
||||
return [elem for i, elem in enumerate(out) if i != 1]
|
||||
|
||||
|
||||
def process_hf_whisper_output(out):
|
||||
out_ret = []
|
||||
for i, elem in enumerate(out):
|
||||
if i == 0:
|
||||
if elem is not None:
|
||||
assert isinstance(elem, dict)
|
||||
out_ret.append({k: v for k, v in elem.items() if k != "logits"})
|
||||
elif i != 1:
|
||||
out_ret.append(elem)
|
||||
|
||||
return out_ret
|
||||
|
||||
|
||||
process_train_model_output = {
|
||||
"hf_Reformer": process_hf_reformer_output,
|
||||
"hf_Whisper": process_hf_whisper_output,
|
||||
}
|
||||
|
||||
|
||||
class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
@ -205,10 +227,12 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
"drq",
|
||||
"hf_Reformer",
|
||||
"DALLE2_pytorch",
|
||||
"hf_BigBird",
|
||||
"detectron2_maskrcnn_r_50_fpn",
|
||||
"detectron2_maskrcnn_r_101_fpn",
|
||||
"vision_maskrcnn",
|
||||
"doctr_reco_predictor",
|
||||
"hf_T5_generate",
|
||||
}
|
||||
|
||||
def load_model(
|
||||
@ -371,6 +395,8 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
||||
and hasattr(model.config, "use_cache")
|
||||
):
|
||||
model.config.use_cache = False
|
||||
if model_name == "hf_T5_generate":
|
||||
model.model.config.use_cache = False
|
||||
|
||||
self.validate_model(model, example_inputs)
|
||||
return device, benchmark.name, model, example_inputs, batch_size
|
||||
|
||||
@ -5,6 +5,8 @@ batch_size:
|
||||
demucs: 4
|
||||
dlrm: 1024
|
||||
densenet121: 4
|
||||
hf_Reformer: 4
|
||||
hf_T5_base: 4
|
||||
timm_efficientdet: 1
|
||||
llama_v2_7b_16h: 1
|
||||
# reduced from 16 due to cudagraphs OOM in TorchInductor dashboard
|
||||
@ -28,6 +30,7 @@ tolerance:
|
||||
- alexnet
|
||||
- attention_is_all_you_need_pytorch
|
||||
- densenet121
|
||||
- hf_Albert
|
||||
- vgg16
|
||||
- mobilenet_v3_large
|
||||
- nvidia_deeprecommender
|
||||
@ -37,16 +40,20 @@ tolerance:
|
||||
- soft_actor_critic
|
||||
- tacotron2
|
||||
- yolov3
|
||||
- timm_efficientdet
|
||||
- timm_efficientnet
|
||||
- squeezenet1_1
|
||||
|
||||
higher_fp16:
|
||||
- doctr_reco_predictor
|
||||
- drq
|
||||
- hf_Whisper
|
||||
- phlippe_resnet
|
||||
|
||||
higher_bf16:
|
||||
- doctr_reco_predictor
|
||||
- drq
|
||||
- hf_Whisper
|
||||
|
||||
# These models need higher tolerance for xpu devices with bf16
|
||||
higher_bf16_xpu:
|
||||
@ -64,9 +71,16 @@ tolerance:
|
||||
|
||||
require_larger_multiplier_for_smaller_tensor:
|
||||
- yolov3
|
||||
- timm_efficientnet
|
||||
|
||||
# These benchmarks took >600s on an i9-11900K CPU
|
||||
very_slow: &VERY_SLOW_MODELS
|
||||
# 3339s
|
||||
- hf_BigBird
|
||||
# 3062s
|
||||
- hf_Longformer
|
||||
# 930s
|
||||
- hf_T5
|
||||
|
||||
|
||||
# These benchmarks took >60s on an i9-11900K CPU
|
||||
@ -78,6 +92,18 @@ slow:
|
||||
- demucs
|
||||
# 242s
|
||||
- fastNLP_Bert
|
||||
# 221s
|
||||
- hf_Albert
|
||||
# 400s
|
||||
- hf_Bart
|
||||
# 334s
|
||||
- hf_Bert
|
||||
# 187s
|
||||
- hf_DistilBert
|
||||
# 470s
|
||||
- hf_GPT2
|
||||
# 141s
|
||||
- hf_Reformer
|
||||
# 317s
|
||||
- speech_transformer
|
||||
# 99s
|
||||
@ -161,36 +187,11 @@ skip:
|
||||
- hf_clip
|
||||
# multi gpu not always available in benchmark runners
|
||||
- simple_gpt_tp_manual
|
||||
# skip hf and timm models in torchbench since
|
||||
# there are already separate benchmarks for them
|
||||
- hf_Albert
|
||||
- hf_Bart
|
||||
- hf_Bert
|
||||
- hf_BigBird
|
||||
- hf_DistilBert
|
||||
- hf_GPT2
|
||||
- hf_Longformer
|
||||
- hf_Reformer
|
||||
- hf_T5
|
||||
- timm_efficientdet
|
||||
- timm_efficientnet
|
||||
- timm_nfnet
|
||||
- timm_regnet
|
||||
- timm_resnest
|
||||
- timm_vision_transformer
|
||||
- timm_vovnet
|
||||
- hf_Bert_large
|
||||
- hf_GPT2_large
|
||||
- hf_Roberta_base
|
||||
- hf_T5_base
|
||||
- hf_T5_generate
|
||||
- hf_T5_large
|
||||
- hf_Whisper
|
||||
- hf_distil_whisper
|
||||
- timm_vision_transformer_large
|
||||
|
||||
device:
|
||||
cpu:
|
||||
# OOMs
|
||||
- hf_T5_generate
|
||||
# model is CUDA only
|
||||
- cm3leon_generate
|
||||
# timeout
|
||||
@ -207,12 +208,16 @@ skip:
|
||||
- torchrec_dlrm
|
||||
- simple_gpt
|
||||
# works on cuda, accuracy failure on cpu
|
||||
- hf_Whisper
|
||||
- stable_diffusion_text_encoder
|
||||
- llava
|
||||
- moco
|
||||
|
||||
# Skip these additional models when running on aarch64
|
||||
cpu_aarch64: []
|
||||
cpu_aarch64:
|
||||
# timeout on aarch64
|
||||
- timm_regnet
|
||||
- timm_nfnet
|
||||
|
||||
cuda: []
|
||||
|
||||
@ -230,6 +235,7 @@ skip:
|
||||
- sam_fast
|
||||
# Model's DEFAULT_TRAIN_BSIZE is not implemented
|
||||
- cm3leon_generate
|
||||
- hf_T5_generate
|
||||
- doctr_det_predictor
|
||||
- doctr_reco_predictor
|
||||
- moondream
|
||||
@ -241,6 +247,9 @@ skip:
|
||||
- cm3leon_generate
|
||||
- detectron2_fcos_r_50_fpn
|
||||
- fastNLP_Bert
|
||||
- hf_Longformer
|
||||
- hf_Reformer
|
||||
- hf_T5_generate
|
||||
- opacus_cifar10
|
||||
- speech_transformer
|
||||
|
||||
@ -277,6 +286,9 @@ accuracy:
|
||||
# Models too large to have eager, dynamo and fp64_numbers simultaneosuly
|
||||
# even for 40 GB machine. We have tested accuracy for smaller version of
|
||||
# these models
|
||||
- hf_GPT2_large
|
||||
- hf_T5_large
|
||||
- timm_vision_transformer_large
|
||||
# accuracy https://github.com/pytorch/pytorch/issues/93847
|
||||
- maml
|
||||
- llama_v2_7b_16h
|
||||
@ -288,4 +300,5 @@ accuracy:
|
||||
- pytorch_unet
|
||||
|
||||
max_batch_size:
|
||||
hf_GPT2: 2
|
||||
pytorch_unet: 2
|
||||
|
||||
@ -4,6 +4,11 @@ LearningToPaint,1024
|
||||
alexnet,1024
|
||||
dcgan,1024
|
||||
densenet121,64
|
||||
hf_Albert,32
|
||||
hf_Bart,16
|
||||
hf_Bert,16
|
||||
hf_GPT2,16
|
||||
hf_T5,4
|
||||
mnasnet1_0,256
|
||||
mobilenet_v2,128
|
||||
mobilenet_v3_large,256
|
||||
@ -14,4 +19,10 @@ resnet50,128
|
||||
resnext50_32x4d,128
|
||||
shufflenet_v2_x1_0,512
|
||||
squeezenet1_1,512
|
||||
timm_nfnet,256
|
||||
timm_efficientnet,128
|
||||
timm_regnet,128
|
||||
timm_resnest,256
|
||||
timm_vision_transformer,256
|
||||
timm_vovnet,128
|
||||
vgg16,128
|
||||
|
||||
@ -6,6 +6,18 @@ densenet121,512
|
||||
dlrm,2048
|
||||
fastNLP_Bert,8
|
||||
functorch_dp_cifar10,1024
|
||||
hf_Albert,8
|
||||
hf_Bart,8
|
||||
hf_Bert,8
|
||||
hf_Bert_large,8
|
||||
hf_DistilBert,8
|
||||
hf_GPT2,8
|
||||
hf_GPT2_large,1
|
||||
hf_Longformer,4
|
||||
hf_Reformer,8
|
||||
hf_T5,4
|
||||
hf_T5_base,1
|
||||
hf_T5_large,1
|
||||
LearningToPaint,96
|
||||
lennard_jones,1024
|
||||
mnasnet1_0,32
|
||||
@ -23,6 +35,13 @@ shufflenet_v2_x1_0,64
|
||||
speech_transformer,1024
|
||||
squeezenet1_1,16
|
||||
Super_SloMo,1024
|
||||
timm_efficientnet,64
|
||||
timm_nfnet,128
|
||||
timm_regnet,32
|
||||
timm_resnest,32
|
||||
timm_vision_transformer,16
|
||||
timm_vision_transformer_large,8
|
||||
timm_vovnet,32
|
||||
tts_angular,1024
|
||||
vgg16,64
|
||||
vision_maskrcnn,1
|
||||
|
||||
@ -369,7 +369,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@ -391,6 +391,7 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
],
|
||||
)
|
||||
def test_replicate_pp(self, ScheduleClass, MixedPrecisionParam):
|
||||
_device_raii = torch.device(device_type, self.device)
|
||||
torch.accelerator.set_device_index(self.device)
|
||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||
torch.distributed.init_process_group(
|
||||
@ -603,281 +604,6 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
ScheduleGPipe,
|
||||
Schedule1F1B,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleLoopedBFS,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
def test_replicate_pp_grads(self, ScheduleClass):
|
||||
torch.accelerator.set_device_index(self.device)
|
||||
store = torch.distributed.FileStore(self.file_name, self.world_size)
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
store=store,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
dim = 8
|
||||
pp_size = 2
|
||||
num_microbatches = 8
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_group = device_mesh["replicate"].get_group()
|
||||
|
||||
# create "entire model"
|
||||
total_layers = 8
|
||||
full_model = nn.ModuleList([MLPModule(dim) for _ in range(total_layers)])
|
||||
ref_model = nn.Sequential(*copy.deepcopy(full_model)).to(self.device)
|
||||
|
||||
# dummy loss needed just to force backwards to run in schedule step
|
||||
def loss_fn(y, target):
|
||||
return y.sum()
|
||||
|
||||
# Simulate microbatch processing for reference model
|
||||
def simulate_stage_forward_backward(model, inputs, labels):
|
||||
"""Simulate forward and backward passes through stages for microbatch processing"""
|
||||
batch_size, _ = inputs.shape
|
||||
total_loss = 0
|
||||
|
||||
# Split inputs into microbatches
|
||||
microbatch_size = batch_size // num_microbatches
|
||||
|
||||
for mb_idx in range(num_microbatches):
|
||||
start_idx = mb_idx * microbatch_size
|
||||
end_idx = start_idx + microbatch_size
|
||||
mb_input = inputs[start_idx:end_idx]
|
||||
mb_label = labels[start_idx:end_idx] if labels is not None else None
|
||||
|
||||
# Simulate stage-by-stage processing
|
||||
if issubclass(ScheduleClass, PipelineScheduleSingle):
|
||||
num_stages = pp_group.size()
|
||||
layers_per_stage = total_layers // pp_group.size() # 8 // 2 = 4
|
||||
else:
|
||||
n_virtual = 2
|
||||
num_stages = pp_group.size() * n_virtual
|
||||
layers_per_stage = total_layers // num_stages
|
||||
|
||||
# Forward pass through all stages
|
||||
x = mb_input
|
||||
|
||||
for stage in range(num_stages):
|
||||
start_layer = stage * layers_per_stage
|
||||
end_layer = start_layer + layers_per_stage
|
||||
|
||||
# Process layers for this stage
|
||||
for layer_idx in range(start_layer, min(end_layer, len(model))):
|
||||
x = model[layer_idx](x)
|
||||
|
||||
mb_loss = loss_fn(x, mb_label)
|
||||
total_loss += mb_loss
|
||||
|
||||
# Backward pass
|
||||
mb_loss.backward()
|
||||
|
||||
return total_loss / num_microbatches
|
||||
|
||||
# Apply replicate to stage module
|
||||
def apply_replicate(partial_model):
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh)
|
||||
return dp_model
|
||||
|
||||
def pipelined_models_parameters(start_layer, model):
|
||||
layer_idx = start_layer
|
||||
|
||||
for layer in model.children():
|
||||
for name, param in layer.named_parameters():
|
||||
updated_param_name = f"{layer_idx}.{name}"
|
||||
pipeline_model_parameter_dict[updated_param_name] = param
|
||||
layer_idx += 1
|
||||
|
||||
def check_gradient_parity(
|
||||
pipeline_model_parameter_dict, ref_model_parameter_dict
|
||||
):
|
||||
for parameter in pipeline_model_parameter_dict:
|
||||
assert parameter in ref_model_parameter_dict
|
||||
|
||||
pipeline_parameter = pipeline_model_parameter_dict[parameter]
|
||||
if pipeline_parameter.grad is not None:
|
||||
pipeline_parameter_grad = pipeline_parameter.grad.to_local()
|
||||
ref_parameter = ref_model_parameter_dict[parameter]
|
||||
if ref_parameter.grad is not None:
|
||||
torch.testing.assert_close(
|
||||
pipeline_parameter_grad,
|
||||
ref_parameter.grad,
|
||||
rtol=1e-4,
|
||||
atol=1e-5,
|
||||
)
|
||||
else:
|
||||
assert pipeline_parameter.grad is None
|
||||
|
||||
pipeline_model_parameter_dict = {}
|
||||
|
||||
# Attach to a schedule
|
||||
if issubclass(ScheduleClass, PipelineScheduleSingle):
|
||||
stage_idx = pp_group.rank()
|
||||
# Calculate layers per stage correctly
|
||||
layers_per_stage = total_layers // pp_group.size() # 8 // 2 = 4
|
||||
start_layer = stage_idx * layers_per_stage
|
||||
end_layer = start_layer + layers_per_stage
|
||||
|
||||
partial_model = nn.Sequential(*full_model[start_layer:end_layer])
|
||||
partial_model.to(self.device)
|
||||
|
||||
dp_model = apply_replicate(partial_model)
|
||||
pipelined_models_parameters(start_layer, dp_model)
|
||||
|
||||
pipeline_stage = PipelineStage(
|
||||
dp_model,
|
||||
stage_idx,
|
||||
pp_group.size(),
|
||||
self.device,
|
||||
group=pp_group,
|
||||
)
|
||||
partial_models = [pipeline_stage.submod]
|
||||
pipeline_schedule = ScheduleClass(
|
||||
pipeline_stage,
|
||||
n_microbatches=num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
scale_grads=False,
|
||||
)
|
||||
|
||||
else:
|
||||
n_virtual = 2
|
||||
num_stages = pp_group.size() * n_virtual
|
||||
layers_per_stage = total_layers // num_stages
|
||||
stages = []
|
||||
for i in range(n_virtual):
|
||||
stage_idx = pp_group.rank() + pp_group.size() * i
|
||||
start_layer = stage_idx * layers_per_stage
|
||||
end_layer = start_layer + layers_per_stage
|
||||
# divide the model layers by the number of stages
|
||||
partial_model = nn.Sequential(*full_model[start_layer:end_layer])
|
||||
partial_model.to(self.device)
|
||||
|
||||
dp_model = apply_replicate(partial_model)
|
||||
pipelined_models_parameters(start_layer, dp_model)
|
||||
stage = PipelineStage(
|
||||
dp_model,
|
||||
stage_idx,
|
||||
num_stages,
|
||||
self.device,
|
||||
group=pp_group,
|
||||
)
|
||||
|
||||
stages.append(stage)
|
||||
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
|
||||
|
||||
pipeline_schedule = ScheduleClass(
|
||||
stages,
|
||||
n_microbatches=num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
scale_grads=False,
|
||||
)
|
||||
|
||||
optimizer_kwargs = {
|
||||
"lr": 0.01,
|
||||
"betas": (0.9, 0.95),
|
||||
"weight_decay": 0.1,
|
||||
"fused": False,
|
||||
"foreach": True,
|
||||
}
|
||||
|
||||
optimizers = [
|
||||
torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
|
||||
for model in partial_models
|
||||
]
|
||||
|
||||
ref_optimizer = torch.optim.AdamW(ref_model.parameters(), **optimizer_kwargs)
|
||||
|
||||
# Helper function to simulate all-reduce for reference model gradients
|
||||
def simulate_all_reduce_grads(model, group):
|
||||
"""Simulate all-reduce operation on gradients like replicate does"""
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
# Scale by the number of replicas (like replicate does)
|
||||
param.grad.div_(group.size())
|
||||
# Simulate all-reduce
|
||||
torch.distributed.all_reduce(param.grad, group=group)
|
||||
|
||||
ref_model_parameter_dict = {}
|
||||
ref_model_parameter_dict = dict(ref_model.named_parameters())
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
for _ in range(5):
|
||||
for optimizer in optimizers:
|
||||
optimizer.zero_grad()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
inputs = torch.rand((num_microbatches, dim), device=self.device)
|
||||
labels = torch.rand((num_microbatches, dim), device=self.device)
|
||||
|
||||
# Ensure all ranks use the same inputs/labels for comparison
|
||||
torch.distributed.broadcast(inputs, 0)
|
||||
torch.distributed.broadcast(labels, 0)
|
||||
|
||||
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
|
||||
|
||||
# Run pipeline schedule
|
||||
if pp_mesh.get_local_rank() == 0:
|
||||
pipeline_schedule.step(inputs)
|
||||
elif is_last_stage:
|
||||
losses = []
|
||||
pipeline_schedule.step(target=labels, losses=losses)
|
||||
else:
|
||||
pipeline_schedule.step()
|
||||
|
||||
# Run reference model simulation
|
||||
if is_last_stage:
|
||||
ref_loss = simulate_stage_forward_backward(ref_model, inputs, labels)
|
||||
# Simulate all-reduce on reference model gradients
|
||||
simulate_all_reduce_grads(ref_model, dp_group)
|
||||
|
||||
# Compare losses - only check on last stage where we have losses
|
||||
if "losses" in locals() and len(losses) > 0:
|
||||
# Average the microbatch losses to match ref_loss
|
||||
avg_pipeline_loss = sum(losses) / len(losses)
|
||||
torch.testing.assert_close(
|
||||
avg_pipeline_loss, ref_loss, rtol=1e-4, atol=1e-5
|
||||
)
|
||||
else:
|
||||
# For non-last stages, still run ref model to generate gradients
|
||||
simulate_stage_forward_backward(ref_model, inputs, None)
|
||||
simulate_all_reduce_grads(ref_model, dp_group)
|
||||
|
||||
# Step optimizers
|
||||
for optimizer in optimizers:
|
||||
optimizer.step()
|
||||
ref_optimizer.step()
|
||||
|
||||
check_gradient_parity(
|
||||
pipeline_model_parameter_dict, ref_model_parameter_dict
|
||||
)
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ComposabilityTest)
|
||||
|
||||
|
||||
@ -54,9 +54,6 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
# Simple and boring model
|
||||
class TestDummyModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -75,12 +72,12 @@ class TestDummyModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(8, 8, device=device_type)
|
||||
return torch.rand(8, 8, device="cuda")
|
||||
|
||||
|
||||
class TestStatefulObj:
|
||||
def __init__(self) -> None:
|
||||
self.data = torch.rand(10, 10, device=device_type)
|
||||
self.data = torch.rand(10, 10, device="cuda")
|
||||
|
||||
def state_dict(self):
|
||||
return {"data": self.data}
|
||||
@ -154,11 +151,10 @@ def _train(model, optim, train_steps=1):
|
||||
class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
@property
|
||||
def backend(self):
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return f"cpu:gloo,{self.device_type}:{curr_backend}"
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
def _create_model(self, compile, model_type, state_dict_options=None):
|
||||
dummy_model = TestDummyModel().to(self.device_type)
|
||||
dummy_model = TestDummyModel().cuda()
|
||||
|
||||
assert model_type in ModelType, f"{model_type} is not supported."
|
||||
if model_type == ModelType.FSDP:
|
||||
@ -211,8 +207,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
def _optim(self, model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize("compile", [True, False])
|
||||
# TODO: Previously PairwiseParallel does not shard properly, passing ModelType.FSDP_TP test where it
|
||||
@ -221,8 +217,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
def test_e2e(self, compile, model_type):
|
||||
self._run_e2e_test(compile, model_type)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize(
|
||||
"cache_staged_state_dict, async_checkpointer_type, zoc",
|
||||
@ -382,9 +378,9 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
# Validate that the non-stateful state dict was replaced with the loaded state dict
|
||||
self.assertTrue(sd.set_sd_item_called)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_different_ordered_state_dict_keys(self):
|
||||
"""Tests that the order of keys in the state dict does not matter when loading
|
||||
If order was not accounted for, the following test would cause a deadlock.
|
||||
@ -398,11 +394,11 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
tl = [
|
||||
torch.ones(2, dtype=torch.int64, device=device_type)
|
||||
torch.ones(2, dtype=torch.int64, device="cuda")
|
||||
for _ in range(world_size)
|
||||
]
|
||||
t = (
|
||||
torch.arange(2, dtype=torch.int64, device=device_type)
|
||||
torch.arange(2, dtype=torch.int64, device="cuda")
|
||||
+ 1
|
||||
+ 2 * dist.get_rank()
|
||||
)
|
||||
@ -414,7 +410,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
tensor = (
|
||||
torch.arange(2, dtype=torch.int64, device=device_type)
|
||||
torch.arange(2, dtype=torch.int64, device="cuda")
|
||||
+ 1
|
||||
+ 2 * dist.get_rank()
|
||||
)
|
||||
@ -441,8 +437,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
DCP.save({}, checkpoint_id=self.temp_dir)
|
||||
DCP.load({}, checkpoint_id=self.temp_dir)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_partial_load(self):
|
||||
model, optim = self._create_model(compile=False, model_type=ModelType.NONE)
|
||||
@ -480,8 +476,8 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
||||
loaded_optim_state[k][optim_key], v[optim_key], offload_to_cpu=True
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_overwrite(self):
|
||||
t1, t2 = torch.randn(10), torch.randn(10)
|
||||
|
||||
@ -82,23 +82,22 @@ class FineTuningModel(nn.Module):
|
||||
class TestFineTuning(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.accelerator.device_count())
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return f"cpu:gloo,{self.device_type}:{curr_backend}"
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
def pretrain(self, pretrain_dir: str) -> None:
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
model = PreTrainedModel().to(self.device_type)
|
||||
model = PreTrainedModel().cuda()
|
||||
model = FSDP(model, device_mesh=device_mesh)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
# Training
|
||||
for _ in range(3):
|
||||
batch = torch.rand(32, DIM, device=self.device_type)
|
||||
batch = torch.rand(32, DIM, device="cuda")
|
||||
loss = model(batch).sum()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
@ -115,7 +114,7 @@ class TestFineTuning(DTensorTestBase):
|
||||
def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
model = FineTuningModel().to(self.device_type)
|
||||
model = FineTuningModel().cuda()
|
||||
# TODO: make the parallelism more complicated, e.g., using 2D + DDP.
|
||||
model = FSDP(model, use_orig_params=True, device_mesh=device_mesh)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
@ -163,7 +162,7 @@ class TestFineTuning(DTensorTestBase):
|
||||
|
||||
# Training
|
||||
for _ in range(3):
|
||||
batch = torch.rand(32, DIM, device=self.device_type)
|
||||
batch = torch.rand(32, DIM, device="cuda")
|
||||
loss = model(batch).sum()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
@ -61,13 +61,13 @@ class TopModel(nn.Module):
|
||||
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.accelerator.device_count())
|
||||
return min(8, torch.cuda.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@with_temp_dir
|
||||
def test_e2e(self):
|
||||
model = TopModel(self.rank).to(self.device_type)
|
||||
model = TopModel(self.rank).cuda()
|
||||
|
||||
mesh_fsdp_tp = init_device_mesh(
|
||||
self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
|
||||
|
||||
@ -32,13 +32,10 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
from torch.utils._pytree import tree_all_only
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.accelerator.device_count())
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
def _get_base_model(self, mlp_dim: int = 2):
|
||||
base_model = nn.Sequential(
|
||||
@ -76,7 +73,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
for module in model2:
|
||||
fully_shard(module, reshard_after_forward=False)
|
||||
fully_shard(model2, reshard_after_forward=False)
|
||||
inp = torch.randn((2, mlp_dim), device=device_type)
|
||||
inp = torch.randn((2, mlp_dim), device="cuda")
|
||||
model2(inp) # parameters are not resharded after this forward
|
||||
# Check that state dict hooks reshard
|
||||
osd_2 = model2.state_dict()
|
||||
@ -134,7 +131,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
|
||||
# Save state dict with model wrapped with FSDP1
|
||||
fsdp1_model = FSDP(
|
||||
self._get_base_model().to(device_type),
|
||||
self._get_base_model().cuda(),
|
||||
use_orig_params=True,
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
)
|
||||
@ -210,14 +207,14 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
# init device mesh
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
device_type,
|
||||
"cuda",
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
|
||||
# Save state dict with original model
|
||||
base_model = _get_base_model().to(device_type)
|
||||
base_model = _get_base_model().cuda()
|
||||
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
|
||||
|
||||
# Save state dict with model wrapped with FSDP1
|
||||
@ -344,17 +341,15 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
# init device mesh
|
||||
dp_size = 2
|
||||
global_mesh_1d = init_device_mesh(
|
||||
device_type, (self.world_size,), mesh_dim_names=("tp",)
|
||||
"cuda", (self.world_size,), mesh_dim_names=("tp",)
|
||||
)
|
||||
global_mesh_2d = init_device_mesh(
|
||||
device_type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
|
||||
|
||||
# Save state dict with original model
|
||||
base_model = _get_base_model().to(device_type)
|
||||
base_model = _get_base_model().cuda()
|
||||
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
|
||||
|
||||
# Save state dict with TP model
|
||||
@ -500,10 +495,10 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
# init device mesh
|
||||
dp_size = 2
|
||||
global_mesh_1d = init_device_mesh(
|
||||
device_type, (self.world_size,), mesh_dim_names=("tp",)
|
||||
"cuda", (self.world_size,), mesh_dim_names=("tp",)
|
||||
)
|
||||
global_mesh_2d = init_device_mesh(
|
||||
device_type,
|
||||
"cuda",
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
@ -511,7 +506,7 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
|
||||
for save_full_state_dict in [True, False]:
|
||||
# Save state dict with original model
|
||||
base_model = _get_base_model(mlp_dim).to(device_type)
|
||||
base_model = _get_base_model(mlp_dim).cuda()
|
||||
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
|
||||
|
||||
# Save state dict with FSDP2 + TP model
|
||||
|
||||
@ -32,10 +32,7 @@ from torch.distributed.checkpoint.planner import (
|
||||
)
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
@ -43,9 +40,6 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
@ -68,8 +62,8 @@ class TestModule(torch.nn.Module):
|
||||
return ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -81,12 +75,12 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_tensor_metadata_with_missing_rank_spec(self) -> None:
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -98,14 +92,14 @@ class TestDistributedCheckpointing(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_default_metadata(self) -> None:
|
||||
device = f"{device_type}:{dist.get_rank()}"
|
||||
device = f"cuda:{dist.get_rank()}"
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -239,14 +233,12 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
||||
def get_spec(self):
|
||||
return ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:{r}/{device_type}:{r}" for r in range(dist.get_world_size())
|
||||
],
|
||||
placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_dummy_writer_works(self) -> None:
|
||||
state_dict = {
|
||||
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
|
||||
@ -258,7 +250,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_dummy_reader_works(self) -> None:
|
||||
state_dict = {
|
||||
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
|
||||
@ -321,7 +313,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_save_error_handling(self) -> None:
|
||||
state_dict = {
|
||||
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
|
||||
@ -355,7 +347,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_load_error_handling(self) -> None:
|
||||
state_dict = {
|
||||
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
|
||||
|
||||
@ -106,7 +106,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
replicated_dt,
|
||||
submesh_sharded_dt,
|
||||
submesh_replicated_dt,
|
||||
).to(self.device_type)
|
||||
).cuda()
|
||||
|
||||
return (
|
||||
model,
|
||||
@ -135,7 +135,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'rdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([4., 5., 6., 7.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([4., 5., 6., 7.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 1, 2, 3]),
|
||||
placements=[Replicate()]
|
||||
)
|
||||
@ -143,7 +143,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'sdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([0.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([0.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 1, 2, 3]),
|
||||
placements=[Shard(dim=0)])
|
||||
),
|
||||
@ -151,7 +151,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'submesh_sdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([8., 9.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([8., 9.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 2]),
|
||||
placements=[Shard(dim=0)]
|
||||
),
|
||||
@ -159,7 +159,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'submesh_rdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([12., 13., 14., 15.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([12., 13., 14., 15.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 2]),
|
||||
placements=[Replicate()]
|
||||
)
|
||||
@ -189,7 +189,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'rdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([40., 50., 60., 70.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([40., 50., 60., 70.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 1, 2, 3]),
|
||||
placements=[Replicate()],
|
||||
)
|
||||
@ -197,7 +197,7 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'sdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([0.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([0.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 1, 2, 3]),
|
||||
placements=[Shard(dim=0)],
|
||||
)
|
||||
@ -205,14 +205,14 @@ class DTensorPlanner(DTensorTestBase):
|
||||
(
|
||||
'submesh_sdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([80., 90.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([80., 90.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 2]),
|
||||
placements=[Shard(dim=0)]
|
||||
)
|
||||
),
|
||||
('submesh_rdt',
|
||||
DTensor(
|
||||
local_tensor=tensor([120., 130., 140., 150.], device=f'{self.device_type}:0'),
|
||||
local_tensor=tensor([120., 130., 140., 150.], device='cuda:0'),
|
||||
device_mesh=DeviceMesh:([0, 2]),
|
||||
placements=[Replicate()]
|
||||
)
|
||||
|
||||
@ -278,7 +278,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
"""
|
||||
Test dtensor checkpoint resharding with dtensor containing empty shards.
|
||||
"""
|
||||
tensor = torch.rand(1).to(self.device_type)
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
ref_state_dict = {"dtensor": dtensor}
|
||||
@ -288,7 +288,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
storage_writer=dist_cp.FileSystemWriter(path=self.temp_dir),
|
||||
)
|
||||
|
||||
tensor = torch.rand(1).to(self.device_type)
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
|
||||
state_dict = {"dtensor": dtensor}
|
||||
|
||||
@ -23,10 +23,7 @@ from torch.distributed.checkpoint import (
|
||||
)
|
||||
from torch.distributed.checkpoint._extension import ZStandard
|
||||
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -48,9 +45,6 @@ from torch.testing._internal.distributed.checkpoint_utils import (
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
@ -172,7 +166,7 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@parametrize("extensions", [None, [Rot13Example()], [ZStandard()]])
|
||||
def test_read_write_shard_tensor(self, extensions) -> None:
|
||||
paths = [tempfile.mkdtemp()]
|
||||
@ -184,8 +178,8 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -234,16 +228,14 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
|
||||
def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
|
||||
res = (
|
||||
torch.zeros(tensor.shape, device=f"{device_type}:0")
|
||||
if dist.get_rank() == 0
|
||||
else None
|
||||
torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
|
||||
)
|
||||
tensor.gather(out=res)
|
||||
return res
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_load_with_different_shard_plan(self) -> None:
|
||||
path = self.get_file_path()
|
||||
|
||||
@ -255,18 +247,18 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
),
|
||||
# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
|
||||
ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
f"rank:1/{device_type}:1",
|
||||
f"rank:0/{device_type}:0",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:1/cuda:1",
|
||||
"rank:0/cuda:0",
|
||||
],
|
||||
),
|
||||
# This requires the tensors to be [10, 20]
|
||||
@ -275,27 +267,27 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 0],
|
||||
shard_sizes=[2, 20],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[2, 0],
|
||||
shard_sizes=[1, 20],
|
||||
placement=f"rank:1/{device_type}:1",
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[3, 0],
|
||||
shard_sizes=[3, 20],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[6, 0],
|
||||
shard_sizes=[3, 20],
|
||||
placement=f"rank:1/{device_type}:1",
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[9, 0],
|
||||
shard_sizes=[1, 20],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
]
|
||||
),
|
||||
@ -305,12 +297,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 0],
|
||||
shard_sizes=[8, 20],
|
||||
placement=f"rank:1/{device_type}:1",
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[8, 0],
|
||||
shard_sizes=[2, 20],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
]
|
||||
),
|
||||
@ -358,7 +350,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_load_rowwise_to_colwise(self) -> None:
|
||||
path = self.get_file_path()
|
||||
self.assertEqual(self.world_size, dist.get_world_size())
|
||||
@ -367,8 +359,8 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
src_spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -376,8 +368,8 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
dst_spec = ChunkShardingSpec(
|
||||
dim=1,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
@ -385,14 +377,14 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
os.makedirs(path)
|
||||
|
||||
model_to_save = MyShardedModel3(src_spec).to(dist.get_rank())
|
||||
model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank())
|
||||
model_to_save._register_state_dict_hook(state_dict_hook)
|
||||
state_dict_to_save = model_to_save.state_dict()
|
||||
|
||||
fs_writer = FileSystemWriter(path=path)
|
||||
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
|
||||
|
||||
model_to_load = MyShardedModel3(dst_spec).to(dist.get_rank())
|
||||
model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank())
|
||||
model_to_load._register_state_dict_hook(state_dict_hook)
|
||||
state_dict_to_load_to = model_to_load.state_dict()
|
||||
|
||||
@ -409,7 +401,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_save_load_bytes(self) -> None:
|
||||
path = self.get_file_path()
|
||||
|
||||
@ -428,7 +420,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
def test_switch_between_sharded_tensor_to_tensor(self) -> None:
|
||||
path = self.get_file_path()
|
||||
tensor_size = 32
|
||||
@ -437,17 +429,17 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
),
|
||||
ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
f"rank:1/{device_type}:1",
|
||||
f"rank:0/{device_type}:0",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
"rank:1/cuda:1",
|
||||
"rank:0/cuda:0",
|
||||
],
|
||||
),
|
||||
EnumerableShardingSpec(
|
||||
@ -455,12 +447,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ShardMetadata(
|
||||
shard_offsets=[0],
|
||||
shard_sizes=[8],
|
||||
placement=f"rank:1/{device_type}:1",
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[8],
|
||||
shard_sizes=[tensor_size - 8],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
]
|
||||
),
|
||||
@ -469,12 +461,12 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
||||
ShardMetadata(
|
||||
shard_offsets=[0],
|
||||
shard_sizes=[10],
|
||||
placement=f"rank:0/{device_type}:0",
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[10],
|
||||
shard_sizes=[tensor_size - 10],
|
||||
placement=f"rank:1/{device_type}:1",
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
]
|
||||
),
|
||||
@ -520,15 +512,15 @@ class TestDistributedStateDictSaveLoadWithCaching(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@with_temp_dir
|
||||
def test_read_write_shard_tensor(self) -> None:
|
||||
# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
|
||||
spec = ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:0/{device_type}:0",
|
||||
f"rank:1/{device_type}:1",
|
||||
"rank:0/cuda:0",
|
||||
"rank:1/cuda:1",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -22,9 +22,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
class SimpleModelUneven(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -43,7 +40,7 @@ class SimpleModelUneven(nn.Module):
|
||||
return x
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(4, 5, device=device_type)
|
||||
return torch.rand(4, 5, device="cuda")
|
||||
|
||||
|
||||
class TestFormatUtils(DTensorTestBase):
|
||||
@ -90,7 +87,7 @@ class TestFormatUtils(DTensorTestBase):
|
||||
|
||||
# Load into a sharded model
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model = SimpleModelUneven().to(self.device_type)
|
||||
model = SimpleModelUneven().cuda()
|
||||
model = FSDP(
|
||||
model,
|
||||
device_mesh=device_mesh,
|
||||
|
||||
@ -21,8 +21,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
class FsdpModelStateCheckpoint(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return f"cpu:gloo,{self.device_type}:{curr_backend}"
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
def _test_fsdp_model_state(self, process_group) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
@ -68,8 +67,8 @@ class FsdpModelStateCheckpoint(DTensorTestBase):
|
||||
self.assertEqual(model.weight, model_2.weight)
|
||||
self.assertEqual(model.bias, model_2.bias)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_fsdp_model_state_no_resharding(self):
|
||||
self._test_fsdp_model_state(process_group=None)
|
||||
@ -89,8 +88,8 @@ class FsdpModelStateCheckpoint(DTensorTestBase):
|
||||
|
||||
return my_fsdp
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_fsdp_model_state_with_resharding(self):
|
||||
self._test_fsdp_model_state(process_group=self._create_new_dist_group())
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
@ -29,9 +28,8 @@ class FsdpOptimStateCheckpoint(DTensorTestBase):
|
||||
layer3_weight_dim = self.world_size * 3
|
||||
|
||||
class TestDummyModel(torch.nn.Module):
|
||||
def __init__(self, device_type) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
self.net1 = nn.Sequential(nn.Linear(8, layer1_weight_dim), nn.ReLU())
|
||||
self.net2 = nn.Sequential(
|
||||
nn.Linear(layer1_weight_dim, layer2_weight_dim), nn.ReLU()
|
||||
@ -44,18 +42,17 @@ class FsdpOptimStateCheckpoint(DTensorTestBase):
|
||||
return self.net3(self.net2(self.net1(x)))
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(8, 8, device=self.device_type)
|
||||
return torch.rand(8, 8, device="cuda")
|
||||
|
||||
model = TestDummyModel(self.device_type).to(self.device_type)
|
||||
model = TestDummyModel().cuda()
|
||||
return model
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return f"cpu:gloo,{self.device_type}:{curr_backend}"
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
@parametrize("pass_planner", [True, False])
|
||||
def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:
|
||||
|
||||
@ -30,7 +30,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
|
||||
def test_fsdp_to_tp(self):
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
model = MLPModule(self.device_type).to(self.rank)
|
||||
model = MLPModule(self.device_type).cuda(self.rank)
|
||||
# create a FSDP wrapped model
|
||||
fsdp_model = FSDP(model, use_orig_params=True)
|
||||
|
||||
@ -49,7 +49,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
|
||||
# create a TP wrapped model
|
||||
mesh_shape = (self.world_size,)
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape)
|
||||
model = MLPModule(self.device_type).to(self.rank)
|
||||
model = MLPModule(self.device_type).cuda(self.rank)
|
||||
# Parallelize the module based on the given Parallel Style.
|
||||
parallelize_plan = {
|
||||
"net1": ColwiseParallel(),
|
||||
@ -60,7 +60,7 @@ class TestFsdpTpCheckpointConversion(DTensorTestBase):
|
||||
|
||||
# Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict().
|
||||
torch.manual_seed(0)
|
||||
inp = torch.rand(20, 10).to(self.rank)
|
||||
inp = torch.rand(20, 10).cuda(self.rank)
|
||||
output = tp_model(inp)
|
||||
output.sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
@ -587,7 +587,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
print("safetensors not installed")
|
||||
return
|
||||
|
||||
tensor = torch.rand(1).to(self.device_type)
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
ref_state_dict = {"dtensor": dtensor}
|
||||
@ -599,7 +599,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
),
|
||||
)
|
||||
|
||||
tensor = torch.rand(1).to(self.device_type)
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
|
||||
state_dict = {"dtensor": dtensor}
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -30,9 +29,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -48,7 +44,7 @@ class SimpleModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(4, 5, device=device_type)
|
||||
return torch.rand(4, 5, device="cuda")
|
||||
|
||||
|
||||
class SimpleModelUneven(torch.nn.Module):
|
||||
@ -68,17 +64,16 @@ class SimpleModelUneven(torch.nn.Module):
|
||||
return x
|
||||
|
||||
def get_input(self):
|
||||
return torch.rand(4, 5, device=device_type)
|
||||
return torch.rand(4, 5, device="cuda")
|
||||
|
||||
|
||||
class TestHSDPCheckpoint(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
curr_backend = dist.get_default_backend_for_device(self.device_type)
|
||||
return f"cpu:gloo,{self.device_type}:{curr_backend}"
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
|
||||
@ -87,7 +82,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
|
||||
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
model = FSDP(
|
||||
simple_model().to(self.device_type),
|
||||
simple_model().cuda(),
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
device_mesh=mesh_2d,
|
||||
)
|
||||
@ -135,8 +130,8 @@ class TestHSDPCheckpoint(DTensorTestBase):
|
||||
self.assertEqual(v1.placements, v2.placements)
|
||||
self.assertEqual(v1.to_local(), v2.to_local())
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
@parametrize("is_even_sharded_model", [True, False])
|
||||
def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
|
||||
@ -146,7 +141,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
|
||||
# save the hsdp model state_dict
|
||||
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
hsdp_model = FSDP(
|
||||
simple_model().to(self.device_type),
|
||||
simple_model().cuda(),
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
device_mesh=mesh_2d,
|
||||
)
|
||||
@ -164,7 +159,7 @@ class TestHSDPCheckpoint(DTensorTestBase):
|
||||
# initialize a fsdp model to load checkpoint into
|
||||
mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
|
||||
fsdp_model = FSDP(
|
||||
simple_model().to(self.device_type),
|
||||
simple_model().cuda(),
|
||||
device_mesh=mesh_1d,
|
||||
)
|
||||
FSDP.set_state_dict_type(
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
init_from_local_shards,
|
||||
@ -25,11 +23,10 @@ from torch.distributed.checkpoint._pg_transport import (
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
at_least_x_gpu,
|
||||
HAS_ACCELERATOR,
|
||||
MultiProcContinuousTest,
|
||||
requires_accelerator_dist_backend,
|
||||
requires_nccl,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
@ -38,8 +35,6 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -165,9 +160,9 @@ def _test_pg_transport_with_mixed_content(self, device) -> None:
|
||||
|
||||
|
||||
def _test_pg_transport_with_sharded_tensor(self, device) -> None:
|
||||
# Set current accelerator device for NCCL/XCCL
|
||||
if device.type == "cuda" or device.type == "xpu":
|
||||
torch.accelerator.set_device_index(device)
|
||||
# Set current CUDA device for NCCL
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
state_dict = _create_sharded_tensor_state_dict(self.rank, self.world_size, device)
|
||||
transport = PGTransport(_get_default_group(), timedelta(seconds=10), device)
|
||||
@ -232,36 +227,34 @@ class PgTransportCPU(MultiProcContinuousTest):
|
||||
_test_pg_transport_with_sharded_tensor(self, self.device)
|
||||
|
||||
|
||||
class PgTransportGPU(MultiProcContinuousTest):
|
||||
class PgTransportCUDA(MultiProcContinuousTest):
|
||||
world_size = 2
|
||||
timeout: timedelta = timedelta(seconds=20)
|
||||
|
||||
@classmethod
|
||||
def backend_str(cls) -> Optional[str]:
|
||||
return dist.get_default_backend_for_device(cls.device_type())
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def device_type(cls) -> str:
|
||||
return "cuda"
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(f"{self.device_type()}:{self.rank}")
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not at_least_x_gpu(2), "test requires 2+ accelerators"
|
||||
)
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_pg_transport(self) -> None:
|
||||
_test_pg_transport(self, self.device)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not at_least_x_gpu(2), "test requires 2+ accelerators"
|
||||
)
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_pg_transport_with_mixed_content(self) -> None:
|
||||
_test_pg_transport_with_mixed_content(self, self.device)
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not at_least_x_gpu(2), "test requires 2+ accelerators"
|
||||
)
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_pg_transport_with_sharded_tensor(self) -> None:
|
||||
_test_pg_transport_with_sharded_tensor(self, self.device)
|
||||
|
||||
@ -585,10 +578,13 @@ class TestPGTransportEdgeCases(TestCase):
|
||||
self.pg.send = MagicMock(return_value=self.mock_work)
|
||||
self.pg.recv = MagicMock(return_value=self.mock_work)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
def test_send_checkpoint_with_cpu_tensors(self):
|
||||
"""Test send_checkpoint with CPU tensors when device is accelerator."""
|
||||
device = torch.device(f"{device_type}:0")
|
||||
"""Test send_checkpoint with CPU tensors when device is CUDA."""
|
||||
# Skip if CUDA is not available
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("CUDA not available")
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Create a state dict with CPU tensors
|
||||
state_dict = {
|
||||
@ -596,7 +592,7 @@ class TestPGTransportEdgeCases(TestCase):
|
||||
"cpu_tensor2": torch.randn(3, 4),
|
||||
}
|
||||
|
||||
# Create transport with accelerator device
|
||||
# Create transport with CUDA device
|
||||
transport = PGTransport(self.pg, self.timeout, device)
|
||||
|
||||
# Call send_checkpoint
|
||||
|
||||
@ -37,7 +37,7 @@ class TestSaveAndLoadAPI(DTensorTestBase):
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_auto_detect(self):
|
||||
model = FSDP(MyTestModule().to(self.device_type))
|
||||
model = FSDP(MyTestModule().cuda())
|
||||
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
model = FSDP(model, device_mesh=device_mesh)
|
||||
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
|
||||
import torch
|
||||
@ -19,21 +18,14 @@ from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.checkpoint._state_dict_stager import StateDictStager
|
||||
from torch.distributed.checkpoint.staging import _ReplicationStager
|
||||
from torch.distributed.tensor import DeviceMesh, distribute_tensor
|
||||
from torch.testing._internal.common_distributed import (
|
||||
HAS_ACCELERATOR,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def create_cpu_state_dict(state_dict):
|
||||
cpu_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
@ -41,16 +33,16 @@ def create_cpu_state_dict(state_dict):
|
||||
return cpu_state_dict
|
||||
|
||||
|
||||
def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||
def compare_state_dicts(cuda_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||
"""
|
||||
Compare if two state dictionaries (one on GPU, one on CPU) are otherwise the same.
|
||||
Compare if two state dictionaries (one on CUDA, one on CPU) are otherwise the same.
|
||||
|
||||
This function checks if the tensors in both state dictionaries have the same values,
|
||||
shapes, dtypes, etc., ignoring the device difference. It also checks if tensors that
|
||||
share storage in one state dict also share storage in the other.
|
||||
|
||||
Args:
|
||||
gpu_state_dict: The state dictionary with tensors on GPU
|
||||
cuda_state_dict: The state dictionary with tensors on CUDA
|
||||
cpu_state_dict: The state dictionary with tensors on CPU
|
||||
rtol: Relative tolerance for comparing tensor values
|
||||
atol: Absolute tolerance for comparing tensor values
|
||||
@ -60,65 +52,65 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||
str: Error message if the state dictionaries are not equivalent, empty string otherwise
|
||||
"""
|
||||
# Track storage data pointers to check storage sharing
|
||||
gpu_storage_ptrs = {}
|
||||
cuda_storage_ptrs = {}
|
||||
cpu_storage_ptrs = {}
|
||||
|
||||
def compare_objects(gpu_obj, cpu_obj, path=""):
|
||||
def compare_objects(cuda_obj, cpu_obj, path=""):
|
||||
# If objects are tensors, compare them
|
||||
if isinstance(gpu_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor):
|
||||
if isinstance(cuda_obj, torch.Tensor) and isinstance(cpu_obj, torch.Tensor):
|
||||
# Check if devices are as expected
|
||||
if gpu_obj.device.type != device_type:
|
||||
if cuda_obj.device.type != "cuda":
|
||||
return (
|
||||
False,
|
||||
f"Expected accelerator tensor, got {gpu_obj.device.type} tensor at {path}",
|
||||
f"Expected CUDA tensor, got {cuda_obj.device.type} tensor at {path}",
|
||||
)
|
||||
if cpu_obj.device.type != "cpu":
|
||||
return (
|
||||
False,
|
||||
f"Expected CPU tensor, got {cpu_obj.device.type} tensor at {path}",
|
||||
)
|
||||
if gpu_obj.storage_offset() != cpu_obj.storage_offset():
|
||||
if cuda_obj.storage_offset() != cpu_obj.storage_offset():
|
||||
return (
|
||||
False,
|
||||
f"Storage offset mismatch at {path}: {gpu_obj.storage_offset()} vs {cpu_obj.storage_offset()}",
|
||||
f"Storage offset mismatch at {path}: {cuda_obj.storage_offset()} vs {cpu_obj.storage_offset()}",
|
||||
)
|
||||
|
||||
if not torch.equal(gpu_obj.cpu(), cpu_obj):
|
||||
if not torch.equal(cuda_obj.cpu(), cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Tensors are not same at {path}",
|
||||
)
|
||||
|
||||
# Track storage sharing
|
||||
gpu_storage_ptr = gpu_obj.storage().data_ptr()
|
||||
cuda_storage_ptr = cuda_obj.storage().data_ptr()
|
||||
cpu_storage_ptr = cpu_obj.storage().data_ptr()
|
||||
|
||||
if gpu_storage_ptr in gpu_storage_ptrs:
|
||||
# This GPU tensor shares storage with another tensor
|
||||
if cuda_storage_ptr in cuda_storage_ptrs:
|
||||
# This CUDA tensor shares storage with another tensor
|
||||
# Check if the corresponding CPU tensors also share storage
|
||||
if cpu_storage_ptr != gpu_storage_ptrs[gpu_storage_ptr]:
|
||||
if cpu_storage_ptr != cuda_storage_ptrs[cuda_storage_ptr]:
|
||||
return (
|
||||
False,
|
||||
f"Storage sharing mismatch: GPU tensors share storage but CPU tensors don't at {path}",
|
||||
f"Storage sharing mismatch: CUDA tensors share storage but CPU tensors don't at {path}",
|
||||
)
|
||||
else:
|
||||
# First time seeing this storage
|
||||
gpu_storage_ptrs[gpu_storage_ptr] = cpu_storage_ptr
|
||||
cpu_storage_ptrs[cpu_storage_ptr] = gpu_storage_ptr
|
||||
cuda_storage_ptrs[cuda_storage_ptr] = cpu_storage_ptr
|
||||
cpu_storage_ptrs[cpu_storage_ptr] = cuda_storage_ptr
|
||||
|
||||
return True, ""
|
||||
|
||||
# If objects are dictionaries, compare them recursively
|
||||
elif isinstance(gpu_obj, dict) and isinstance(cpu_obj, dict):
|
||||
if gpu_obj.keys() != cpu_obj.keys():
|
||||
elif isinstance(cuda_obj, dict) and isinstance(cpu_obj, dict):
|
||||
if cuda_obj.keys() != cpu_obj.keys():
|
||||
return (
|
||||
False,
|
||||
f"Dictionary keys mismatch at {path}: {gpu_obj.keys()} vs {cpu_obj.keys()}",
|
||||
f"Dictionary keys mismatch at {path}: {cuda_obj.keys()} vs {cpu_obj.keys()}",
|
||||
)
|
||||
|
||||
for key in gpu_obj:
|
||||
for key in cuda_obj:
|
||||
result, error = compare_objects(
|
||||
gpu_obj[key], cpu_obj[key], f"{path}.{key}" if path else key
|
||||
cuda_obj[key], cpu_obj[key], f"{path}.{key}" if path else key
|
||||
)
|
||||
if not result:
|
||||
return False, error
|
||||
@ -126,37 +118,37 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||
return True, ""
|
||||
|
||||
# If objects are lists, tuples, or sets, compare them recursively
|
||||
elif isinstance(gpu_obj, (list, tuple, set)) and isinstance(
|
||||
elif isinstance(cuda_obj, (list, tuple, set)) and isinstance(
|
||||
cpu_obj, (list, tuple, set)
|
||||
):
|
||||
if len(gpu_obj) != len(cpu_obj):
|
||||
if len(cuda_obj) != len(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Collection length mismatch at {path}: {len(gpu_obj)} vs {len(cpu_obj)}",
|
||||
f"Collection length mismatch at {path}: {len(cuda_obj)} vs {len(cpu_obj)}",
|
||||
)
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
if type(cuda_obj) != type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Collection type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
f"Collection type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||
)
|
||||
|
||||
for i, (gpu_item, cpu_item) in enumerate(zip(gpu_obj, cpu_obj)):
|
||||
result, error = compare_objects(gpu_item, cpu_item, f"{path}[{i}]")
|
||||
for i, (cuda_item, cpu_item) in enumerate(zip(cuda_obj, cpu_obj)):
|
||||
result, error = compare_objects(cuda_item, cpu_item, f"{path}[{i}]")
|
||||
if not result:
|
||||
return False, error
|
||||
|
||||
return True, ""
|
||||
|
||||
# If objects are custom classes, compare their attributes
|
||||
elif hasattr(gpu_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
elif hasattr(cuda_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
|
||||
if type(cuda_obj) != type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Object type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
f"Object type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||
)
|
||||
|
||||
result, error = compare_objects(
|
||||
gpu_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__"
|
||||
cuda_obj.__dict__, cpu_obj.__dict__, f"{path}.__dict__"
|
||||
)
|
||||
if not result:
|
||||
return False, error
|
||||
@ -165,18 +157,18 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
||||
|
||||
# For other types, use direct equality comparison
|
||||
else:
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
if type(cuda_obj) != type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
f"Type mismatch at {path}: {type(cuda_obj)} vs {type(cpu_obj)}",
|
||||
)
|
||||
if gpu_obj != cpu_obj:
|
||||
return False, f"Value mismatch at {path}: {gpu_obj} vs {cpu_obj}"
|
||||
if cuda_obj != cpu_obj:
|
||||
return False, f"Value mismatch at {path}: {cuda_obj} vs {cpu_obj}"
|
||||
|
||||
return True, ""
|
||||
|
||||
# Start the recursive comparison
|
||||
result, error = compare_objects(gpu_state_dict, cpu_state_dict)
|
||||
result, error = compare_objects(cuda_state_dict, cpu_state_dict)
|
||||
return result, error
|
||||
|
||||
|
||||
@ -206,7 +198,7 @@ class FrozenDataClass:
|
||||
|
||||
|
||||
class TestStateDictStager(TestCase):
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_views(self):
|
||||
test_configs = [
|
||||
(False, False), # pin_memory=False, share_memory=False,
|
||||
@ -216,9 +208,9 @@ class TestStateDictStager(TestCase):
|
||||
]
|
||||
for pin_memory, share_memory in test_configs:
|
||||
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||
tensor1 = torch.randn(4, 4).to(device_type)
|
||||
tensor1 = torch.randn(4, 4).cuda()
|
||||
tensor2 = tensor1.view(16)
|
||||
tensor3 = torch.randn(4, 4).to(device_type)
|
||||
tensor3 = torch.randn(4, 4).cuda()
|
||||
state_dict = {
|
||||
"tensor1": tensor1,
|
||||
"tensor2": tensor2,
|
||||
@ -261,7 +253,7 @@ class TestStateDictStager(TestCase):
|
||||
assert num_bytes == expected_bytes, (
|
||||
f"Expected {expected_bytes} bytes, got {num_bytes}"
|
||||
)
|
||||
# Verify that the CPU state dict is equivalent to the original GPU state dict
|
||||
# Verify that the CPU state dict is equivalent to the original CUDA state dict
|
||||
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
||||
assert result, f"State dicts are not equivalent: {error}"
|
||||
|
||||
@ -281,7 +273,7 @@ class TestStateDictStager(TestCase):
|
||||
== recursive["type"].tensor1.storage().data_ptr()
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_caching(self):
|
||||
"""
|
||||
Test that the StateDictStager correctly caches and reuses storages.
|
||||
@ -295,9 +287,9 @@ class TestStateDictStager(TestCase):
|
||||
for pin_memory, share_memory in test_configs:
|
||||
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||
# Create test tensors and state dict
|
||||
tensor1 = torch.randn(4, 4).to(device_type)
|
||||
tensor1 = torch.randn(4, 4).cuda()
|
||||
tensor2 = tensor1.view(16)
|
||||
tensor3 = torch.randn(4, 4).to(device_type)
|
||||
tensor3 = torch.randn(4, 4).cuda()
|
||||
state_dict = {
|
||||
"tensor1": tensor1,
|
||||
"tensor2": tensor2,
|
||||
@ -373,14 +365,14 @@ class TestStateDictStager(TestCase):
|
||||
"Updated values should be reflected in the cached state dict"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_tensor_attrs(self):
|
||||
"""
|
||||
Test that tensor attributes are preserved during stage with StateDictStager.
|
||||
"""
|
||||
tensor1 = torch.randn(4, 4).to(device_type)
|
||||
tensor1 = torch.randn(4, 4).cuda()
|
||||
tensor2 = tensor1.view(16)
|
||||
tensor3 = torch.randn(4, 4).to(device_type)
|
||||
tensor3 = torch.randn(4, 4).cuda()
|
||||
|
||||
# Add custom attributes to tensors
|
||||
tensor1.a = 42
|
||||
@ -419,22 +411,18 @@ class TestStateDictStager(TestCase):
|
||||
"Tensor attribute 'c' has incorrect value"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_different_dtypes(self):
|
||||
"""
|
||||
Test that StateDictStager works correctly with tensors of different data types.
|
||||
"""
|
||||
# Create tensors with different dtypes
|
||||
tensors = {
|
||||
"float32": torch.randn(4, 4, dtype=torch.float32).to(device_type),
|
||||
"float64": torch.randn(4, 4, dtype=torch.float64).to(device_type),
|
||||
"int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).to(
|
||||
device_type
|
||||
),
|
||||
"int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).to(
|
||||
device_type
|
||||
),
|
||||
"bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).to(device_type),
|
||||
"float32": torch.randn(4, 4, dtype=torch.float32).cuda(),
|
||||
"float64": torch.randn(4, 4, dtype=torch.float64).cuda(),
|
||||
"int32": torch.randint(-100, 100, (4, 4), dtype=torch.int32).cuda(),
|
||||
"int64": torch.randint(-100, 100, (4, 4), dtype=torch.int64).cuda(),
|
||||
"bool": torch.randint(0, 2, (4, 4), dtype=torch.bool).cuda(),
|
||||
}
|
||||
|
||||
# Create a state dict with these tensors
|
||||
@ -459,7 +447,7 @@ class TestStateDictStager(TestCase):
|
||||
f"Tensor {dtype_name} has incorrect values",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_empty_tensors(self):
|
||||
"""
|
||||
Test that StateDictStager works correctly with empty tensors.
|
||||
@ -474,17 +462,15 @@ class TestStateDictStager(TestCase):
|
||||
with self.subTest(pin_memory=pin_memory, share_memory=share_memory):
|
||||
# Create empty tensors with different shapes
|
||||
tensors = {
|
||||
"empty_0d": torch.tensor([], dtype=torch.float32).to(device_type),
|
||||
"empty_1d": torch.tensor([], dtype=torch.float32)
|
||||
.reshape(0)
|
||||
.to(device_type),
|
||||
"empty_0d": torch.tensor([], dtype=torch.float32).cuda(),
|
||||
"empty_1d": torch.tensor([], dtype=torch.float32).reshape(0).cuda(),
|
||||
"empty_2d": torch.tensor([], dtype=torch.float32)
|
||||
.reshape(0, 0)
|
||||
.to(device_type),
|
||||
.cuda(),
|
||||
"empty_3d": torch.tensor([], dtype=torch.float32)
|
||||
.reshape(0, 0, 0)
|
||||
.to(device_type),
|
||||
"zero_dim": torch.tensor(0.0).to(device_type), # scalar tensor
|
||||
.cuda(),
|
||||
"zero_dim": torch.tensor(0.0).cuda(), # scalar tensor
|
||||
}
|
||||
|
||||
# Create a state dict with these tensors
|
||||
@ -514,13 +500,13 @@ class TestStateDictStager(TestCase):
|
||||
f"Tensor {tensor_name} has incorrect dtype",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_complex_storage_sharing(self):
|
||||
"""
|
||||
Test that StateDictStager correctly handles complex storage sharing scenarios.
|
||||
"""
|
||||
# Create a base tensor
|
||||
base_tensor = torch.randn(10, 10).to(device_type)
|
||||
base_tensor = torch.randn(10, 10).cuda()
|
||||
|
||||
# Create various views and slices that share storage
|
||||
view1 = base_tensor.view(100)
|
||||
@ -596,13 +582,13 @@ class TestStateDictStager(TestCase):
|
||||
"slice3 should reflect changes to base",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_dataclasses(self):
|
||||
# Create tensors
|
||||
tensor1 = torch.randn(4, 4).to(device_type)
|
||||
tensor2 = torch.randn(8, 8).to(device_type)
|
||||
tensor3 = torch.randn(2, 6).to(device_type)
|
||||
tensor4 = torch.randn(3, 5).to(device_type)
|
||||
tensor1 = torch.randn(4, 4).cuda()
|
||||
tensor2 = torch.randn(8, 8).cuda()
|
||||
tensor3 = torch.randn(2, 6).cuda()
|
||||
tensor4 = torch.randn(3, 5).cuda()
|
||||
|
||||
# Create dataclass instances
|
||||
nested = NestedTensorStruct(tensor=tensor3)
|
||||
@ -709,14 +695,14 @@ class TestStateDictStager(TestCase):
|
||||
"CPU tensor should have the same values as the original tensor",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_ACCELERATOR, "No accelerator")
|
||||
@requires_cuda
|
||||
def test_tensor_pinned_and_shared(self):
|
||||
"""
|
||||
Test that verifies tensors are actually pinned and shared using tensor.is_pinned() and tensor.is_shared() methods.
|
||||
"""
|
||||
# Create test tensors
|
||||
tensor1 = torch.randn(4, 4).to(device_type)
|
||||
tensor2 = torch.randn(8, 8).to(device_type)
|
||||
tensor1 = torch.randn(4, 4).cuda()
|
||||
tensor2 = torch.randn(8, 8).cuda()
|
||||
|
||||
# Create a state dict with these tensors
|
||||
state_dict = {
|
||||
@ -811,17 +797,15 @@ class TestStateDictStager(TestCase):
|
||||
|
||||
class TestDTensorStateDictStager(DTensorTestBase):
|
||||
@with_comms
|
||||
@requires_accelerator_dist_backend()
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_dtensor(self):
|
||||
"""
|
||||
Test that StateDictStager works correctly with DTensors.
|
||||
"""
|
||||
# Create a DTensor
|
||||
device_mesh = dist.DeviceMesh(
|
||||
self.device_type, list(range(dist.get_world_size()))
|
||||
)
|
||||
tensor = torch.randn(3, 3, device=self.device_type)
|
||||
device_mesh = dist.DeviceMesh("cuda", list(range(dist.get_world_size())))
|
||||
tensor = torch.randn(3, 3, device="cuda")
|
||||
dtensor = DTensor.from_local(tensor, device_mesh, [Shard(0)])
|
||||
|
||||
dtensor = dtensor + 1
|
||||
|
||||
@ -47,7 +47,7 @@ class TestTpCheckpoint(DTensorTestBase):
|
||||
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
|
||||
|
||||
# create model and move it to GPU with id rank
|
||||
model = MLPModule(self.device_type).to(self.rank)
|
||||
model = MLPModule(self.device_type).cuda(self.rank)
|
||||
# Parallelize the module based on the given Parallel Style.
|
||||
parallelize_plan = {
|
||||
"net1": ColwiseParallel(),
|
||||
@ -65,7 +65,7 @@ class TestTpCheckpoint(DTensorTestBase):
|
||||
|
||||
# Update the parameters so model.state_dict() will be different from original_state_dict.
|
||||
torch.manual_seed(0)
|
||||
inp = torch.rand(20, 10).to(self.rank)
|
||||
inp = torch.rand(20, 10).cuda(self.rank)
|
||||
output = model(inp)
|
||||
output.sum().backward()
|
||||
optimizer.step()
|
||||
@ -94,7 +94,7 @@ class TestTpCheckpoint(DTensorTestBase):
|
||||
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
|
||||
|
||||
# create model and move it to GPU with id rank
|
||||
model = UnevenShardedModel(self.device_type).to(self.rank)
|
||||
model = UnevenShardedModel(self.device_type).cuda(self.rank)
|
||||
# Parallelize the module based on the given Parallel Style.
|
||||
parallelize_plan = {
|
||||
"net1": ColwiseParallel(),
|
||||
|
||||
@ -199,7 +199,7 @@ class TestReaderView(TestCase):
|
||||
class TestDistWrapper(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return min(4, torch.accelerator.device_count())
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
import logging
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw
|
||||
@ -26,15 +27,7 @@ from torch.distributed.pipelining import (
|
||||
ScheduleLoopedBFS,
|
||||
ScheduleZBVZeroBubble,
|
||||
)
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
_Action,
|
||||
_PipelineContext,
|
||||
_PipelineScheduleRuntime,
|
||||
_wait_batch_p2p,
|
||||
FORWARD,
|
||||
OVERLAP_F_B,
|
||||
)
|
||||
from torch.distributed.pipelining.stage import _PipelineStageBase # noqa: TC002
|
||||
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
|
||||
from torch.nn.modules.loss import MSELoss
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinuousTest,
|
||||
@ -522,7 +515,8 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(
|
||||
@ -549,6 +543,46 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
# Handle new runtime testing
|
||||
if use_new_runtime:
|
||||
old_schedule = schedule
|
||||
tmp_schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
tmp_schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
|
||||
|
||||
# Test CSV round-trip for compute_comms schedule
|
||||
schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
tmp_schedule._dump_csv(f.name)
|
||||
f.seek(0)
|
||||
schedule._load_csv(f.name, format="compute_comms")
|
||||
|
||||
one_more_schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
one_more_schedule._prepare_schedule_with_comms(
|
||||
schedule.pipeline_order_with_comms, format="compute_comms"
|
||||
)
|
||||
|
||||
# Verify schedule consistency
|
||||
self.assertEqual(
|
||||
len(schedule.pipeline_order_with_comms),
|
||||
len(one_more_schedule.pipeline_order_with_comms),
|
||||
)
|
||||
for rank in schedule.pipeline_order_with_comms:
|
||||
self.assertEqual(
|
||||
len(schedule.pipeline_order_with_comms[rank]),
|
||||
len(one_more_schedule.pipeline_order_with_comms[rank]),
|
||||
)
|
||||
for a, b in zip(
|
||||
schedule.pipeline_order_with_comms[rank],
|
||||
one_more_schedule.pipeline_order_with_comms[rank],
|
||||
):
|
||||
self.assertEqual(a, b)
|
||||
|
||||
# Run pipeline with tensor leak checking
|
||||
out = None
|
||||
losses = []
|
||||
@ -716,201 +750,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
def test_custom_function_callback(self):
|
||||
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
|
||||
n_stages = 8
|
||||
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(
|
||||
self.config, n_layers=n_stages
|
||||
)
|
||||
|
||||
# Run reference
|
||||
ref_out, ref_loss = run_reference_model(ref_mod, x, target, loss_fn)
|
||||
|
||||
# Create multi-stage pipeline with custom stage indices
|
||||
num_microbatches = 8
|
||||
stage_indices = rank_stages[self.rank]
|
||||
stages, stage_modules, submod_names = create_multi_stage_pipeline(
|
||||
self.config, mod, len(stage_indices), n_stages, stage_indices
|
||||
)
|
||||
|
||||
# Use DualPipeV schedule as the base schedule
|
||||
base_schedule = ScheduleDualPipeV(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
base_schedule._prepare_schedule_with_comms(base_schedule.pipeline_order)
|
||||
|
||||
# Track both types of callbacks separately
|
||||
forward_calls = []
|
||||
overlap_calls = []
|
||||
|
||||
def forward_callback(action: _Action, ctx: _PipelineContext):
|
||||
"""Custom callback for FORWARD computation that mimics the original implementation."""
|
||||
schedule = ctx.schedule_ref
|
||||
assert isinstance(schedule, _PipelineScheduleRuntime)
|
||||
stage_index_to_stage: dict[int, _PipelineStageBase] = {
|
||||
stage.stage_index: stage for stage in schedule._stages
|
||||
}
|
||||
stage = stage_index_to_stage[action.stage_index]
|
||||
stage_index = stage.stage_index
|
||||
mb_index = action.microbatch_index
|
||||
assert mb_index is not None
|
||||
fwd_recv_ops = schedule.fwd_recv_ops
|
||||
arg_mbs = ctx.arg_mbs
|
||||
kwarg_mbs = ctx.kwarg_mbs
|
||||
|
||||
is_next_stage_on_this_rank = stage_index + 1 in stage_index_to_stage
|
||||
is_prev_stage_on_this_rank = stage_index - 1 in stage_index_to_stage
|
||||
|
||||
# used in verification at the end
|
||||
forward_calls.append((stage_index, mb_index))
|
||||
|
||||
if (
|
||||
not stage.is_first
|
||||
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
|
||||
and not is_prev_stage_on_this_rank
|
||||
):
|
||||
assert (
|
||||
stage_index,
|
||||
mb_index,
|
||||
) in fwd_recv_ops, f"Computing {action=} before receiving input"
|
||||
from torch.distributed.pipelining.schedules import _wait_batch_p2p
|
||||
|
||||
_wait_batch_p2p(fwd_recv_ops.pop((stage_index, mb_index)))
|
||||
|
||||
output = stage.forward_one_chunk(
|
||||
mb_index,
|
||||
arg_mbs[mb_index], # type: ignore[index]
|
||||
kwarg_mbs[mb_index], # type: ignore[index]
|
||||
)
|
||||
schedule._maybe_compute_loss(stage, output, ctx.target_mbs, mb_index)
|
||||
|
||||
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
|
||||
# see [Note: V-schedule special case]
|
||||
if is_next_stage_on_this_rank:
|
||||
stage_index_to_stage[stage_index + 1].set_local_fwd_input(
|
||||
output, mb_index
|
||||
)
|
||||
|
||||
def overlap_callback(action: _Action, ctx: _PipelineContext):
|
||||
"""Custom callback for OVERLAP_F_B computation that mimics the original implementation."""
|
||||
schedule = ctx.schedule_ref
|
||||
assert isinstance(schedule, _PipelineScheduleRuntime)
|
||||
stage_index_to_stage: dict[int, _PipelineStageBase] = {
|
||||
stage.stage_index: stage for stage in schedule._stages
|
||||
}
|
||||
assert action.sub_actions is not None
|
||||
fwd_action = action.sub_actions[0]
|
||||
bwd_action = action.sub_actions[1]
|
||||
|
||||
# Forward ========================================================
|
||||
forward_callback(fwd_action, ctx)
|
||||
overlap_calls.append(
|
||||
(
|
||||
fwd_action.stage_index,
|
||||
fwd_action.microbatch_index,
|
||||
bwd_action.stage_index,
|
||||
bwd_action.microbatch_index,
|
||||
)
|
||||
)
|
||||
|
||||
# Backward ========================================================
|
||||
backward_stage_index = bwd_action.stage_index
|
||||
backward_stage = stage_index_to_stage[backward_stage_index]
|
||||
backward_mb_index = bwd_action.microbatch_index
|
||||
assert backward_mb_index is not None
|
||||
bwd_recv_ops = schedule.bwd_recv_ops
|
||||
is_next_stage_on_this_rank = (
|
||||
backward_stage.stage_index + 1 in stage_index_to_stage
|
||||
)
|
||||
is_prev_stage_on_this_rank = (
|
||||
backward_stage.stage_index - 1 in stage_index_to_stage
|
||||
)
|
||||
if (
|
||||
not backward_stage.is_last
|
||||
# no recv op expected for V-schedule special case (see [Note: V-schedule special case])
|
||||
and not is_next_stage_on_this_rank
|
||||
):
|
||||
assert (
|
||||
backward_stage_index,
|
||||
backward_mb_index,
|
||||
) in bwd_recv_ops, (
|
||||
f"Attempted to run compute {action=} before receiving input"
|
||||
)
|
||||
_wait_batch_p2p(
|
||||
bwd_recv_ops.pop((backward_stage_index, backward_mb_index))
|
||||
)
|
||||
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
|
||||
schedule.backward_counter[backward_stage_index] += 1
|
||||
last_backward = (
|
||||
schedule.backward_counter[backward_stage_index]
|
||||
== schedule._n_microbatches
|
||||
)
|
||||
grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1
|
||||
backward_stage.backward_one_chunk(
|
||||
backward_mb_index,
|
||||
loss=loss,
|
||||
full_backward=True,
|
||||
last_backward=last_backward,
|
||||
)
|
||||
if last_backward:
|
||||
backward_stage.scale_grads(grad_scale_factor)
|
||||
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
|
||||
# see [Note: V-schedule special case]
|
||||
if is_prev_stage_on_this_rank:
|
||||
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
|
||||
backward_stage.get_local_bwd_output(backward_mb_index),
|
||||
backward_mb_index,
|
||||
)
|
||||
|
||||
# Add the callback for FORWARD computation type
|
||||
|
||||
base_schedule.register_custom_function(FORWARD, forward_callback)
|
||||
base_schedule.register_custom_function(OVERLAP_F_B, overlap_callback)
|
||||
|
||||
# Run pipeline - special case where first and last stage are on rank 0
|
||||
out = None
|
||||
losses = []
|
||||
num_loops = 2
|
||||
for _ in range(num_loops):
|
||||
zero_gradients(stage_modules)
|
||||
if self.rank == 0:
|
||||
out = base_schedule.step(x, target=target, losses=losses)
|
||||
else:
|
||||
base_schedule.step()
|
||||
|
||||
dist.barrier()
|
||||
|
||||
# Verify results (rank 0 has both first and last stages)
|
||||
if self.rank == 0:
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
pipe_loss = sum(losses)
|
||||
torch.testing.assert_close(pipe_loss, ref_loss)
|
||||
|
||||
# Verify overlap callbacks were called
|
||||
self.assertGreater(
|
||||
len(overlap_calls), 0, "OVERLAP_F_B callback should have been called"
|
||||
)
|
||||
|
||||
# In a V-schedule with 8 microbatches and 2 stages per rank,
|
||||
# rank 0 should have 32 calls (8 microbatches * 2 stages * 2 loops)
|
||||
expected_count = num_microbatches * 2 * num_loops
|
||||
self.assertEqual(len(forward_calls), expected_count)
|
||||
|
||||
# Verify all callback calls are for stages on this rank
|
||||
for stage_idx, _ in forward_calls:
|
||||
self.assertIn(
|
||||
stage_idx,
|
||||
stage_indices,
|
||||
f"Callback called for stage {stage_idx} not on rank {self.rank}",
|
||||
)
|
||||
|
||||
# Check gradients using helper method
|
||||
check_gradients(self.config, stage_modules, ref_mod, submod_names)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, "NCCL test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
@ -1008,7 +847,8 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
|
||||
@ -1031,6 +871,13 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
if use_new_runtime:
|
||||
old_schedule = schedule
|
||||
schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn
|
||||
)
|
||||
schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
|
||||
|
||||
# Run pipeline - special case where first and last stage are on rank 0
|
||||
out = None
|
||||
losses = []
|
||||
|
||||
@ -336,6 +336,20 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_set_mesh_dim_group_options(self):
|
||||
device_type = (
|
||||
torch.accelerator.current_accelerator().type
|
||||
if torch.accelerator.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
|
||||
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
mesh = DeviceMesh(device_type, mesh_tensor)
|
||||
# Fake pg only have BackendType as BackendType::CUSTOM.
|
||||
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
|
||||
|
||||
@with_comms
|
||||
def test_get_root_mesh_multiple_independent_meshes(self):
|
||||
# regression test for issue #163330
|
||||
|
||||
@ -893,29 +893,6 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(gn(inp), inp + 3)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_step_unsupported(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts)
|
||||
def fn(x):
|
||||
x = x + 1 + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
return x + 4
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(fn(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
def test_step_unsupported_empty_checkpoint(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
torch._dynamo.step_unsupported()
|
||||
return x + 1
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(fn(inp), inp + 1)
|
||||
|
||||
@skipIfWindows(
|
||||
msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows."
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
|
||||
from torch._dynamo.testing import skipIfNotPy312, skipIfOnlyNotPy312
|
||||
from torch._dynamo.testing import skipIfNotPy312
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
@ -1015,7 +1015,6 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
|
||||
"<Internal traceback>\n",
|
||||
msg,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
msg,
|
||||
"""\
|
||||
@ -1052,6 +1051,7 @@ from user code:
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
|
||||
# check the log for the 2nd torch._dynamo.graph_break()
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[-1].getMessage(), skip=0),
|
||||
"""\
|
||||
@ -1075,104 +1075,6 @@ User code traceback:
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_latest_bytecode_to_graph_break_fullgraph(self, records):
|
||||
def fn(x):
|
||||
y = x + 1
|
||||
z = x + y
|
||||
torch._dynamo.graph_break()
|
||||
return z
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(3)),
|
||||
"""\
|
||||
Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: None
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@skipIfOnlyNotPy312
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_latest_bytecode_to_graph_break_python_versioning(self, records):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
y = x + 1
|
||||
z = x + y
|
||||
torch._dynamo.graph_break()
|
||||
return z
|
||||
|
||||
fn(torch.ones(3))
|
||||
|
||||
s = munge_exc(records[0].getMessage(), skip=0)
|
||||
|
||||
self.assertExpectedInline(
|
||||
s,
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Call to `torch._dynamo.graph_break()`
|
||||
Explanation: User-inserted graph break. Message: None
|
||||
Hint: Remove the `torch._dynamo.graph_break()` call.
|
||||
|
||||
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_latest_bytecode_to_graph_break_python_versioning
|
||||
fn(torch.ones(3))
|
||||
|
||||
========== most recent `torch.compile` tracing attempt started here ==========
|
||||
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
|
||||
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
|
||||
Most recent bytecode instructions traced (max 20):
|
||||
TRACE RESUME 0 []
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_CONST 1 [LazyVariableTracker()]
|
||||
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
|
||||
TRACE STORE_FAST 'y' [TensorVariable()]
|
||||
TRACE LOAD_FAST 'x' []
|
||||
TRACE LOAD_FAST 'y' [TensorVariable()]
|
||||
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
|
||||
TRACE STORE_FAST 'z' [TensorVariable()]
|
||||
TRACE LOAD_GLOBAL 'torch' []
|
||||
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
|
||||
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
|
||||
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_latest_bytecode_to_graph_break(self, records):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
y = x + 1
|
||||
z = x + y
|
||||
torch._dynamo.graph_break()
|
||||
return z
|
||||
|
||||
fn(torch.ones(3))
|
||||
|
||||
pattern = r"TRACE.*"
|
||||
s = munge_exc(records[0].getMessage(), skip=0)
|
||||
matches = re.findall(pattern, s)
|
||||
self.assertEqual((len(matches) > 10), True)
|
||||
self.assertEqual((len(matches) <= 20), True)
|
||||
self.assertIn("Most recent bytecode instructions traced (max 20):", s)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_graph_break_traceback_above_dynamo_shows_user_code(self, records):
|
||||
|
||||
@ -1,270 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.test_case import run_tests
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
def inner(*args):
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO - should not need this because we should turn this on in Dynamo but
|
||||
# for some reasons, test fail.
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cm = torch.fx.traceback.preserve_node_meta()
|
||||
self.cm.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.cm.__exit__(None, None, None)
|
||||
|
||||
def get_custom_metadata(self, gm):
|
||||
def helper(gm):
|
||||
custom_metadata = []
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, "meta") and node.meta.get("custom", None):
|
||||
custom_metadata.append((node.op, node.name, node.meta["custom"]))
|
||||
if node.op == "get_attr" and isinstance(
|
||||
getattr(gm, node.target), torch.fx.GraphModule
|
||||
):
|
||||
custom_metadata.append(helper(getattr(gm, node.target)))
|
||||
return custom_metadata
|
||||
|
||||
return "\n".join(str(x) for x in helper(gm))
|
||||
|
||||
def test_annotations(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
with fx_traceback.annotate({"fdsp_bucket": 0}):
|
||||
sin = torch.sin(x)
|
||||
sub = sin - 2
|
||||
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
|
||||
mul = sub * 2
|
||||
div = mul / 3
|
||||
return div
|
||||
|
||||
m = Mod()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_m(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})
|
||||
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
with fx_traceback.annotate({"ac_sin": 0}):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'ac_sin': 0})
|
||||
('get_attr', 'wrap_body_0', {'ac_sin': 0})
|
||||
[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})]
|
||||
('call_function', 'tag_activation_checkpoint', {'ac_sin': 0})
|
||||
('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'ac_sin': 0})
|
||||
('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing_annotation_inside(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"stage": 0}):
|
||||
p = torch.sin(x)
|
||||
return p + 1
|
||||
|
||||
def fn(x):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'stage': 0})
|
||||
('call_function', 'mul', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_ac_flex_attention(self):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
|
||||
|
||||
def gn(x: torch.Tensor):
|
||||
with fx_traceback.annotate({"compile_inductor": 0}):
|
||||
return flex_attention(
|
||||
x, x, x, block_mask=block_mask, score_mod=_squared
|
||||
)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
x = gn(x)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0})
|
||||
('get_attr', 'score_mod_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'mask_fn_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('get_attr', 'sdpa_score0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'sdpa_mask0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'getitem', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_4', {'compile_inductor': 0})
|
||||
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('placeholder', 'getitem', {'compile_inductor': 0})
|
||||
('placeholder', 'detach_5', {'compile_inductor': 0})
|
||||
('call_function', 'zeros', {'compile_inductor': 0})
|
||||
('call_function', 'detach', {'compile_inductor': 0})
|
||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'mask_graph0', {'compile_inductor': 0})
|
||||
[('call_function', 'ge', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_3', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_4', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -363,31 +363,6 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 13)
|
||||
|
||||
def test_cells_double_graph_break(self):
|
||||
def f1(x1):
|
||||
cell1 = x1 + 1
|
||||
|
||||
def f2(x2):
|
||||
nonlocal cell1
|
||||
cell1 += 2
|
||||
torch._dynamo.graph_break()
|
||||
torch._dynamo.graph_break()
|
||||
return x2 + cell1
|
||||
|
||||
return f2(x1 + 4), cell1
|
||||
|
||||
def outer(x):
|
||||
return f1(x)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
|
||||
x = torch.zeros(3)
|
||||
res = outer(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
def test_side_effects_cells(self):
|
||||
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4
|
||||
|
||||
@ -536,7 +511,6 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 5)
|
||||
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
|
||||
self.assertEqual(cnts.op_count, 12)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
|
||||
|
||||
def test_nested_graph_break_in_try_block(self):
|
||||
# NOTE: this also tests nested step_graph_break
|
||||
@ -577,40 +551,13 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.zeros(3)
|
||||
res = f5(x)
|
||||
ref = opt_fn(x)
|
||||
print(ref, res)
|
||||
self.assertEqual(ref, res)
|
||||
# skip frame due to graph break in try block
|
||||
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
|
||||
|
||||
def test_nested_step_unsupported(self):
|
||||
global f1, f2, f3
|
||||
|
||||
def f1(x):
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
x = x + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
return f1(x) + 4
|
||||
|
||||
def f3(x):
|
||||
x = x + 8
|
||||
return f2(x) + 16
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
# all ops except + 4
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -7256,26 +7256,6 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
flag = False
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
|
||||
def test_cells_unsupported_step_exception(self):
|
||||
# This error happened because:
|
||||
# - we were generating cells into a list on the stack
|
||||
# - we encountered an unsupported step, resulting in a step graph break
|
||||
# - we encounter an exception, which pops the stack until it reaches a certain length;
|
||||
# the presence of the list of cells then messes things up.
|
||||
|
||||
cell = 0
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
x = x + 1 + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
with contextlib.nullcontext():
|
||||
print(cell)
|
||||
raise AssertionError
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(torch.ones(3))
|
||||
|
||||
def test_unbind_copy_out(self):
|
||||
def f(eye, out):
|
||||
torch.unbind_copy(eye, out=out)
|
||||
|
||||
@ -15660,6 +15660,11 @@ def forward(self, x):
|
||||
test_serdes=True,
|
||||
)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureStrictV2
|
||||
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
|
||||
@testing.expectedFailureSerDer
|
||||
def test_preserve_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
||||
@ -5,7 +5,6 @@ import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
class TestAugmentedGraphHelper(TestCase):
|
||||
@ -62,29 +61,9 @@ class TestAugmentedGraphHelper(TestCase):
|
||||
]:
|
||||
self.nodes[node.name] = node
|
||||
|
||||
# Get all nodes and compute ancestors
|
||||
# Get all nodes and create tracker
|
||||
self.all_nodes = list(self.graph.nodes)
|
||||
self.node_ancestors = self._collect_node_ancestors(self.graph)
|
||||
|
||||
# Create tracker with ancestors
|
||||
self.tracker = AugmentedGraphHelper(
|
||||
self.graph, node_ancestors=self.node_ancestors
|
||||
)
|
||||
|
||||
def _collect_node_ancestors(
|
||||
self, graph: fx.Graph
|
||||
) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||
"""Collect all ancestors for each node."""
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
for node in graph.nodes:
|
||||
for input_node in node.all_input_nodes:
|
||||
ancestors[node].add(input_node)
|
||||
ancestors[node] |= ancestors[input_node]
|
||||
return ancestors
|
||||
self.tracker = AugmentedGraphHelper(self.graph)
|
||||
|
||||
def get_deps(self, node):
|
||||
"""Helper to get dependencies for a node."""
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import functools
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
|
||||
|
||||
def tensor_storage_id(tensor):
|
||||
return tensor._typed_storage()._cdata
|
||||
|
||||
|
||||
def device_filter(device):
|
||||
return device.type == "cuda"
|
||||
|
||||
|
||||
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
|
||||
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
|
||||
# counter of storage ids to live references
|
||||
self.storage_count: dict[int, int] = Counter()
|
||||
# live fake tensors
|
||||
self.live_tensors = WeakIdKeyDictionary()
|
||||
self.memory_use = 0
|
||||
self.max_memory = 0
|
||||
self.device_filter = device_filter
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
rs = func(*args, **kwargs)
|
||||
tree_map_only(torch._subclasses.FakeTensor, self.increase_memory_use, rs)
|
||||
return rs
|
||||
|
||||
def increase_memory_use(self, tensor):
|
||||
# already accounted for
|
||||
if tensor in self.live_tensors:
|
||||
return
|
||||
|
||||
if self.device_filter is not None and not self.device_filter(tensor.device):
|
||||
return
|
||||
|
||||
self.live_tensors[tensor] = True
|
||||
nbytes = tensor.untyped_storage().nbytes()
|
||||
|
||||
storage_id = tensor_storage_id(tensor)
|
||||
|
||||
# new storage, add to memory
|
||||
if storage_id not in self.storage_count:
|
||||
self.change_memory(nbytes)
|
||||
|
||||
self.storage_count[storage_id] += 1
|
||||
|
||||
# when this tensor dies, we need to adjust memory
|
||||
weakref.finalize(
|
||||
tensor, functools.partial(self.tensor_cleanup, storage_id, nbytes)
|
||||
)
|
||||
|
||||
def tensor_cleanup(self, storage_id, nbytes):
|
||||
self.storage_count[storage_id] -= 1
|
||||
if self.storage_count[storage_id] == 0:
|
||||
del self.storage_count[storage_id]
|
||||
self.change_memory(-nbytes)
|
||||
|
||||
def change_memory(self, delta):
|
||||
self.memory_use += delta
|
||||
self.max_memory = max(self.memory_use, self.max_memory)
|
||||
|
||||
|
||||
class TestMemoryProfilingResNet(InductorTestCase):
|
||||
def test_simple_linear_layers(self):
|
||||
"""Test with a simple sequential model with explicit weights on CUDA."""
|
||||
|
||||
def create_inputs_and_weights():
|
||||
"""Create inputs and weights on CUDA."""
|
||||
x = torch.randn(32, 1000, device="cuda")
|
||||
w1 = torch.randn(500, 1000, device="cuda")
|
||||
w2 = torch.randn(100, 500, device="cuda")
|
||||
w3 = torch.randn(10, 100, device="cuda")
|
||||
return x, w1, w2, w3
|
||||
|
||||
def fn(x, w1, w2, w3):
|
||||
h1 = torch.nn.functional.linear(x, w1)
|
||||
h1 = torch.nn.functional.relu(h1)
|
||||
h2 = torch.nn.functional.linear(h1, w2)
|
||||
h2 = torch.nn.functional.relu(h2)
|
||||
out = torch.nn.functional.linear(h2, w3)
|
||||
return out
|
||||
|
||||
with FakeTensorMode():
|
||||
# Trace with make_fx
|
||||
x, w1, w2, w3 = create_inputs_and_weights()
|
||||
fx_graph = make_fx(fn)(x, w1, w2, w3)
|
||||
|
||||
# Static analysis
|
||||
def is_releasable(node):
|
||||
return node.op not in ("placeholder", "get_attr")
|
||||
|
||||
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
|
||||
fx_peak = max(fx_memory_profile)
|
||||
|
||||
# Runtime profiling
|
||||
profiler = FakeTensorMemoryProfilerMode()
|
||||
|
||||
with profiler:
|
||||
x_runtime, w1_runtime, w2_runtime, w3_runtime = (
|
||||
create_inputs_and_weights()
|
||||
)
|
||||
result = fn(x_runtime, w1_runtime, w2_runtime, w3_runtime)
|
||||
del result
|
||||
|
||||
runtime_peak = profiler.max_memory
|
||||
|
||||
self.assertEqual(fx_peak, runtime_peak)
|
||||
|
||||
def test_conv_network(self):
|
||||
"""Test with a convolutional network."""
|
||||
|
||||
def create_inputs_and_weights():
|
||||
"""Create inputs and weights on CUDA."""
|
||||
x = torch.randn(8, 3, 224, 224, device="cuda")
|
||||
conv1_weight = torch.randn(64, 3, 3, 3, device="cuda")
|
||||
conv2_weight = torch.randn(128, 64, 3, 3, device="cuda")
|
||||
linear_weight = torch.randn(10, 128 * 56 * 56, device="cuda")
|
||||
return x, conv1_weight, conv2_weight, linear_weight
|
||||
|
||||
def fn(x, conv1_weight, conv2_weight, linear_weight):
|
||||
h = torch.nn.functional.conv2d(x, conv1_weight, padding=1)
|
||||
h = torch.nn.functional.relu(h)
|
||||
h = torch.nn.functional.max_pool2d(h, 2)
|
||||
h = torch.nn.functional.conv2d(h, conv2_weight, padding=1)
|
||||
h = torch.nn.functional.relu(h)
|
||||
h = torch.nn.functional.max_pool2d(h, 2)
|
||||
h = torch.flatten(h, 1)
|
||||
out = torch.nn.functional.linear(h, linear_weight)
|
||||
return out
|
||||
|
||||
with FakeTensorMode():
|
||||
# Trace with make_fx
|
||||
x, conv1_weight, conv2_weight, linear_weight = create_inputs_and_weights()
|
||||
fx_graph = make_fx(fn)(x, conv1_weight, conv2_weight, linear_weight)
|
||||
|
||||
def is_releasable(node):
|
||||
return node.op not in ("placeholder", "get_attr")
|
||||
|
||||
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
|
||||
fx_peak = max(fx_memory_profile)
|
||||
|
||||
# Runtime profiling
|
||||
profiler = FakeTensorMemoryProfilerMode()
|
||||
|
||||
with profiler:
|
||||
x_runtime, conv1_w, conv2_w, linear_w = create_inputs_and_weights()
|
||||
result = fn(x_runtime, conv1_w, conv2_w, linear_w)
|
||||
del result
|
||||
|
||||
runtime_peak = profiler.max_memory
|
||||
|
||||
self.assertEqual(fx_peak, runtime_peak)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA_AND_TRITON:
|
||||
run_tests(needs="filelock")
|
||||
@ -22,8 +22,7 @@ from torch.testing._internal.common_cuda import \
|
||||
(SM53OrLater, SM80OrLater, TEST_MULTIGPU)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, dtypesIfMPS, onlyCPU, onlyCUDA, precisionOverride,
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS,
|
||||
expectedFailureMPSComplex, largeTensorTest)
|
||||
deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS, largeTensorTest)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
@ -1854,7 +1853,7 @@ class TestSparse(TestSparseBase):
|
||||
self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPSComplex
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_norm(self, device, dtype, coalesced):
|
||||
|
||||
@ -40,7 +40,6 @@ from .decorators import (
|
||||
run,
|
||||
set_stance,
|
||||
skip_frame,
|
||||
step_unsupported,
|
||||
substitute_in_graph,
|
||||
)
|
||||
from .eval_frame import (
|
||||
@ -103,7 +102,6 @@ __all__ = [
|
||||
"error_on_graph_break",
|
||||
"set_stance",
|
||||
"skip_frame",
|
||||
"step_unsupported",
|
||||
"substitute_in_graph",
|
||||
]
|
||||
|
||||
|
||||
@ -397,37 +397,19 @@ def create_call_function(nargs: int, push_null: bool) -> list[Instruction]:
|
||||
return [create_instruction("CALL_FUNCTION", arg=nargs)]
|
||||
|
||||
|
||||
def create_call_function_ex(
|
||||
has_kwargs: bool, push_null: bool, ignore_314_kwargs_push: bool = False
|
||||
) -> list[Instruction]:
|
||||
def create_call_function_ex(has_kwargs: bool) -> list[Instruction]:
|
||||
"""
|
||||
Assumes that in 3.14+, if has_kwargs=False, there is NOT a NULL
|
||||
on the TOS for the kwargs. This utility function will add a PUSH_NULL.
|
||||
|
||||
If the caller has already pushed a NULL for the kwargs, then set ignore_314_kwargs_push=True
|
||||
so we don't push another NULL for the kwargs.
|
||||
If the caller has already pushed a NULL, then do not call this function -
|
||||
just use create_instruction("CALL_FUNCTION_EX", arg=...).
|
||||
"""
|
||||
if sys.version_info >= (3, 11):
|
||||
output = []
|
||||
if (
|
||||
sys.version_info >= (3, 14)
|
||||
and not has_kwargs
|
||||
and not ignore_314_kwargs_push
|
||||
):
|
||||
output.append(create_instruction("PUSH_NULL"))
|
||||
if push_null:
|
||||
output.append(create_instruction("PUSH_NULL"))
|
||||
# 3.13 swapped NULL and callable
|
||||
# if flags == 1, 2 values popped - otherwise if flags == 0, 1 value
|
||||
rots = (
|
||||
int(has_kwargs) + 2
|
||||
if sys.version_info >= (3, 13)
|
||||
else int(has_kwargs) + 3
|
||||
)
|
||||
output.extend(create_rot_n(rots))
|
||||
output.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
|
||||
return output
|
||||
return [create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs))]
|
||||
insts = []
|
||||
if sys.version_info >= (3, 14) and not has_kwargs:
|
||||
insts.append(create_instruction("PUSH_NULL"))
|
||||
insts.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
|
||||
return insts
|
||||
|
||||
|
||||
def create_call_method(nargs: int) -> list[Instruction]:
|
||||
@ -533,8 +515,6 @@ def create_binary_slice(
|
||||
def create_copy(i: int) -> list[Instruction]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return [create_instruction("COPY", arg=i)]
|
||||
if i == 1:
|
||||
return [create_instruction("DUP_TOP")]
|
||||
# COPY 4
|
||||
# 0 1 2 3
|
||||
# 3 1 2 0
|
||||
|
||||
@ -519,7 +519,7 @@ class PyCodegen:
|
||||
create_build_tuple(n),
|
||||
self.create_load_const_unchecked(rot_n_helper(n)),
|
||||
*create_rot_n(2),
|
||||
*create_call_function_ex(False, False),
|
||||
*create_call_function_ex(False),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=n),
|
||||
]
|
||||
|
||||
@ -540,33 +540,51 @@ class PyCodegen:
|
||||
|
||||
def make_function_with_closure(
|
||||
self,
|
||||
tx: "InstructionTranslatorBase",
|
||||
fn_name: str,
|
||||
code: types.CodeType,
|
||||
push_null: bool,
|
||||
num_on_stack: int = 0,
|
||||
) -> None:
|
||||
"""Creates a closure with code object `code`.
|
||||
|
||||
Expects the TOS to be the tuple of cells to use for this closure.
|
||||
TOS will be popped to create the closure.
|
||||
Args:
|
||||
- fn_name: name of the function
|
||||
- code: code object of the function
|
||||
(does not include the tuple of cells on the TOS)
|
||||
"""
|
||||
freevars = code.co_freevars
|
||||
assert freevars
|
||||
output = self._output
|
||||
|
||||
output.append(self.create_load_const(code))
|
||||
if sys.version_info < (3, 11):
|
||||
output.append(self.create_load_const(fn_name))
|
||||
if sys.version_info >= (3, 13):
|
||||
output.extend(
|
||||
[
|
||||
create_instruction("MAKE_FUNCTION"),
|
||||
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
|
||||
]
|
||||
)
|
||||
else:
|
||||
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
|
||||
def gen_fn() -> None:
|
||||
self.clear_tos()
|
||||
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
|
||||
# requires that in the generated bytecode, these cells would keep
|
||||
# their original local names, which we ensure via
|
||||
# `CellVariable.local_name`.
|
||||
for var in freevars:
|
||||
if tx is self.tx: # root frame
|
||||
assert var in self.cell_and_freevars()
|
||||
output.append(self.create_load_closure(var))
|
||||
else: # nested frame
|
||||
assert var in tx.cell_and_freevars()
|
||||
assert tx.post_prune_cell_and_freevars
|
||||
self(tx.post_prune_cell_and_freevars[var])
|
||||
output.append(create_build_tuple(len(freevars)))
|
||||
output.append(self.create_load_const(code))
|
||||
if sys.version_info < (3, 11):
|
||||
output.append(self.create_load_const(fn_name))
|
||||
if sys.version_info >= (3, 13):
|
||||
output.extend(
|
||||
[
|
||||
create_instruction("MAKE_FUNCTION"),
|
||||
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
|
||||
]
|
||||
)
|
||||
else:
|
||||
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
|
||||
|
||||
if push_null and sys.version_info >= (3, 11):
|
||||
self.add_push_null(gen_fn)
|
||||
output.extend(self.rot_n(num_on_stack + 2))
|
||||
output.extend(self.rot_n(num_on_stack + 2))
|
||||
else:
|
||||
gen_fn()
|
||||
output.extend(self.rot_n(num_on_stack + 1))
|
||||
self.clear_tos()
|
||||
|
||||
def create_load_python_module(self, mod: types.ModuleType) -> Instruction:
|
||||
|
||||
@ -750,9 +750,6 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
|
||||
return handle
|
||||
|
||||
|
||||
# TODO - We want to run preserve_node_meta context manager here, but the CI
|
||||
# fails (its unclear if the failures were flaky)
|
||||
# @torch.fx.traceback.preserve_node_meta()
|
||||
@preserve_global_state
|
||||
def trace_frame(
|
||||
code: types.CodeType,
|
||||
|
||||
@ -296,14 +296,6 @@ def skip_frame(msg: str = "") -> None:
|
||||
"""Force a skipped frame"""
|
||||
|
||||
|
||||
@_disallow_in_graph_helper(throw_if_not_allowed=False)
|
||||
def step_unsupported(msg: str = "") -> None:
|
||||
"""Force a step unsupported graph break, which results in compiling
|
||||
the traced FX graph so far, then skipping the rest of the frame.
|
||||
In order to get expected behavior, there should be at least 2 ops
|
||||
and a part of the code not contained in any try/with blocks."""
|
||||
|
||||
|
||||
def forbid_in_graph(fn: Any) -> Any:
|
||||
"""
|
||||
Customize which functions TorchDynamo will assert are not present while tracing.
|
||||
|
||||
@ -263,11 +263,6 @@ class RecompileLimitExceeded(Unsupported):
|
||||
pass
|
||||
|
||||
|
||||
# debug exception thrown when tracing torch._dynamo.step_unsupported()
|
||||
class StepUnsupported(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsafeScriptObjectError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
@ -2763,18 +2763,5 @@
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0275": [
|
||||
{
|
||||
"Gb_type": "torch._dynamo.step_unsupported() with empty checkpoint",
|
||||
"Context": "",
|
||||
"Explanation": "traced torch._dynamo.step_unsupported(), but there is no checkpoint to step_graph_break from. This graph break is used for debugging only.",
|
||||
"Hints": [
|
||||
"Remove the torch._dynamo.step_unsupported() call.",
|
||||
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some ",
|
||||
"line of code that is not in a try/with block, and has an empty Python stack.",
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -79,7 +79,6 @@ from .backends.registry import CompiledFn, CompilerFn
|
||||
from .bytecode_transformation import (
|
||||
create_binary_slice,
|
||||
create_binary_subscr,
|
||||
create_build_tuple,
|
||||
create_call_function,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
@ -1535,9 +1534,8 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
# Codegen stack convention before the unsupported instruction
|
||||
# NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
|
||||
# NOTE: stack/locals/cells must be codegen'd BEFORE the unsupported instruction, since the latter
|
||||
# NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
|
||||
# can arbitrarily mutate the former.
|
||||
# [frame N cells, .., frame 1 cells],
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
@ -1547,7 +1545,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
# see symbolic_convert.py for
|
||||
# codegen stack convention after the unsupported instruction
|
||||
# NOTE: cells will be loaded into continuation functions directly by symbolic_convert
|
||||
# NOTE: cells are loaded into continuation functions directly
|
||||
|
||||
# this determines the order that values are codegen'd to the stack
|
||||
stack_values_flat = [val for vals in all_stack_values for val in vals]
|
||||
@ -1579,19 +1577,12 @@ class OutputGraph(OutputGraphCommon):
|
||||
and not all_stack_locals_metas[-1].locals_null_keys
|
||||
):
|
||||
# optimization to generate better code in a common case
|
||||
|
||||
# codegen cells
|
||||
# no side effects, so no new cells created - no need to call side_effects.codegen_save_tempvars
|
||||
cell_cg = PyCodegen(self.root_tx)
|
||||
self.codegen_cells(tx, cell_cg)
|
||||
self.add_output_instructions(
|
||||
[
|
||||
# load in reverse since UNPACK_SEQUENCE will reverse
|
||||
*self.compile_and_call_fx_graph(
|
||||
tx, list(reversed(stack_values_flat)), root
|
||||
),
|
||||
*cell_cg.get_instructions(),
|
||||
*create_swap(2),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
|
||||
]
|
||||
)
|
||||
@ -1693,7 +1684,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
# store all stack and locals for each frame
|
||||
# current state of the stack:
|
||||
# all cells,
|
||||
# *(frame N stack), *(frame N locals),
|
||||
# ...,
|
||||
# *(frame 1 stack), *(frame 1 locals)
|
||||
@ -1708,7 +1698,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
|
||||
# current state of the stack:
|
||||
# all cells,
|
||||
# *(frame N stack), [
|
||||
# *(frame N locals),
|
||||
# *(frame N-1 stack), *(frame N-1 locals),
|
||||
@ -1769,8 +1758,7 @@ class OutputGraph(OutputGraphCommon):
|
||||
# *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
|
||||
|
||||
# current state of the stack:
|
||||
# all cells,
|
||||
# *(frame N stack),
|
||||
# *(frame N stack)
|
||||
# frame N locals,
|
||||
# frame N-1 stack, frame N-1 locals,
|
||||
# ...
|
||||
@ -1787,7 +1775,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
)
|
||||
|
||||
# final state of the stack before running the unsupported bytecode:
|
||||
# all cells,
|
||||
# [
|
||||
# [frame N locals],
|
||||
# [frame N-1 stack + locals],
|
||||
@ -1844,31 +1831,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
|
||||
return all_stack_locals_metas
|
||||
|
||||
def codegen_cells(self, tx: "InstructionTranslatorBase", cg: PyCodegen) -> None:
|
||||
# no need to codegen if reason.graph_break is False (since we won't resume)
|
||||
if self.compile_subgraph_reason.graph_break:
|
||||
tx_cnt = 0
|
||||
cur_tx: Optional[InstructionTranslatorBase] = tx
|
||||
while cur_tx is not None:
|
||||
# NOTE: we generate cells in the same order as resume_execution.py: sorted freevars + cellvars
|
||||
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
|
||||
# requires that in the generated bytecode, these cells would keep
|
||||
# their original local names, which we ensure via
|
||||
# `CellVariable.local_name`.
|
||||
freevars = tuple(sorted(cur_tx.cell_and_freevars()))
|
||||
for cell in freevars:
|
||||
if cur_tx is self.root_tx: # root frame
|
||||
cg.append_output(cg.create_load_closure(cell))
|
||||
else: # nested frame
|
||||
assert cur_tx.post_prune_cell_and_freevars
|
||||
cg(cur_tx.post_prune_cell_and_freevars[cell])
|
||||
cg.append_output(create_build_tuple(len(freevars)))
|
||||
cur_tx = cur_tx.parent
|
||||
tx_cnt += 1
|
||||
cg.append_output(create_instruction("BUILD_LIST", arg=tx_cnt))
|
||||
else:
|
||||
cg.append_output(create_instruction("BUILD_LIST", arg=0))
|
||||
|
||||
def codegen_suffix(
|
||||
self,
|
||||
tx: "InstructionTranslatorBase",
|
||||
@ -1888,7 +1850,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
cg.store_attr(name)
|
||||
self.side_effects.codegen_hooks(cg)
|
||||
|
||||
# TODO get debug_locals working for nested graph breaks
|
||||
# Return variables used for logging at the end
|
||||
for debug_var, args in tx.debug_locals:
|
||||
cg.add_push_null(lambda: cg(debug_var))
|
||||
@ -1897,9 +1858,6 @@ class OutputGraph(OutputGraphCommon):
|
||||
cg.extend_output(create_call_function(len(args), False))
|
||||
cg.extend_output([create_instruction("POP_TOP")])
|
||||
|
||||
# codegen cells before we apply side effects
|
||||
self.codegen_cells(tx, cg)
|
||||
|
||||
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(cg)
|
||||
|
||||
|
||||
@ -318,7 +318,6 @@ class ContinueExecutionCache:
|
||||
argnames: tuple[str, ...],
|
||||
argnames_null: tuple[str, ...],
|
||||
setup_fns: tuple[ReenterWith, ...],
|
||||
handle_inactive_ctx: bool,
|
||||
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
|
||||
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
|
||||
null_idxes: tuple[int, ...],
|
||||
@ -342,7 +341,6 @@ class ContinueExecutionCache:
|
||||
argnames,
|
||||
argnames_null,
|
||||
setup_fns,
|
||||
handle_inactive_ctx,
|
||||
stack_ctx_vars,
|
||||
argnames_ctx_vars,
|
||||
null_idxes,
|
||||
@ -434,7 +432,7 @@ class ContinueExecutionCache:
|
||||
prefix.append(
|
||||
create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
|
||||
)
|
||||
if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
|
||||
if stack_i in stack_ctx_vars_d:
|
||||
# NOTE: we assume that current stack var is a context manager CLASS!
|
||||
# Load args for context variable and construct it
|
||||
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
|
||||
@ -461,11 +459,10 @@ class ContinueExecutionCache:
|
||||
|
||||
# NOTE: we assume that local var is a context manager CLASS!
|
||||
# initialize inactive context vars in argnames
|
||||
if handle_inactive_ctx:
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
prefix.extend(_load_tuple_and_call(vals))
|
||||
prefix.append(create_instruction("STORE_FAST", argval=name))
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
prefix.extend(_load_tuple_and_call(vals))
|
||||
prefix.append(create_instruction("STORE_FAST", argval=name))
|
||||
|
||||
# 3.12+: store NULL into variables that were NULL
|
||||
if argnames_null:
|
||||
@ -527,7 +524,7 @@ class ContinueExecutionCache:
|
||||
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
|
||||
),
|
||||
# finish the call
|
||||
*create_call_function_ex(False, False),
|
||||
*create_call_function_ex(False),
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
||||
@ -43,7 +43,6 @@ import threading
|
||||
import traceback
|
||||
import types
|
||||
import weakref
|
||||
from collections import deque
|
||||
from traceback import StackSummary
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
@ -80,7 +79,6 @@ from .bytecode_transformation import (
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_load_const,
|
||||
create_rot_n,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
@ -98,7 +96,6 @@ from .exc import (
|
||||
format_graph_break_message,
|
||||
get_stack_above_dynamo,
|
||||
ResumePrologueTracingError,
|
||||
StepUnsupported,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
@ -545,7 +542,6 @@ def log_graph_break(
|
||||
reason: str = "",
|
||||
exc_info: bool = False,
|
||||
user_stack: Optional[StackSummary] = None,
|
||||
latest_bytecode_log: Optional[str] = None,
|
||||
) -> None:
|
||||
if user_stack is None:
|
||||
user_stack = torch._guards.TracingContext.extract_stack()
|
||||
@ -608,10 +604,6 @@ def log_graph_break(
|
||||
# This log line MUST contain the string "Graph break in user code",
|
||||
# This log line is exercised from
|
||||
# python test/dynamo/test_exc.py -k test_graph_break_log
|
||||
if latest_bytecode_log and config.verbose:
|
||||
user_stack_trace += "Most recent bytecode instructions traced (max 20):\n"
|
||||
user_stack_trace += latest_bytecode_log
|
||||
|
||||
graph_break_log.debug(
|
||||
user_stack_trace,
|
||||
)
|
||||
@ -677,20 +669,14 @@ def generic_jump(
|
||||
)
|
||||
self.pop()
|
||||
|
||||
if_next = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
) + self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
if_next = self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
)
|
||||
if push:
|
||||
self.push(value)
|
||||
assert inst.target is not None
|
||||
if_jump = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], inst.target
|
||||
) + self.create_call_resume_at(
|
||||
inst.target,
|
||||
all_stack_locals_metadata,
|
||||
if_jump = self.create_call_resume_at(
|
||||
inst.target, all_stack_locals_metadata, False
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@ -939,7 +925,6 @@ def break_graph_if_unsupported(
|
||||
exc_info=True,
|
||||
reason=str(excp),
|
||||
user_stack=excp.real_stack,
|
||||
latest_bytecode_log="\n".join(self.latest_bytecode_queue),
|
||||
)
|
||||
|
||||
if self.maybe_has_backedge():
|
||||
@ -975,7 +960,7 @@ def break_graph_if_unsupported(
|
||||
all_stack_locals_metadata = self.output.compile_subgraph(
|
||||
self, reason=reason, stack_pops=push - stack_effect
|
||||
)
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
cg = PyCodegen(self)
|
||||
cleanup: list[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
@ -1024,12 +1009,8 @@ def break_graph_if_unsupported(
|
||||
for _ in range(push):
|
||||
self.push(UnknownVariable())
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
)
|
||||
)
|
||||
|
||||
@ -1191,8 +1172,6 @@ class InstructionTranslatorBase(
|
||||
parent: Optional[InstructionTranslatorBase]
|
||||
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
|
||||
package: Optional[CompilePackage]
|
||||
latest_bytecode_queue: deque[str]
|
||||
# Store the latest bytecode before graph_break() call by user
|
||||
|
||||
def mark_inconsistent_side_effects(self) -> None:
|
||||
"""
|
||||
@ -1360,17 +1339,6 @@ class InstructionTranslatorBase(
|
||||
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
|
||||
)
|
||||
|
||||
# Store the latest 20 bytecode execution for the process,
|
||||
# Used repr for byte processing and limiting the length to 2048
|
||||
try:
|
||||
stack_repr = repr(self.stack)
|
||||
except ValueError:
|
||||
# Handle large integers that exceed sys.int_info.str_digits_check_threshold
|
||||
stack_repr = "<self.stack repr truncated due to large integer>"
|
||||
self.latest_bytecode_queue.append(
|
||||
f"TRACE {inst.opname} {repr(inst.argval)} {stack_repr}"
|
||||
)
|
||||
|
||||
self.update_block_stack(inst)
|
||||
|
||||
try:
|
||||
@ -1383,22 +1351,9 @@ class InstructionTranslatorBase(
|
||||
return True
|
||||
except (ReturnValueOp, YieldValueOp):
|
||||
return False
|
||||
except (Unsupported, StepUnsupported) as e:
|
||||
except Unsupported:
|
||||
if self.current_speculation is None:
|
||||
log.debug("empty checkpoint")
|
||||
if isinstance(e, StepUnsupported):
|
||||
unimplemented_v2(
|
||||
gb_type="torch._dynamo.step_unsupported() with empty checkpoint",
|
||||
context="",
|
||||
explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint "
|
||||
"to step_graph_break from. This graph break is used for debugging only.",
|
||||
hints=[
|
||||
"Remove the torch._dynamo.step_unsupported() call.",
|
||||
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some "
|
||||
"line of code that is not in a try/with block, and has an empty Python stack.",
|
||||
*graph_break_hints.DYNAMO_BUG,
|
||||
],
|
||||
)
|
||||
raise
|
||||
log.debug("step triggered compile", exc_info=True)
|
||||
|
||||
@ -1472,110 +1427,24 @@ class InstructionTranslatorBase(
|
||||
partial_convert=True,
|
||||
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
|
||||
)
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
if self.parent:
|
||||
from .eval_frame import skip_code
|
||||
|
||||
# nested graph break
|
||||
assert config.nested_graph_breaks
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
|
||||
# codegen cells and frame values only for frame N
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
]
|
||||
)
|
||||
# No need to fix stack, since stack is assumed to be empty here.
|
||||
# Do NOT handle_inactive_ctx because we will be skipping this resume code.
|
||||
leaf_resume_code, leaf_resume_name = self.create_resume(
|
||||
0, continue_inst, all_stack_locals_metadata[0], [], cg, True, False
|
||||
)
|
||||
skip_code(leaf_resume_code)
|
||||
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ], [frame N cells], [frame N locals],
|
||||
self.codegen_call_resume([leaf_resume_code], [leaf_resume_name], cg)
|
||||
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ], leaf_resume result
|
||||
|
||||
# add the leaf_resume result to frame N-1 stack
|
||||
num_stack = all_stack_locals_metadata[1].num_stack
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(num_stack, num_stack, True),
|
||||
]
|
||||
)
|
||||
|
||||
# pop frame N cells and locals
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(1),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# call the remaining resume functions
|
||||
# current frame state
|
||||
# [frame N-1 cells, ..., frame 1 cells],
|
||||
# [
|
||||
# frame N-1 stack (including leaf_resume result) + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
self.parent.push(UnknownVariable())
|
||||
all_stack_locals_metadata[1].num_stack += 1
|
||||
self.output.add_output_instructions(
|
||||
cg.get_instructions()
|
||||
+ self.parent.create_call_resume_at(
|
||||
self.parent.next_instruction, all_stack_locals_metadata[1:]
|
||||
self.create_call_resume_at(
|
||||
continue_inst, all_stack_locals_metadata, True
|
||||
)
|
||||
)
|
||||
else:
|
||||
# pop cells
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
*create_swap(2),
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
# load locals from frame values
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
# current frame state
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
cg = PyCodegen(self)
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
cg.create_load_const(-1),
|
||||
@ -2640,12 +2509,8 @@ class InstructionTranslatorBase(
|
||||
self.output.add_output_instructions([copy.copy(inst)])
|
||||
self.popn(2)
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
)
|
||||
)
|
||||
|
||||
@ -2658,292 +2523,48 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def codegen_return_with_pops(
|
||||
inst: Instruction, num_stack: int
|
||||
def codegen_return_after_compile_subgraph(
|
||||
inst: Instruction, meta: StackLocalsMetadata
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Debug CPython expects the stack to be empty after the return.
|
||||
Calling compile_subgraph will push cells and frame values to TOS.
|
||||
This function will pop those 2 values from the stack before actually returning.
|
||||
|
||||
Expects the stack to be:
|
||||
cells, frame values, current frame stack (0 or 1 values)
|
||||
|
||||
Pops cells and frame values, leaving the current frame stack as TOS.
|
||||
A return instruction is included.
|
||||
"""
|
||||
insts = []
|
||||
# NOTE: Debug CPython expects the stack to be empty after the return.
|
||||
# Expect the current stack to be in the state
|
||||
# cells, frame values, current frame stack (0 or 1 values)
|
||||
assert num_stack <= 1
|
||||
if num_stack == 1:
|
||||
insts.extend(create_swap(3))
|
||||
# [[]] (empty frame values), current frame stack (0 or 1 values)
|
||||
assert meta.num_stack <= 1
|
||||
if meta.num_stack == 1:
|
||||
insts.extend(create_swap(2))
|
||||
return_inst = (
|
||||
create_instruction("RETURN_VALUE")
|
||||
if inst.opname == "RETURN_VALUE"
|
||||
else create_instruction("RETURN_CONST", argval=inst.argval)
|
||||
)
|
||||
insts.extend(
|
||||
[create_instruction("POP_TOP"), create_instruction("POP_TOP"), return_inst]
|
||||
)
|
||||
insts.extend([create_instruction("POP_TOP"), return_inst])
|
||||
return insts
|
||||
|
||||
def codegen_fix_leaf_stack(
|
||||
self, meta: StackLocalsMetadata, resume_inst: Instruction
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: Any,
|
||||
disable_current_frame_resume: bool,
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Fixes the stack values of the current/leaf frame (self).
|
||||
Codegen resume function(s) and call it.
|
||||
Assumes that the unsupported instruction has already been run.
|
||||
|
||||
Expects the TOS to be:
|
||||
Expects the stack to be in the state:
|
||||
[
|
||||
frame N locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
], *(frame N stack (post-unsupported instruction))
|
||||
|
||||
Rearranges the TOS to become:
|
||||
[
|
||||
frame N stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Args:
|
||||
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
|
||||
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
|
||||
"""
|
||||
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return []
|
||||
# move frame N stack to the frame values list
|
||||
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
|
||||
meta.num_stack = current_num_stack
|
||||
return [
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
create_load_const(0),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
|
||||
def create_resume(
|
||||
self,
|
||||
idx: int,
|
||||
resume_inst: Instruction,
|
||||
meta: StackLocalsMetadata,
|
||||
resume_codes: list[types.CodeType],
|
||||
cg: PyCodegen,
|
||||
is_leaf: bool,
|
||||
handle_inactive_ctx: bool,
|
||||
) -> tuple[types.CodeType, str]:
|
||||
"""
|
||||
Creates the resume function for the frame corresponding to `self`.
|
||||
|
||||
Expects the TOS to be:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Some additional codegen may happen to prepare the frame stack + locals values for the generated resume function:
|
||||
- inactive context variables in the stack and locals will be replaced by their types
|
||||
- if the frame is a leaf frame, prune dead locals
|
||||
|
||||
Regardless of codegen, the stack will be left in the same state as before.
|
||||
|
||||
Args:
|
||||
- idx: depth of this frame: 0 corresponds to the leaf frame (frame N), N-1 to the root frame (frame 1).
|
||||
- resume_inst: the instruction that this frame should resume at
|
||||
- meta: metadata for this frame returned from OutputGraph.compile_subgraph
|
||||
- resume_codes: nested resume code objects generated from previous create_resume calls.
|
||||
- cg: codegen object to output to
|
||||
- is_leaf: True if `self` corresponds to the leaf frame.
|
||||
- handle_inactive_ctx: If True, handles inactive context variables as described above. This is necessary
|
||||
iff the resume function is traced
|
||||
"""
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
if handle_inactive_ctx:
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.stack[j_orig])
|
||||
# frames[idx][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
|
||||
# frames[idx][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# If the resume instruction is a jump absolute, then resume
|
||||
# at the target instead. This handles the case where we
|
||||
# graph break again in a nested function before jump-resuming
|
||||
# this frame.
|
||||
if is_jump_absolute(resume_inst):
|
||||
assert resume_inst.target
|
||||
resume_inst = resume_inst.target
|
||||
|
||||
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
|
||||
# More locals may have been pruned in the current/leaf frame
|
||||
# after the unsupported instruction (e.g. branch).
|
||||
# There should not be any pruning in the other frames since
|
||||
# the current instruction there should be a CALL.
|
||||
if is_leaf:
|
||||
reads = livevars_analysis(self.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in self.symbolic_locals.keys()
|
||||
if k in reads and k not in self.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
|
||||
# codegen filter for current frame's locals
|
||||
# current stack state: frames
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
]
|
||||
)
|
||||
for arg in argnames:
|
||||
# current stack state: frames, frames[i], *(prev locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[arg]),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(2),
|
||||
# frames, frames i live locals, frames[i]
|
||||
*create_binary_slice(meta.num_stack, None, True),
|
||||
# frames[i][num_stack:] = frame i live locals
|
||||
]
|
||||
)
|
||||
# current stack state: frames
|
||||
else:
|
||||
argnames = tuple(meta.locals_names.keys())
|
||||
argnames_null = tuple(meta.locals_null_keys)
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(self.stack) - len(meta.stack_null_idxes)
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
self.f_code,
|
||||
self.lineno,
|
||||
resume_inst.offset,
|
||||
tuple(b.target.offset for b in self.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in self.block_stack),
|
||||
handle_inactive_ctx,
|
||||
tuple(meta.stack_ctx_args),
|
||||
tuple(meta.locals_ctx_args),
|
||||
tuple(meta.stack_null_idxes),
|
||||
tuple(resume_codes),
|
||||
)
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
# add resume function to the global scope
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
self.output.install_global_unsafe(resume_name, new_code)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
self.output.install_global_unsafe(
|
||||
resume_name,
|
||||
types.FunctionType(new_code, self.f_globals, resume_name),
|
||||
)
|
||||
package_name = resume_name
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_resume_function(
|
||||
new_code, self.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
return new_code, resume_name
|
||||
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: list[StackLocalsMetadata],
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen all resume function(s) from the frame stack starting at `self` and call them.
|
||||
Assumes that the unsupported instruction has already been run.
|
||||
|
||||
Expects the stack to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Pops the cells and frame values list from the stack.
|
||||
Also includes a return instruction (stack expected to be empty after return).
|
||||
], frame N stack (post-instruction)
|
||||
|
||||
Args:
|
||||
- inst: the instruction of the current (deepest) frame to resume at
|
||||
- all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
|
||||
metadata such as local names, NULL positions, stack length, etc.
|
||||
- disable_current_frame_resume: If True, disable tracing on the current frame's resume function.
|
||||
Used for implementing nested step_graph_break.
|
||||
"""
|
||||
|
||||
self.instruction_pointer = None
|
||||
@ -2954,115 +2575,234 @@ class InstructionTranslatorBase(
|
||||
all_stack_locals_metadata[0].num_stack = current_num_stack
|
||||
|
||||
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return self.codegen_return_with_pops(
|
||||
inst, all_stack_locals_metadata[0].num_stack
|
||||
return self.codegen_return_after_compile_subgraph(
|
||||
inst, all_stack_locals_metadata[0]
|
||||
)
|
||||
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
|
||||
# move frame N stack to the frame values list
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# [
|
||||
# [frame N stack (fixed) + locals]
|
||||
# ...,
|
||||
# [frame 1 stack + locals]
|
||||
# ],
|
||||
|
||||
#
|
||||
txes = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = self
|
||||
idx = 0
|
||||
resume_codes: list[types.CodeType] = []
|
||||
resume_names = []
|
||||
while cur_tx is not None:
|
||||
txes.append(cur_tx)
|
||||
cur_tx = cur_tx.parent
|
||||
assert len(txes) == len(all_stack_locals_metadata)
|
||||
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
if i == 0 and disable_current_frame_resume:
|
||||
continue
|
||||
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
|
||||
# frames[i][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
|
||||
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# build the resume function for each frame
|
||||
resume_names = []
|
||||
resume_codes: list[types.CodeType] = []
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
cur_tx = txes[i]
|
||||
if cur_tx is self:
|
||||
resume_inst = inst
|
||||
else:
|
||||
resume_inst = cur_tx.next_instruction
|
||||
resume_code, resume_name = cur_tx.create_resume(
|
||||
idx,
|
||||
resume_inst,
|
||||
all_stack_locals_metadata[idx],
|
||||
resume_codes,
|
||||
cg,
|
||||
cur_tx is self,
|
||||
True,
|
||||
)
|
||||
resume_codes.append(resume_code)
|
||||
# If the resume instruction is a jump absolute, then resume
|
||||
# at the target instead. This handles the case where we
|
||||
# graph break again in a nested function before jump-resuming
|
||||
# this frame.
|
||||
if is_jump_absolute(resume_inst):
|
||||
assert resume_inst.target
|
||||
resume_inst = resume_inst.target
|
||||
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
resume_names.append(resume_name)
|
||||
|
||||
cur_tx = cur_tx.parent
|
||||
idx += 1
|
||||
# More locals may have been pruned in the current frame
|
||||
# after the unsupported instruction (e.g. branch).
|
||||
# There should not be any pruning in the other frames since
|
||||
# the current instruction is a CALL.
|
||||
if cur_tx is self:
|
||||
reads = livevars_analysis(cur_tx.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in cur_tx.symbolic_locals.keys()
|
||||
if k in reads and k not in cur_tx.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
|
||||
self.codegen_call_resume(resume_codes, resume_names, cg)
|
||||
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
|
||||
|
||||
@staticmethod
|
||||
def codegen_call_resume(
|
||||
resume_codes: list[types.CodeType], resume_names: list[str], cg: PyCodegen
|
||||
) -> None:
|
||||
"""
|
||||
Calls the provided resume functions.
|
||||
|
||||
Expects the TOS to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Pops the cells and frame values, leaving the result of calling the resume functions on TOS.
|
||||
|
||||
Args:
|
||||
- resume_codes: list of resume function code objects to call
|
||||
- resume_names: list of the corresponding names of the resume functions
|
||||
- cg: PyCodegen object to output instructions to
|
||||
"""
|
||||
# NOTE: We will load cells as we load resume functions
|
||||
|
||||
# load resume functions except the root's
|
||||
cg.extend_output(create_copy(2))
|
||||
for i, (name, code) in enumerate(zip(resume_names, resume_codes)):
|
||||
if i == len(resume_names) - 1:
|
||||
break
|
||||
# stack: cells, frames, *(resume 1, ...), cells
|
||||
if code.co_freevars:
|
||||
# codegen filter for current frame's locals
|
||||
# current stack state: frames
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
]
|
||||
)
|
||||
cg.make_function_with_closure(name, code)
|
||||
for arg in argnames:
|
||||
# current stack state: frames, frames[i], *(prev locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(
|
||||
meta.num_stack + meta.locals_names[arg]
|
||||
),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(2),
|
||||
# frames, frames i live locals, frames[i]
|
||||
*create_binary_slice(meta.num_stack, None, True),
|
||||
# frames[i][num_stack:] = frame i live locals
|
||||
]
|
||||
)
|
||||
# current stack state: frames
|
||||
else:
|
||||
argnames = tuple(meta.locals_names.keys())
|
||||
argnames_null = tuple(meta.locals_null_keys)
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
cur_tx.f_code,
|
||||
cur_tx.lineno,
|
||||
resume_inst.offset,
|
||||
tuple(b.target.offset for b in cur_tx.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in cur_tx.block_stack),
|
||||
tuple(meta.stack_ctx_args),
|
||||
tuple(meta.locals_ctx_args),
|
||||
tuple(meta.stack_null_idxes),
|
||||
tuple(resume_codes),
|
||||
)
|
||||
resume_codes.append(new_code)
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
# add resume function to the global scope
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
cur_tx.output.install_global_unsafe(resume_name, new_code)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
cur_tx.output.install_global_unsafe(
|
||||
resume_name,
|
||||
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
|
||||
)
|
||||
package_name = resume_name
|
||||
|
||||
if cur_tx.package is not None:
|
||||
cur_tx.package.add_resume_function(
|
||||
new_code, cur_tx.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
if disable_current_frame_resume:
|
||||
from .eval_frame import skip_code
|
||||
|
||||
skip_code(resume_codes[0])
|
||||
|
||||
# load first resume function (to be called this frame)
|
||||
if resume_codes[-1].co_freevars:
|
||||
cg.make_function_with_closure(
|
||||
txes[-1], resume_names[-1], resume_codes[-1], True, 1
|
||||
)
|
||||
else:
|
||||
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
|
||||
|
||||
# load all other resume functions (to be called later)
|
||||
resume_names.pop()
|
||||
resume_codes.pop()
|
||||
for tx, name, code in zip(txes, resume_names, resume_codes):
|
||||
if code.co_freevars:
|
||||
cg.make_function_with_closure(tx, name, code, False, 0)
|
||||
else:
|
||||
cg.extend_output(cg.load_function_name(name, False, 0))
|
||||
cg.extend_output(create_swap(2))
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(resume_codes) - 1),
|
||||
create_instruction("BUILD_LIST", arg=len(resume_codes)),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
|
||||
# stack: cells, frames, [resume 1, ..., resume N - 1]
|
||||
# load root resume function
|
||||
cg.extend_output(create_swap(3))
|
||||
if resume_codes[-1].co_freevars:
|
||||
cg.extend_output(
|
||||
[
|
||||
cg.create_load_const(-1),
|
||||
cg.create_binary_subscr(),
|
||||
]
|
||||
)
|
||||
cg.make_function_with_closure(resume_names[-1], resume_codes[-1])
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_rot_n(3),
|
||||
]
|
||||
)
|
||||
else:
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
*cg.load_function_name(resume_names[-1], False),
|
||||
*create_rot_n(3),
|
||||
]
|
||||
)
|
||||
|
||||
# resume 1, [resume N, ..., resume 2], frames
|
||||
# resume 1 (+ NULL), [resume N, ..., resume 2], frames
|
||||
|
||||
# load top level-frame; final stack state should be:
|
||||
# first resume function (+ NULL),
|
||||
@ -3103,9 +2843,11 @@ class InstructionTranslatorBase(
|
||||
# TOS: [resumes, frames, *(frame 1 stack + locals)]
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_call_function_ex(False, True),
|
||||
*create_call_function_ex(False),
|
||||
create_instruction("RETURN_VALUE"),
|
||||
]
|
||||
)
|
||||
return cg.get_instructions()
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -3748,7 +3490,7 @@ class InstructionTranslatorBase(
|
||||
self.active_generic_context_managers.append(ctx)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
# See update_block_stack/create_resume for block stack details.
|
||||
# See create_call_resume_at for block stack details.
|
||||
# Only push a block if the current instruction's block is a
|
||||
# with block that is not nested in a try block - that is, the current
|
||||
# instruction's block target is the same as the top block's target.
|
||||
@ -4103,7 +3845,6 @@ class InstructionTranslatorBase(
|
||||
self.accept_prefix_inst = True
|
||||
self.prefix_insts = []
|
||||
self.exn_vt_stack = exn_vt_stack
|
||||
self.latest_bytecode_queue = deque(maxlen=20)
|
||||
|
||||
# Properties of the input/output code
|
||||
self.instructions: list[Instruction] = instructions
|
||||
@ -4449,7 +4190,9 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
assert len(all_stack_locals_metadata) == 1
|
||||
assert not all_stack_locals_metadata[0].stack_null_idxes
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_return_with_pops(inst, all_stack_locals_metadata[0].num_stack)
|
||||
self.codegen_return_after_compile_subgraph(
|
||||
inst, all_stack_locals_metadata[0]
|
||||
)
|
||||
)
|
||||
raise ReturnValueOp
|
||||
|
||||
@ -4834,10 +4577,13 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: list[StackLocalsMetadata],
|
||||
all_stack_locals_metadata: Any,
|
||||
disable_current_frame_resume: bool,
|
||||
) -> list[Instruction]:
|
||||
if config.nested_graph_breaks:
|
||||
return super().create_call_resume_at(inst, all_stack_locals_metadata)
|
||||
return super().create_call_resume_at(
|
||||
inst, all_stack_locals_metadata, disable_current_frame_resume
|
||||
)
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break in inlined function",
|
||||
context="",
|
||||
|
||||
@ -506,12 +506,6 @@ def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return unittest.skip("Requires Python 3.12+")(fn)
|
||||
|
||||
|
||||
def skipIfOnlyNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if sys.version_info >= (3, 13) or sys.version_info < (3, 12):
|
||||
return unittest.skip("Requires Python 3.12")(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if sys.version_info >= (3, 12):
|
||||
return unittest.expectedFailure(fn)
|
||||
|
||||
@ -51,6 +51,7 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .utils import (
|
||||
getfile,
|
||||
hashable,
|
||||
is_annotate_wrapped_function,
|
||||
is_lru_cache_wrapped_function,
|
||||
NP_SUPPORTED_MODULES,
|
||||
unwrap_if_wrapper,
|
||||
@ -154,6 +155,7 @@ manual_torch_name_rule_map: dict[
|
||||
type[UserFunctionVariable],
|
||||
],
|
||||
] = {
|
||||
"torch.fx.traceback.annotate": UserFunctionVariable,
|
||||
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
||||
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
||||
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
||||
@ -2994,6 +2996,9 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
||||
continue
|
||||
obj = torch_dir + k[len("torch/") :]
|
||||
if obj is not None:
|
||||
if is_annotate_wrapped_function(obj):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
obj = obj.__wrapped__
|
||||
if is_lru_cache_wrapped_function(obj):
|
||||
obj = obj.__wrapped__
|
||||
if obj in d and d[obj] != v:
|
||||
@ -3425,7 +3430,6 @@ MOD_INLINELIST = [
|
||||
"torch.fx._symbolic_trace",
|
||||
"torch.fx.experimental.proxy_tensor",
|
||||
"torch.fx.passes.shape_prop",
|
||||
"torch.fx.traceback",
|
||||
"torch.nn",
|
||||
"torch.overrides",
|
||||
"torch.random",
|
||||
|
||||
@ -1111,6 +1111,14 @@ def is_lru_cache_wrapped_function(
|
||||
)
|
||||
|
||||
|
||||
def is_annotate_wrapped_function(
|
||||
value: Any,
|
||||
) -> bool:
|
||||
return value == torch.fx.traceback.annotate and is_function(
|
||||
inspect.getattr_static(value, "__wrapped__")
|
||||
)
|
||||
|
||||
|
||||
_FuncTypes: TypeAlias = Union[
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
|
||||
@ -29,7 +29,6 @@ from .ctx_manager import (
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
|
||||
@ -1262,34 +1262,6 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
||||
return "_sdpa_kernel_variadic"
|
||||
|
||||
|
||||
class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
"""
|
||||
fx.traceback.annotate is a context manager that allows users to annotate the
|
||||
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
|
||||
to trace the body of the context manager. Instead we want to directly run
|
||||
the body of the context manager, so the Dynamo created Fx graphs have the
|
||||
right custom metadata. This variable tracker just runs __enter__ and
|
||||
__exit__ method (instead of tracing).
|
||||
"""
|
||||
|
||||
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
|
||||
def enter(self, tx, *args):
|
||||
cm = torch.fx.traceback.annotate(self.target_values)
|
||||
cm.__enter__()
|
||||
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def module_name(self):
|
||||
return "torch.fx.traceback"
|
||||
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
|
||||
class StreamVariable(VariableTracker):
|
||||
def __init__(self, proxy, value, device, **kwargs) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
|
||||
@ -52,7 +52,6 @@ from ..exc import (
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
SkipFrame,
|
||||
StepUnsupported,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
@ -1528,8 +1527,6 @@ class SkipFunctionVariable(VariableTracker):
|
||||
raise SkipFrame(
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||
)
|
||||
elif self.value is torch._dynamo.step_unsupported:
|
||||
raise StepUnsupported
|
||||
else:
|
||||
if config.dont_skip_tracing:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
@ -449,7 +449,7 @@ class ZipVariable(IteratorVariable):
|
||||
codegen.create_load_const("strict"),
|
||||
codegen.create_load_const(self.strict),
|
||||
create_instruction("BUILD_MAP", arg=1),
|
||||
*create_call_function_ex(True, False),
|
||||
*create_call_function_ex(True),
|
||||
]
|
||||
)
|
||||
|
||||
@ -487,7 +487,7 @@ class MapVariable(ZipVariable):
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_build_tuple(len(self.iterables) + 1),
|
||||
*create_call_function_ex(False, False),
|
||||
*create_call_function_ex(False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -1579,7 +1579,7 @@ class StringFormatVariable(VariableTracker):
|
||||
variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
|
||||
}
|
||||
codegen(variables.ConstDictVariable(kwargs))
|
||||
codegen.extend_output(create_call_function_ex(True, False))
|
||||
codegen.extend_output(create_call_function_ex(True))
|
||||
|
||||
|
||||
class DebuggingVariable(VariableTracker):
|
||||
|
||||
@ -125,7 +125,6 @@ supported_ctx_manager_classes = dict.fromkeys(
|
||||
torch.autograd.graph.disable_saved_tensors_hooks,
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.fx.traceback.annotate,
|
||||
# We'll let Dynamo inline into the contextlib part of these context
|
||||
# manager instances, all the way till it invokes the wrapped function
|
||||
# itself (at which point we wrap it back to special context manager
|
||||
@ -326,7 +325,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
DualLevelContextManager,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
@ -361,11 +359,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
|
||||
return InferenceModeVariable.create(tx, inf_mode)
|
||||
elif self.value is torch.fx.traceback.annotate:
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
return FxTracebackAnnotateVariable(
|
||||
args[0].as_python_constant(), source=self.source
|
||||
)
|
||||
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
|
||||
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user