Files
pytorch/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py
soulitzer 21c6de9fdc Extend autograd functional benchmarking to run vectorized tasks (#67045)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67045

To run: `python benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py --gpu -1 --model-filter=ppl    _robust_reg --num-iter 100`

```
Results for model ppl_robust_reg on task vjp: 0.0012262486852705479s (var: 2.2107682351446556e-10)
Results for model ppl_robust_reg on task vhp: 0.002099371049553156s (var: 6.906406557760647e-10)
Results for model ppl_robust_reg on task jvp: 0.001860950025729835s (var: 1.1251884146634694e-10)
Results for model ppl_robust_reg on task hvp: 0.003481731517240405s (var: 2.2713633751614282e-10)
Results for model ppl_robust_reg on task jacobian: 0.0012128615053370595s (var: 1.3687526667638394e-09)
Results for model ppl_robust_reg on task hessian: 0.009885427542030811s (var: 9.366265096844018e-09)
Results for model ppl_robust_reg on task hessian_fwdrev: 0.005268776323646307s (var: 2.4293791422991262e-09)
Results for model ppl_robust_reg on task hessian_revrev: 0.002561321249231696s (var: 7.557877101938004e-10)
Results for model ppl_robust_reg on task jacfwd: 0.002619938924908638s (var: 5.109343503839625e-10)
Results for model ppl_robust_reg on task jacrev: 0.0013469004770740867s (var: 3.1857563254078514e-09)
```
Notes:
 - We go through batched fallback for both
 - ppl_robust_reg takes 3 tensor inputs and returns a single scalar output
   - this means that jacobian is equivalent to doing vjp and vmap would not help us
   - we expect jacfwd to be slower than jacrev

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D33265947

Pulled By: soulitzer

fbshipit-source-id: 14f537a1376dea7e5afbe0c8e97f94731479b018
2021-12-21 17:20:29 -08:00

182 lines
6.3 KiB
Python

import torch
from torch.autograd import functional
import time
from argparse import ArgumentParser
from collections import defaultdict
from typing import NamedTuple, Callable, List, Any
import ppl_models
import vision_models
import audio_text_models
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
def get_task_func(task: str) -> Callable:
def hessian_fwdrev(model, inp, strict=None):
return functional.hessian(model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode")
def hessian_revrev(model, inp, strict=None):
return functional.hessian(model, inp, strict=False, vectorize=True)
def jacfwd(model, inp, strict=None):
return functional.jacobian(model, inp, strict=False, vectorize=True, strategy="forward-mode")
def jacrev(model, inp, strict=None):
return functional.jacobian(model, inp, strict=False, vectorize=True)
if task == "hessian_fwdrev":
return hessian_fwdrev
elif task == "hessian_revrev":
return hessian_revrev
elif task == "jacfwd":
return jacfwd
elif task == "jacrev":
return jacrev
else:
return getattr(functional, task)
# Listing of the different tasks
FAST_TASKS_NO_DOUBLE_BACK = [
"vjp",
]
FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
"vhp",
"jvp",
]
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + [
"hvp",
"jacobian",
"hessian"
]
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
# Model definition which contains:
# - name: a string with the model name.
# - getter: a function to get the model. It takes as input the device on which the model
# will run. It should return the forward function and the parameters (Tensors) used as
# input for the forward function. Note that the forward must *not* have any side effect.
# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
# - unsupported: the list of tasks that this model cannot run.
class ModelDef(NamedTuple):
name: str
getter: GetterType
tasks: List[str]
unsupported: List[str]
MODELS = [
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
]
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
v: VType
if task in ["vjp"]:
out = model(*inp)
v = torch.rand_like(out)
elif task in ["jvp", "hvp", "vhp"]:
if isinstance(inp, tuple):
v = tuple(torch.rand_like(i) for i in inp)
else:
v = torch.rand_like(inp)
else:
v = None
return v
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
func = get_task_func(task)
if v is not None:
res = func(model, inp, v=v, strict=True)
else:
res = func(model, inp, strict=True)
def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
if args.gpu == -1:
device = torch.device("cpu")
def noop():
pass
do_sync = noop
else:
device = torch.device("cuda:{}".format(args.gpu))
do_sync = torch.cuda.synchronize
model, inp = model_getter(device)
v = get_v_for(model, inp, task)
# Warmup
run_once(model, inp, task, v)
elapsed = []
for it in range(args.num_iters):
do_sync()
start = time.time()
run_once(model, inp, task, v)
do_sync()
elapsed.append(time.time() - start)
return elapsed
def main():
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
parser.add_argument("--num-iters", type=int, default=10)
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
parser.add_argument("--num-threads", type=int, default=10,
help="Number of concurrent threads to use when running on cpu")
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
args = parser.parse_args()
results: TimingResultType = defaultdict(defaultdict)
torch.set_num_threads(args.num_threads)
torch.set_num_interop_threads(args.num_threads)
# This automatically seed cuda if it is available
torch.manual_seed(args.seed)
if args.gpu == -2:
args.gpu = 0 if torch.cuda.is_available() else -1
for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
if args.model_filter and name not in args.model_filter:
continue
tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
for task in tasks:
if task in unsupported_tasks:
continue
if args.task_filter and task not in args.task_filter:
continue
runtimes = run_model(model_getter, args, task)
runtimes = torch.tensor(runtimes)
mean, var = runtimes.mean(), runtimes.var()
results[name][task] = (mean.item(), var.item())
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
if args.output:
with open(args.output, "w") as f:
f.write(to_markdown_table(results))
if __name__ == "__main__":
main()