mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94146 Approved by: https://github.com/ezyang
348 lines
10 KiB
Python
Executable File
348 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import importlib
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import warnings
|
|
|
|
import torch
|
|
from common import BenchmarkRunner, main
|
|
|
|
from torch._dynamo.testing import collect_results
|
|
from torch._dynamo.utils import clone_inputs
|
|
|
|
|
|
def pip_install(package):
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
|
|
|
|
|
try:
|
|
importlib.import_module("timm")
|
|
except ModuleNotFoundError:
|
|
print("Installing Pytorch Image Models...")
|
|
pip_install("git+https://github.com/rwightman/pytorch-image-models")
|
|
finally:
|
|
from timm import __version__ as timmversion
|
|
from timm.data import resolve_data_config
|
|
from timm.models import create_model
|
|
|
|
TIMM_MODELS = dict()
|
|
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
|
|
|
with open(filename, "r") as fh:
|
|
lines = fh.readlines()
|
|
lines = [line.rstrip() for line in lines]
|
|
for line in lines:
|
|
model_name, batch_size = line.split(" ")
|
|
TIMM_MODELS[model_name] = int(batch_size)
|
|
|
|
|
|
# TODO - Figure out the reason of cold start memory spike
|
|
|
|
BATCH_SIZE_DIVISORS = {
|
|
"beit_base_patch16_224": 2,
|
|
"cait_m36_384": 2,
|
|
"convit_base": 2,
|
|
"convmixer_768_32": 2,
|
|
"convnext_base": 2,
|
|
"cspdarknet53": 2,
|
|
"deit_base_distilled_patch16_224": 2,
|
|
"dpn107": 2,
|
|
"gluon_xception65": 2,
|
|
"mobilevit_s": 2,
|
|
"pit_b_224": 2,
|
|
"pnasnet5large": 2,
|
|
"poolformer_m36": 2,
|
|
"res2net101_26w_4s": 2,
|
|
"resnest101e": 2,
|
|
"sebotnet33ts_256": 2,
|
|
"swin_base_patch4_window7_224": 2,
|
|
"swsl_resnext101_32x16d": 2,
|
|
"twins_pcpvt_base": 2,
|
|
"vit_base_patch16_224": 2,
|
|
"volo_d1_224": 2,
|
|
"jx_nest_base": 4,
|
|
"xcit_large_24_p8_224": 4,
|
|
}
|
|
|
|
REQUIRE_HIGHER_TOLERANCE = set("botnet26t_256")
|
|
|
|
SKIP = {
|
|
# Unusual training setup
|
|
"levit_128",
|
|
}
|
|
|
|
|
|
MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = {
|
|
"cait_m36_384": 4,
|
|
}
|
|
|
|
|
|
def refresh_model_names():
|
|
import glob
|
|
|
|
from timm.models import list_models
|
|
|
|
def read_models_from_docs():
|
|
models = set()
|
|
# TODO - set the path to pytorch-image-models repo
|
|
for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
|
|
with open(fn, "r") as f:
|
|
while True:
|
|
line = f.readline()
|
|
if not line:
|
|
break
|
|
if not line.startswith("model = timm.create_model("):
|
|
continue
|
|
|
|
model = line.split("'")[1]
|
|
# print(model)
|
|
models.add(model)
|
|
return models
|
|
|
|
def get_family_name(name):
|
|
known_families = [
|
|
"darknet",
|
|
"densenet",
|
|
"dla",
|
|
"dpn",
|
|
"ecaresnet",
|
|
"halo",
|
|
"regnet",
|
|
"efficientnet",
|
|
"deit",
|
|
"mobilevit",
|
|
"mnasnet",
|
|
"convnext",
|
|
"resnet",
|
|
"resnest",
|
|
"resnext",
|
|
"selecsls",
|
|
"vgg",
|
|
"xception",
|
|
]
|
|
|
|
for known_family in known_families:
|
|
if known_family in name:
|
|
return known_family
|
|
|
|
if name.startswith("gluon_"):
|
|
return "gluon_" + name.split("_")[1]
|
|
return name.split("_")[0]
|
|
|
|
def populate_family(models):
|
|
family = dict()
|
|
for model_name in models:
|
|
family_name = get_family_name(model_name)
|
|
if family_name not in family:
|
|
family[family_name] = []
|
|
family[family_name].append(model_name)
|
|
return family
|
|
|
|
docs_models = read_models_from_docs()
|
|
all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
|
|
|
|
all_models_family = populate_family(all_models)
|
|
docs_models_family = populate_family(docs_models)
|
|
|
|
# print(docs_models_family.keys())
|
|
for key in docs_models_family:
|
|
del all_models_family[key]
|
|
|
|
chosen_models = set()
|
|
for value in docs_models_family.values():
|
|
chosen_models.add(value[0])
|
|
|
|
for key, value in all_models_family.items():
|
|
chosen_models.add(value[0])
|
|
|
|
filename = "timm_models_list.txt"
|
|
if os.path.exists("benchmarks"):
|
|
filename = "benchmarks/" + filename
|
|
with open(filename, "w") as fw:
|
|
for model_name in sorted(chosen_models):
|
|
fw.write(model_name + "\n")
|
|
|
|
|
|
class TimmRunnner(BenchmarkRunner):
|
|
def __init__(self):
|
|
super(TimmRunnner, self).__init__()
|
|
self.suite_name = "timm_models"
|
|
|
|
def load_model(
|
|
self,
|
|
device,
|
|
model_name,
|
|
batch_size=None,
|
|
):
|
|
|
|
is_training = self.args.training
|
|
use_eval_mode = self.args.use_eval_mode
|
|
|
|
# _, model_dtype, data_dtype = self.resolve_precision()
|
|
channels_last = self._args.channels_last
|
|
|
|
tries = 1
|
|
success = False
|
|
model = None
|
|
total_allowed_tries = 5
|
|
while not success and tries <= total_allowed_tries:
|
|
try:
|
|
model = create_model(
|
|
model_name,
|
|
in_chans=3,
|
|
scriptable=False,
|
|
num_classes=None,
|
|
drop_rate=0.0,
|
|
drop_path_rate=None,
|
|
drop_block_rate=None,
|
|
pretrained=True,
|
|
# global_pool=kwargs.pop('gp', 'fast'),
|
|
# num_classes=kwargs.pop('num_classes', None),
|
|
# drop_rate=kwargs.pop('drop', 0.),
|
|
# drop_path_rate=kwargs.pop('drop_path', None),
|
|
# drop_block_rate=kwargs.pop('drop_block', None),
|
|
)
|
|
success = True
|
|
except Exception as e:
|
|
tries += 1
|
|
if tries <= total_allowed_tries:
|
|
wait = tries * 30
|
|
print(
|
|
f"Failed to load model: {e}. Trying again ({tries}/{total_allowed_tries}) after {wait}s"
|
|
)
|
|
time.sleep(wait)
|
|
|
|
if model is None:
|
|
raise RuntimeError(f"Failed to load model '{model_name}'")
|
|
|
|
model.to(
|
|
device=device,
|
|
memory_format=torch.channels_last if channels_last else None,
|
|
)
|
|
|
|
self.num_classes = model.num_classes
|
|
|
|
data_config = resolve_data_config(
|
|
vars(self._args) if timmversion >= "0.8.0" else self._args,
|
|
model=model,
|
|
use_test_size=not is_training,
|
|
)
|
|
input_size = data_config["input_size"]
|
|
recorded_batch_size = TIMM_MODELS[model_name]
|
|
|
|
if model_name in BATCH_SIZE_DIVISORS:
|
|
recorded_batch_size = max(
|
|
int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
|
|
)
|
|
batch_size = batch_size or recorded_batch_size
|
|
|
|
# Control the memory footprint for few models
|
|
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
|
|
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
|
|
|
|
# example_inputs = torch.randn(
|
|
# (batch_size,) + input_size, device=device, dtype=data_dtype
|
|
# )
|
|
torch.manual_seed(1337)
|
|
input_tensor = torch.randint(
|
|
256, size=(batch_size,) + input_size, device=device
|
|
).to(dtype=torch.float32)
|
|
mean = torch.mean(input_tensor)
|
|
std_dev = torch.std(input_tensor)
|
|
example_inputs = (input_tensor - mean) / std_dev
|
|
|
|
if channels_last:
|
|
example_inputs = example_inputs.contiguous(
|
|
memory_format=torch.channels_last
|
|
)
|
|
example_inputs = [
|
|
example_inputs,
|
|
]
|
|
self.target = self._gen_target(batch_size, device)
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss().to(device)
|
|
if is_training and not use_eval_mode:
|
|
model.train()
|
|
else:
|
|
model.eval()
|
|
|
|
self.validate_model(model, example_inputs)
|
|
|
|
return device, model_name, model, example_inputs, batch_size
|
|
|
|
def iter_model_names(self, args):
|
|
# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
|
|
model_names = sorted(TIMM_MODELS.keys())
|
|
start, end = self.get_benchmark_indices(len(model_names))
|
|
for index, model_name in enumerate(model_names):
|
|
if index < start or index >= end:
|
|
continue
|
|
if (
|
|
not re.search("|".join(args.filter), model_name, re.I)
|
|
or re.search("|".join(args.exclude), model_name, re.I)
|
|
or model_name in args.exclude_exact
|
|
or model_name in self.skip_models
|
|
):
|
|
continue
|
|
|
|
yield model_name
|
|
|
|
def pick_grad(self, name, is_training):
|
|
if is_training:
|
|
return torch.enable_grad()
|
|
else:
|
|
return torch.no_grad()
|
|
|
|
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
|
|
cosine = self.args.cosine
|
|
tolerance = 1e-3
|
|
if is_training:
|
|
if REQUIRE_HIGHER_TOLERANCE:
|
|
tolerance = 2 * 1e-2
|
|
else:
|
|
tolerance = 1e-2
|
|
return tolerance, cosine
|
|
|
|
def _gen_target(self, batch_size, device):
|
|
# return torch.ones((batch_size,) + (), device=device, dtype=torch.long)
|
|
return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
|
|
self.num_classes
|
|
)
|
|
|
|
def compute_loss(self, pred):
|
|
# High loss values make gradient checking harder, as small changes in
|
|
# accumulation order upsets accuracy checks.
|
|
return self.loss(pred, self.target) / 10.0
|
|
|
|
def forward_pass(self, mod, inputs, collect_outputs=True):
|
|
with self.autocast():
|
|
return mod(*inputs)
|
|
|
|
def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
|
|
cloned_inputs = clone_inputs(inputs)
|
|
self.optimizer_zero_grad(mod)
|
|
with self.autocast():
|
|
pred = mod(*cloned_inputs)
|
|
if isinstance(pred, tuple):
|
|
pred = pred[0]
|
|
loss = self.compute_loss(pred)
|
|
self.grad_scaler.scale(loss).backward()
|
|
self.optimizer_step()
|
|
if collect_outputs:
|
|
return collect_results(mod, pred, loss, cloned_inputs)
|
|
return None
|
|
|
|
|
|
def timm_main():
|
|
logging.basicConfig(level=logging.WARNING)
|
|
warnings.filterwarnings("ignore")
|
|
main(TimmRunnner())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
timm_main()
|