#!/usr/bin/env python3 import gc import importlib import logging import os import re import sys import warnings from os.path import abspath, exists import torch from common import BenchmarkRunner, main from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs # We are primarily interested in tf32 datatype torch.backends.cuda.matmul.allow_tf32 = True os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam for torchbench_dir in ( "./torchbenchmark", "../torchbenchmark", "../torchbench", "../benchmark", "../../torchbenchmark", "../../torchbench", "../../benchmark", ): if exists(torchbench_dir): break assert exists(torchbench_dir), "../../torchbenchmark does not exist" original_dir = abspath(os.getcwd()) torchbench_dir = abspath(torchbench_dir) os.chdir(torchbench_dir) sys.path.append(torchbench_dir) # Some models have large dataset that doesn't fit in memory. Lower the batch # size to test the accuracy. USE_SMALL_BATCH_SIZE = { "demucs": 4, "densenet121": 4, "hf_Reformer": 4, "timm_efficientdet": 1, } DETECTRON2_MODELS = { "detectron2_fasterrcnn_r_101_c4", "detectron2_fasterrcnn_r_101_dc5", "detectron2_fasterrcnn_r_101_fpn", "detectron2_fasterrcnn_r_50_c4", "detectron2_fasterrcnn_r_50_dc5", "detectron2_fasterrcnn_r_50_fpn", "detectron2_maskrcnn_r_101_c4", "detectron2_maskrcnn_r_101_fpn", "detectron2_maskrcnn_r_50_fpn", } SKIP = { # https://github.com/pytorch/torchdynamo/issues/101 "detectron2_maskrcnn", # https://github.com/pytorch/torchdynamo/issues/145 "fambench_xlmr", } # Additional models that are skipped in training SKIP_TRAIN = { # not designed for training "pyhpc_equation_of_state", "pyhpc_isoneutral_mixing", "pyhpc_turbulent_kinetic_energy", # Unusual training setup "opacus_cifar10", "maml", } SKIP_TRAIN.update(DETECTRON2_MODELS) # These models support only train mode. So accuracy checking can't be done in # eval mode. ONLY_TRAINING_MODE = { "tts_angular", "tacotron2", "demucs", "hf_Reformer", "pytorch_struct", "yolov3", } ONLY_TRAINING_MODE.update(DETECTRON2_MODELS) # Need lower tolerance on GPU. GPU kernels have non deterministic kernels for these models. REQUIRE_HIGHER_TOLERANCE = { "alexnet", "attention_is_all_you_need_pytorch", "densenet121", "hf_Albert", "vgg16", "mobilenet_v3_large", "nvidia_deeprecommender", "timm_efficientdet", "vision_maskrcnn", } # These models need >1e-3 tolerance REQUIRE_EVEN_HIGHER_TOLERANCE = { "soft_actor_critic", "tacotron2", } REQUIRE_COSINE_TOLERACE = { # https://github.com/pytorch/torchdynamo/issues/556 "resnet50_quantized_qat", } # non-deterministic output / cant check correctness NONDETERMINISTIC = set() # These benchmarks took >600s on an i9-11900K CPU VERY_SLOW_BENCHMARKS = { "hf_BigBird", # 3339s "hf_Longformer", # 3062s "hf_T5", # 930s } # These benchmarks took >60s on an i9-11900K CPU SLOW_BENCHMARKS = { *VERY_SLOW_BENCHMARKS, "BERT_pytorch", # 137s "demucs", # 116s "fastNLP_Bert", # 242s "hf_Albert", # 221s "hf_Bart", # 400s "hf_Bert", # 334s "hf_DistilBert", # 187s "hf_GPT2", # 470s "hf_Reformer", # 141s "speech_transformer", # 317s "vision_maskrcnn", # 99s } TRT_NOT_YET_WORKING = { "alexnet", "resnet18", "resnet50", "mobilenet_v2", "mnasnet1_0", "squeezenet1_1", "shufflenetv2_x1_0", "vgg16", "resnext50_32x4d", } DYNAMIC_SHAPES_NOT_YET_WORKING = { "demucs", "timm_nfnet", } DONT_CHANGE_BATCH_SIZE = { "demucs", "pytorch_struct", "pyhpc_turbulent_kinetic_energy", } SKIP_ACCURACY_CHECK_MODELS = { # Models too large to have eager, dynamo and fp64_numbers simultaneosuly # even for 40 GB machine. We have tested accuracy for smaller version of # these models "hf_GPT2_large", "hf_T5_large", "timm_vision_transformer_large", } class TorchBenchmarkRunner(BenchmarkRunner): def __init__(self): super(TorchBenchmarkRunner, self).__init__() self.suite_name = "torchbench" @property def skip_models(self): return SKIP @property def slow_models(self): return SLOW_BENCHMARKS @property def very_slow_models(self): return VERY_SLOW_BENCHMARKS @property def non_deterministic_models(self): return NONDETERMINISTIC @property def skip_not_suitable_for_training_models(self): return SKIP_TRAIN @property def failing_fx2trt_models(self): return TRT_NOT_YET_WORKING @property def failing_dynamic_shape_models(self): return DYNAMIC_SHAPES_NOT_YET_WORKING @property def skip_accuracy_checks_large_models_dashboard(self): if self.args.dashboard: return SKIP_ACCURACY_CHECK_MODELS return set() def load_model( self, device, model_name, batch_size=None, ): is_training = self.args.training use_eval_mode = self.args.use_eval_mode dynamic_shapes = self.args.dynamic_shapes module = importlib.import_module(f"torchbenchmark.models.{model_name}") benchmark_cls = getattr(module, "Model", None) if not hasattr(benchmark_cls, "name"): benchmark_cls.name = model_name cant_change_batch_size = ( not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True) or model_name in DONT_CHANGE_BATCH_SIZE ) if cant_change_batch_size: batch_size = None if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE: batch_size = USE_SMALL_BATCH_SIZE[model_name] if is_training: benchmark = benchmark_cls( test="train", device=device, jit=False, batch_size=batch_size ) else: benchmark = benchmark_cls( test="eval", device=device, jit=False, batch_size=batch_size ) if dynamic_shapes: if not hasattr(benchmark, "get_dynamic_shapes_module"): raise NotImplementedError("Dynamic Shapes not supported") model, example_inputs = benchmark.get_dynamic_shapes_module() else: model, example_inputs = benchmark.get_module() # Models that must be in train mode while training if is_training and (not use_eval_mode or model_name in ONLY_TRAINING_MODE): model.train() else: model.eval() gc.collect() batch_size = benchmark.batch_size self.init_optimizer(device, model.parameters()) # Torchbench has quite different setup for yolov3, so directly passing # the right example_inputs if model_name == "yolov3": example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),) # global current_name, current_device # current_device = device # current_name = benchmark.name self.validate_model(model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size def iter_model_names(self, args): from torchbenchmark import _list_model_paths models = _list_model_paths() start, end = self.get_benchmark_indices(len(models)) for index, model_path in enumerate(models): if index < start or index >= end: continue model_name = os.path.basename(model_path) if ( not re.search("|".join(args.filter), model_name, re.I) or re.search("|".join(args.exclude), model_name, re.I) or model_name in SKIP ): continue yield model_name def pick_grad(self, name, is_training): if is_training or name in ("maml",): return torch.enable_grad() else: return torch.no_grad() def get_tolerance_and_cosine_flag(self, is_training, current_device, name): tolerance = 1e-4 cosine = self.args.cosine # Increase the tolerance for torch allclose if self.args.float16: return 1e-3, cosine if is_training and current_device == "cuda": if name in REQUIRE_COSINE_TOLERACE: cosine = True elif name in REQUIRE_HIGHER_TOLERANCE: tolerance = 1e-3 elif name in REQUIRE_EVEN_HIGHER_TOLERANCE: tolerance = 8 * 1e-2 return tolerance, cosine def compute_loss(self, pred): return reduce_to_scalar_loss(pred) def forward_pass(self, mod, inputs, collect_outputs=True): return mod(*inputs) def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) mod.zero_grad(True) with self.autocast(): pred = mod(*cloned_inputs) loss = self.compute_loss(pred) self.grad_scaler.scale(loss).backward() self.optimizer_step() if collect_outputs: return collect_results(mod, pred, loss, cloned_inputs) return None if __name__ == "__main__": logging.basicConfig(level=logging.WARNING) warnings.filterwarnings("ignore") main(TorchBenchmarkRunner(), original_dir)