mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
FIXES https://github.com/pytorch/pytorch/issues/144775 frfr See details on the problem: https://github.com/pytorch/pytorch/issues/144775#issuecomment-2611699385 We fixed some silent incorrectness, but it results in less nodes DCE'd. The benchmark iteration loop had some dead code which could contain side effect ops that aren't safe to DCE. The regression is expected. This PR removes the compile time benchmarking of the dead code, which should reduce the noise of the benchmark and aligns with the benchmarking used by performance tests New benchmark results: ```python dev,name,batch_size,accuracy,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks,autograd_captures,autograd_compiles,cudagraph_skips,compilation_latency cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,39.322364 # after https://github.com/pytorch/pytorch/pull/144319 cuda,BartForConditionalGeneration,1,pass,897,1,0,0,0,0,0,38.972257 # before https://github.com/pytorch/pytorch/pull/144319 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145590 Approved by: https://github.com/jansel ghstack dependencies: #145447
442 lines
12 KiB
Python
Executable File
442 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import warnings
|
|
|
|
|
|
try:
|
|
from .common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
|
|
except ImportError:
|
|
from common import BenchmarkRunner, download_retry_decorator, load_yaml_file, main
|
|
|
|
import torch
|
|
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
|
|
from torch._dynamo.utils import clone_inputs
|
|
|
|
|
|
# Enable FX graph caching
|
|
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
def pip_install(package):
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
|
|
|
|
|
try:
|
|
importlib.import_module("timm")
|
|
except ModuleNotFoundError:
|
|
print("Installing PyTorch Image Models...")
|
|
pip_install("git+https://github.com/rwightman/pytorch-image-models")
|
|
finally:
|
|
from timm import __version__ as timmversion
|
|
from timm.data import resolve_data_config
|
|
from timm.models import create_model
|
|
|
|
TIMM_MODELS = {}
|
|
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
|
|
|
with open(filename) as fh:
|
|
lines = fh.readlines()
|
|
lines = [line.rstrip() for line in lines]
|
|
for line in lines:
|
|
model_name, batch_size = line.split(" ")
|
|
TIMM_MODELS[model_name] = int(batch_size)
|
|
|
|
|
|
# TODO - Figure out the reason of cold start memory spike
|
|
|
|
BATCH_SIZE_DIVISORS = {
|
|
"beit_base_patch16_224": 2,
|
|
"convit_base": 2,
|
|
"convmixer_768_32": 2,
|
|
"convnext_base": 2,
|
|
"cspdarknet53": 2,
|
|
"deit_base_distilled_patch16_224": 2,
|
|
"gluon_xception65": 2,
|
|
"mobilevit_s": 2,
|
|
"pnasnet5large": 2,
|
|
"poolformer_m36": 2,
|
|
"resnest101e": 2,
|
|
"swin_base_patch4_window7_224": 2,
|
|
"swsl_resnext101_32x16d": 2,
|
|
"vit_base_patch16_224": 2,
|
|
"volo_d1_224": 2,
|
|
"jx_nest_base": 4,
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE = {
|
|
"fbnetv3_b",
|
|
"gmixer_24_224",
|
|
"hrnet_w18",
|
|
"inception_v3",
|
|
"mixer_b16_224",
|
|
"mobilenetv3_large_100",
|
|
"sebotnet33ts_256",
|
|
"selecsls42b",
|
|
"convnext_base",
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE_AMP = {
|
|
"poolformer_m36",
|
|
}
|
|
|
|
REQUIRE_EVEN_HIGHER_TOLERANCE = {
|
|
"levit_128",
|
|
"sebotnet33ts_256",
|
|
"beit_base_patch16_224",
|
|
"cspdarknet53",
|
|
}
|
|
|
|
# These models need higher tolerance in MaxAutotune mode
|
|
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
|
|
"gluon_inception_v3",
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
|
|
"adv_inception_v3",
|
|
"botnet26t_256",
|
|
"gluon_inception_v3",
|
|
"selecsls42b",
|
|
"swsl_resnext101_32x16d",
|
|
}
|
|
|
|
SCALED_COMPUTE_LOSS = {
|
|
"ese_vovnet19b_dw",
|
|
"fbnetc_100",
|
|
"mnasnet_100",
|
|
"mobilevit_s",
|
|
"sebotnet33ts_256",
|
|
}
|
|
|
|
FORCE_AMP_FOR_FP16_BF16_MODELS = {
|
|
"convit_base",
|
|
"xcit_large_24_p8_224",
|
|
}
|
|
|
|
SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
|
|
"xcit_large_24_p8_224",
|
|
}
|
|
|
|
REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
|
|
"inception_v3",
|
|
"mobilenetv3_large_100",
|
|
"cspdarknet53",
|
|
}
|
|
|
|
|
|
def refresh_model_names():
|
|
import glob
|
|
|
|
from timm.models import list_models
|
|
|
|
def read_models_from_docs():
|
|
models = set()
|
|
# TODO - set the path to pytorch-image-models repo
|
|
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
|
|
with open(fn) as f:
|
|
while True:
|
|
line = f.readline()
|
|
if not line:
|
|
break
|
|
if not line.startswith("model = timm.create_model("):
|
|
continue
|
|
|
|
model = line.split("'")[1]
|
|
# print(model)
|
|
models.add(model)
|
|
return models
|
|
|
|
def get_family_name(name):
|
|
known_families = [
|
|
"darknet",
|
|
"densenet",
|
|
"dla",
|
|
"dpn",
|
|
"ecaresnet",
|
|
"halo",
|
|
"regnet",
|
|
"efficientnet",
|
|
"deit",
|
|
"mobilevit",
|
|
"mnasnet",
|
|
"convnext",
|
|
"resnet",
|
|
"resnest",
|
|
"resnext",
|
|
"selecsls",
|
|
"vgg",
|
|
"xception",
|
|
]
|
|
|
|
for known_family in known_families:
|
|
if known_family in name:
|
|
return known_family
|
|
|
|
if name.startswith("gluon_"):
|
|
return "gluon_" + name.split("_")[1]
|
|
return name.split("_")[0]
|
|
|
|
def populate_family(models):
|
|
family = {}
|
|
for model_name in models:
|
|
family_name = get_family_name(model_name)
|
|
if family_name not in family:
|
|
family[family_name] = []
|
|
family[family_name].append(model_name)
|
|
return family
|
|
|
|
docs_models = read_models_from_docs()
|
|
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
|
|
|
|
all_models_family = populate_family(all_models)
|
|
docs_models_family = populate_family(docs_models)
|
|
|
|
for key in docs_models_family:
|
|
del all_models_family[key]
|
|
|
|
chosen_models = set()
|
|
chosen_models.update(value[0] for value in docs_models_family.values())
|
|
|
|
chosen_models.update(value[0] for key, value in all_models_family.items())
|
|
|
|
filename = "timm_models_list.txt"
|
|
if os.path.exists("benchmarks"):
|
|
filename = "benchmarks/" + filename
|
|
with open(filename, "w") as fw:
|
|
for model_name in sorted(chosen_models):
|
|
fw.write(model_name + "\n")
|
|
|
|
|
|
class TimmRunner(BenchmarkRunner):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.suite_name = "timm_models"
|
|
|
|
@property
|
|
def _config(self):
|
|
return load_yaml_file("timm_models.yaml")
|
|
|
|
@property
|
|
def _skip(self):
|
|
return self._config["skip"]
|
|
|
|
@property
|
|
def skip_models(self):
|
|
return self._skip["all"]
|
|
|
|
@property
|
|
def force_amp_for_fp16_bf16_models(self):
|
|
return FORCE_AMP_FOR_FP16_BF16_MODELS
|
|
|
|
@property
|
|
def force_fp16_for_bf16_models(self):
|
|
return set()
|
|
|
|
@property
|
|
def get_output_amp_train_process_func(self):
|
|
return {}
|
|
|
|
@property
|
|
def skip_accuracy_check_as_eager_non_deterministic(self):
|
|
if self.args.accuracy and self.args.training:
|
|
return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
|
|
return set()
|
|
|
|
@property
|
|
def guard_on_nn_module_models(self):
|
|
return {
|
|
"convit_base",
|
|
}
|
|
|
|
@property
|
|
def inline_inbuilt_nn_modules_models(self):
|
|
return {
|
|
"lcnet_050",
|
|
}
|
|
|
|
@download_retry_decorator
|
|
def _download_model(self, model_name):
|
|
model = create_model(
|
|
model_name,
|
|
in_chans=3,
|
|
scriptable=False,
|
|
num_classes=None,
|
|
drop_rate=0.0,
|
|
drop_path_rate=None,
|
|
drop_block_rate=None,
|
|
pretrained=True,
|
|
)
|
|
return model
|
|
|
|
def load_model(
|
|
self,
|
|
device,
|
|
model_name,
|
|
batch_size=None,
|
|
extra_args=None,
|
|
):
|
|
if self.args.enable_activation_checkpointing:
|
|
raise NotImplementedError(
|
|
"Activation checkpointing not implemented for Timm models"
|
|
)
|
|
|
|
is_training = self.args.training
|
|
use_eval_mode = self.args.use_eval_mode
|
|
|
|
channels_last = self._args.channels_last
|
|
model = self._download_model(model_name)
|
|
|
|
if model is None:
|
|
raise RuntimeError(f"Failed to load model '{model_name}'")
|
|
model.to(
|
|
device=device,
|
|
memory_format=torch.channels_last if channels_last else None,
|
|
)
|
|
|
|
self.num_classes = model.num_classes
|
|
|
|
data_config = resolve_data_config(
|
|
vars(self._args) if timmversion >= "0.8.0" else self._args,
|
|
model=model,
|
|
use_test_size=not is_training,
|
|
)
|
|
input_size = data_config["input_size"]
|
|
recorded_batch_size = TIMM_MODELS[model_name]
|
|
|
|
if model_name in BATCH_SIZE_DIVISORS:
|
|
recorded_batch_size = max(
|
|
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
|
|
)
|
|
batch_size = batch_size or recorded_batch_size
|
|
|
|
torch.manual_seed(1337)
|
|
input_tensor = torch.randint(
|
|
256, size=(batch_size,) + input_size, device=device
|
|
).to(dtype=torch.float32)
|
|
mean = torch.mean(input_tensor)
|
|
std_dev = torch.std(input_tensor)
|
|
example_inputs = (input_tensor - mean) / std_dev
|
|
|
|
if channels_last:
|
|
example_inputs = example_inputs.contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
example_inputs = [
|
|
example_inputs,
|
|
]
|
|
self.target = self._gen_target(batch_size, device)
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss().to(device)
|
|
|
|
if model_name in SCALED_COMPUTE_LOSS:
|
|
self.compute_loss = self.scaled_compute_loss
|
|
|
|
if is_training and not use_eval_mode:
|
|
model.train()
|
|
else:
|
|
model.eval()
|
|
|
|
self.validate_model(model, example_inputs)
|
|
|
|
return device, model_name, model, example_inputs, batch_size
|
|
|
|
def iter_model_names(self, args):
|
|
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
|
|
model_names = sorted(TIMM_MODELS.keys())
|
|
start, end = self.get_benchmark_indices(len(model_names))
|
|
for index, model_name in enumerate(model_names):
|
|
if index < start or index >= end:
|
|
continue
|
|
if (
|
|
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
|
|
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
|
|
or model_name in args.exclude_exact
|
|
or model_name in self.skip_models
|
|
):
|
|
continue
|
|
|
|
yield model_name
|
|
|
|
def pick_grad(self, name, is_training):
|
|
if is_training:
|
|
return torch.enable_grad()
|
|
else:
|
|
return torch.no_grad()
|
|
|
|
def use_larger_multiplier_for_smaller_tensor(self, name):
|
|
return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
|
|
|
|
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
|
|
cosine = self.args.cosine
|
|
tolerance = 1e-3
|
|
|
|
if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
|
|
# the conv-batchnorm fusion used under freezing may cause relatively
|
|
# large numerical difference. We need are larger tolerance.
|
|
# Check https://github.com/pytorch/pytorch/issues/120545 for context
|
|
tolerance = 8 * 1e-2
|
|
|
|
if is_training:
|
|
from torch._inductor import config as inductor_config
|
|
|
|
if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
|
|
inductor_config.max_autotune
|
|
and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
|
|
):
|
|
tolerance = 8 * 1e-2
|
|
elif name in REQUIRE_HIGHER_TOLERANCE or (
|
|
self.args.amp and name in REQUIRE_HIGHER_TOLERANCE_AMP
|
|
):
|
|
tolerance = 4 * 1e-2
|
|
else:
|
|
tolerance = 1e-2
|
|
return tolerance, cosine
|
|
|
|
def _gen_target(self, batch_size, device):
|
|
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
|
|
self.num_classes
|
|
)
|
|
|
|
def compute_loss(self, pred):
|
|
# High loss values make gradient checking harder, as small changes in
|
|
# accumulation order upsets accuracy checks.
|
|
return reduce_to_scalar_loss(pred)
|
|
|
|
def scaled_compute_loss(self, pred):
|
|
# Loss values need zoom out further.
|
|
return reduce_to_scalar_loss(pred) / 1000.0
|
|
|
|
def forward_pass(self, mod, inputs, collect_outputs=True):
|
|
with self.autocast(**self.autocast_arg):
|
|
return mod(*inputs)
|
|
|
|
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
|
cloned_inputs = clone_inputs(inputs)
|
|
self.optimizer_zero_grad(mod)
|
|
with self.autocast(**self.autocast_arg):
|
|
pred = mod(*cloned_inputs)
|
|
if isinstance(pred, tuple):
|
|
pred = pred[0]
|
|
loss = self.compute_loss(pred)
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.optimizer_step()
|
|
if collect_outputs:
|
|
return collect_results(mod, pred, loss, cloned_inputs)
|
|
return None
|
|
|
|
|
|
def timm_main():
|
|
logging.basicConfig(level=logging.WARNING)
|
|
warnings.filterwarnings("ignore")
|
|
main(TimmRunner())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
timm_main()
|