mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: It is quite a lot of code because I pulled some code from torchaudio and torchvision to remove issues I had to get latest version with pytorch built from source while I can't build there libs from source (dependency missing for torchaudio). The compare script generates table as follows: | model | task | speedup | mean (before) | var (before) | mean (after) | var (after) | | -- | -- | -- | -- | -- | -- | -- | | resnet18 | vjp | 1.021151844124464 | 1.5627719163894653 | 0.005164200905710459 | 1.5304011106491089 | 0.003979875706136227 | | resnet18 | vhp | 0.9919114430761606 | 6.8089728355407715 | 0.019538333639502525 | 6.86449670791626 | 0.014775685034692287 | | resnet18 | jvp | 0.9715963084255123 | 5.720699310302734 | 0.08197150379419327 | 5.887938499450684 | 0.018408503383398056 | | ppl_simple_reg | vjp | 0.9529183269165618 | 0.000362396240234375 | 7.526952949810095e-10 | 0.00038030146970413625 | 7.726220357939795e-11 | | ppl_simple_reg | vhp | 0.9317708619586977 | 0.00048058031825348735 | 5.035701855504726e-10 | 0.0005157709238119423 | 3.250243477137538e-11 | | ppl_simple_reg | jvp | 0.8609755877018406 | 0.00045447348384186625 | 9.646707044286273e-11 | 0.0005278587341308594 | 1.4493808930815533e-10 | | ppl_simple_reg | hvp | 0.9764100147808232 | 0.0005881547695025802 | 7.618464747949361e-10 | 0.0006023645401000977 | 6.370915461850757e-10 | | ppl_simple_reg | jacobian | 1.0019173715134297 | 0.0003612995205912739 | 2.2979899233499523e-11 | 0.0003606081008911133 | 1.2609764794835332e-11 | | ppl_simple_reg | hessian | 1.0358429970264393 | 0.00206911563873291 | 2.590938796842579e-09 | 0.0019975185859948397 | 2.8916853356264482e-09 | | ppl_robust_reg | vjp | 1.0669910916521521 | 0.0017304659122601151 | 3.1047047155396967e-09 | 0.0016218185191974044 | 4.926861585374809e-09 | | ppl_robust_reg | vhp | 1.0181130455462972 | 0.0029563189018517733 | 2.6359153082466946e-08 | 0.0029037236236035824 | 1.020585038702393e-08 | | ppl_robust_reg | jvp | 0.9818360373406179 | 0.0026934861671179533 | 6.981357714153091e-09 | 0.00274331565015018 | 3.589908459389335e-08 | | ppl_robust_reg | hvp | 1.0270848910527002 | 0.005576515104621649 | 3.2798087801211295e-08 | 0.005429458804428577 | 6.438724398094564e-08 | | ppl_robust_reg | jacobian | 1.0543611284155785 | 0.00167675013653934 | 2.3236829349571053e-08 | 0.001590299652889371 | 1.2011492245278532e-08 | | ppl_robust_reg | hessian | 1.0535378727082656 | 0.01643357239663601 | 1.8450685956850066e-06 | 0.015598463825881481 | 2.1876705602608126e-07 | | wav2letter | vjp | 1.0060408105086573 | 0.3516994118690491 | 1.4463969819189515e-05 | 0.349587619304657 | 9.897866402752697e-05 | | wav2letter | vhp | 0.9873655295086051 | 1.1196287870407104 | 0.00474404776468873 | 1.133955717086792 | 0.009759620763361454 | | wav2letter | jvp | 0.9741820317882822 | 0.7888165712356567 | 0.0017476462526246905 | 0.8097219467163086 | 0.0018235758179798722 | | transfo | vjp | 0.9883954031921641 | 2.8865864276885986 | 0.008410997688770294 | 2.9204773902893066 | 0.006901870481669903 | | transfo | vhp | 1.0111290842971339 | 8.374398231506348 | 0.014904373325407505 | 8.282224655151367 | 0.04449500888586044 | | transfo | jvp | 1.0080534543381963 | 6.293097972869873 | 0.03796082362532616 | 6.24282169342041 | 0.010179692879319191 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/40586 Reviewed By: pbelevich Differential Revision: D23242101 Pulled By: albanD fbshipit-source-id: a2b92d5a4341fe1472711a685ca425ec257d6384
154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
import torch
|
|
from torch.autograd import functional
|
|
|
|
import time
|
|
from argparse import ArgumentParser
|
|
from collections import defaultdict
|
|
from typing import NamedTuple, Callable, List, Any
|
|
|
|
import ppl_models
|
|
import vision_models
|
|
import audio_text_models
|
|
|
|
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
|
|
|
|
# Listing of the different tasks
|
|
FAST_TASKS_NO_DOUBLE_BACK = [
|
|
"vjp",
|
|
]
|
|
|
|
FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
|
|
"vhp",
|
|
"jvp",
|
|
]
|
|
|
|
ALL_TASKS = FAST_TASKS + [
|
|
"hvp",
|
|
"jacobian",
|
|
"hessian"
|
|
]
|
|
|
|
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
|
|
|
|
# Model definition which contains:
|
|
# - name: a string with the model name.
|
|
# - getter: a function to get the model. It takes as input the device on which the model
|
|
# will run. It should return the forward function and the parameters (Tensors) used as
|
|
# input for the forward function. Note that the forward must *not* have any side effect.
|
|
# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
|
|
# - unsupported: the list of tasks that this model cannot run.
|
|
class ModelDef(NamedTuple):
|
|
name: str
|
|
getter: GetterType
|
|
tasks: List[str]
|
|
unsupported: List[str]
|
|
|
|
MODELS = [
|
|
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
|
|
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
|
|
ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
|
|
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
|
|
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
|
|
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
|
|
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
|
|
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
|
|
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
|
|
]
|
|
|
|
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
|
v: VType
|
|
|
|
if task in ["vjp"]:
|
|
out = model(*inp)
|
|
v = torch.rand_like(out)
|
|
elif task in ["jvp", "hvp", "vhp"]:
|
|
if isinstance(inp, tuple):
|
|
v = tuple(torch.rand_like(i) for i in inp)
|
|
else:
|
|
v = torch.rand_like(inp)
|
|
else:
|
|
v = None
|
|
|
|
return v
|
|
|
|
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
|
|
func = getattr(functional, task)
|
|
|
|
if v is not None:
|
|
res = func(model, inp, v=v, strict=True)
|
|
else:
|
|
res = func(model, inp, strict=True)
|
|
|
|
def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
|
|
if args.gpu == -1:
|
|
device = torch.device("cpu")
|
|
|
|
def noop():
|
|
pass
|
|
do_sync = noop
|
|
else:
|
|
device = torch.device("cuda:{}".format(args.gpu))
|
|
do_sync = torch.cuda.synchronize
|
|
|
|
model, inp = model_getter(device)
|
|
|
|
v = get_v_for(model, inp, task)
|
|
# Warmup
|
|
run_once(model, inp, task, v)
|
|
|
|
elapsed = []
|
|
for it in range(args.num_iters):
|
|
do_sync()
|
|
start = time.time()
|
|
run_once(model, inp, task, v)
|
|
do_sync()
|
|
elapsed.append(time.time() - start)
|
|
|
|
return elapsed
|
|
|
|
def main():
|
|
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
|
|
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
|
|
parser.add_argument("--num-iters", type=int, default=10)
|
|
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
|
|
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
|
|
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
|
|
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
|
|
parser.add_argument("--num-threads", type=int, default=10,
|
|
help="Number of concurrent threads to use when running on cpu")
|
|
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
|
|
args = parser.parse_args()
|
|
|
|
results: TimingResultType = defaultdict(defaultdict)
|
|
torch.set_num_threads(args.num_threads)
|
|
torch.set_num_interop_threads(args.num_threads)
|
|
|
|
# This automatically seed cuda if it is available
|
|
torch.manual_seed(args.seed)
|
|
|
|
if args.gpu == -2:
|
|
args.gpu = 0 if torch.cuda.is_available() else -1
|
|
|
|
for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
|
|
if args.model_filter and name not in args.model_filter:
|
|
continue
|
|
tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
|
|
for task in tasks:
|
|
if task in unsupported_tasks:
|
|
continue
|
|
if args.task_filter and task not in args.task_filter:
|
|
continue
|
|
runtimes = run_model(model_getter, args, task)
|
|
|
|
runtimes = torch.tensor(runtimes)
|
|
mean, var = runtimes.mean(), runtimes.var()
|
|
results[name][task] = (mean.item(), var.item())
|
|
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
|
|
|
|
if args.output:
|
|
with open(args.output, "w") as f:
|
|
f.write(to_markdown_table(results))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|