mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67045 To run: `python benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py --gpu -1 --model-filter=ppl _robust_reg --num-iter 100` ``` Results for model ppl_robust_reg on task vjp: 0.0012262486852705479s (var: 2.2107682351446556e-10) Results for model ppl_robust_reg on task vhp: 0.002099371049553156s (var: 6.906406557760647e-10) Results for model ppl_robust_reg on task jvp: 0.001860950025729835s (var: 1.1251884146634694e-10) Results for model ppl_robust_reg on task hvp: 0.003481731517240405s (var: 2.2713633751614282e-10) Results for model ppl_robust_reg on task jacobian: 0.0012128615053370595s (var: 1.3687526667638394e-09) Results for model ppl_robust_reg on task hessian: 0.009885427542030811s (var: 9.366265096844018e-09) Results for model ppl_robust_reg on task hessian_fwdrev: 0.005268776323646307s (var: 2.4293791422991262e-09) Results for model ppl_robust_reg on task hessian_revrev: 0.002561321249231696s (var: 7.557877101938004e-10) Results for model ppl_robust_reg on task jacfwd: 0.002619938924908638s (var: 5.109343503839625e-10) Results for model ppl_robust_reg on task jacrev: 0.0013469004770740867s (var: 3.1857563254078514e-09) ``` Notes: - We go through batched fallback for both - ppl_robust_reg takes 3 tensor inputs and returns a single scalar output - this means that jacobian is equivalent to doing vjp and vmap would not help us - we expect jacfwd to be slower than jacrev Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D33265947 Pulled By: soulitzer fbshipit-source-id: 14f537a1376dea7e5afbe0c8e97f94731479b018
182 lines
6.3 KiB
Python
182 lines
6.3 KiB
Python
import torch
|
|
from torch.autograd import functional
|
|
|
|
import time
|
|
from argparse import ArgumentParser
|
|
from collections import defaultdict
|
|
from typing import NamedTuple, Callable, List, Any
|
|
|
|
import ppl_models
|
|
import vision_models
|
|
import audio_text_models
|
|
|
|
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
|
|
|
|
def get_task_func(task: str) -> Callable:
|
|
def hessian_fwdrev(model, inp, strict=None):
|
|
return functional.hessian(model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode")
|
|
|
|
def hessian_revrev(model, inp, strict=None):
|
|
return functional.hessian(model, inp, strict=False, vectorize=True)
|
|
|
|
def jacfwd(model, inp, strict=None):
|
|
return functional.jacobian(model, inp, strict=False, vectorize=True, strategy="forward-mode")
|
|
|
|
def jacrev(model, inp, strict=None):
|
|
return functional.jacobian(model, inp, strict=False, vectorize=True)
|
|
|
|
if task == "hessian_fwdrev":
|
|
return hessian_fwdrev
|
|
elif task == "hessian_revrev":
|
|
return hessian_revrev
|
|
elif task == "jacfwd":
|
|
return jacfwd
|
|
elif task == "jacrev":
|
|
return jacrev
|
|
else:
|
|
return getattr(functional, task)
|
|
|
|
# Listing of the different tasks
|
|
FAST_TASKS_NO_DOUBLE_BACK = [
|
|
"vjp",
|
|
]
|
|
|
|
FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
|
|
"vhp",
|
|
"jvp",
|
|
]
|
|
|
|
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + [
|
|
"hvp",
|
|
"jacobian",
|
|
"hessian"
|
|
]
|
|
|
|
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
|
|
|
|
VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
|
|
|
|
ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
|
|
|
|
# Model definition which contains:
|
|
# - name: a string with the model name.
|
|
# - getter: a function to get the model. It takes as input the device on which the model
|
|
# will run. It should return the forward function and the parameters (Tensors) used as
|
|
# input for the forward function. Note that the forward must *not* have any side effect.
|
|
# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
|
|
# - unsupported: the list of tasks that this model cannot run.
|
|
class ModelDef(NamedTuple):
|
|
name: str
|
|
getter: GetterType
|
|
tasks: List[str]
|
|
unsupported: List[str]
|
|
|
|
MODELS = [
|
|
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
|
|
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
|
|
ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
|
|
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
|
|
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
|
|
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
|
|
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
|
|
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
|
|
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
|
|
]
|
|
|
|
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
|
v: VType
|
|
|
|
if task in ["vjp"]:
|
|
out = model(*inp)
|
|
v = torch.rand_like(out)
|
|
elif task in ["jvp", "hvp", "vhp"]:
|
|
if isinstance(inp, tuple):
|
|
v = tuple(torch.rand_like(i) for i in inp)
|
|
else:
|
|
v = torch.rand_like(inp)
|
|
else:
|
|
v = None
|
|
|
|
return v
|
|
|
|
def run_once(model: Callable, inp: InputsType, task: str, v: VType) -> None:
|
|
func = get_task_func(task)
|
|
|
|
if v is not None:
|
|
res = func(model, inp, v=v, strict=True)
|
|
else:
|
|
res = func(model, inp, strict=True)
|
|
|
|
def run_model(model_getter: GetterType, args: Any, task: str) -> List[float]:
|
|
if args.gpu == -1:
|
|
device = torch.device("cpu")
|
|
|
|
def noop():
|
|
pass
|
|
do_sync = noop
|
|
else:
|
|
device = torch.device("cuda:{}".format(args.gpu))
|
|
do_sync = torch.cuda.synchronize
|
|
|
|
model, inp = model_getter(device)
|
|
|
|
v = get_v_for(model, inp, task)
|
|
# Warmup
|
|
run_once(model, inp, task, v)
|
|
|
|
elapsed = []
|
|
for it in range(args.num_iters):
|
|
do_sync()
|
|
start = time.time()
|
|
run_once(model, inp, task, v)
|
|
do_sync()
|
|
elapsed.append(time.time() - start)
|
|
|
|
return elapsed
|
|
|
|
def main():
|
|
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
|
|
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
|
|
parser.add_argument("--num-iters", type=int, default=10)
|
|
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
|
|
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
|
|
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
|
|
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
|
|
parser.add_argument("--num-threads", type=int, default=10,
|
|
help="Number of concurrent threads to use when running on cpu")
|
|
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
|
|
args = parser.parse_args()
|
|
|
|
results: TimingResultType = defaultdict(defaultdict)
|
|
torch.set_num_threads(args.num_threads)
|
|
torch.set_num_interop_threads(args.num_threads)
|
|
|
|
# This automatically seed cuda if it is available
|
|
torch.manual_seed(args.seed)
|
|
|
|
if args.gpu == -2:
|
|
args.gpu = 0 if torch.cuda.is_available() else -1
|
|
|
|
for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
|
|
if args.model_filter and name not in args.model_filter:
|
|
continue
|
|
tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
|
|
for task in tasks:
|
|
if task in unsupported_tasks:
|
|
continue
|
|
if args.task_filter and task not in args.task_filter:
|
|
continue
|
|
runtimes = run_model(model_getter, args, task)
|
|
|
|
runtimes = torch.tensor(runtimes)
|
|
mean, var = runtimes.mean(), runtimes.var()
|
|
results[name][task] = (mean.item(), var.item())
|
|
print("Results for model {} on task {}: {}s (var: {})".format(name, task, mean, var))
|
|
|
|
if args.output:
|
|
with open(args.output, "w") as f:
|
|
f.write(to_markdown_table(results))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|