mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Apply UFMT to all files in benchmarks/ (#105928)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/105928 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a361fceef3
commit
dd3a77bc96
@ -1,33 +1,43 @@
|
||||
import torch
|
||||
from torch.autograd import functional
|
||||
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from typing import NamedTuple, Callable, List, Any
|
||||
from typing import Any, Callable, List, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import functional
|
||||
|
||||
try:
|
||||
import functorch as ft
|
||||
|
||||
has_functorch = True
|
||||
print(f"Found functorch: {ft.__version__}")
|
||||
except ImportError:
|
||||
has_functorch = False
|
||||
|
||||
import audio_text_models
|
||||
import ppl_models
|
||||
import vision_models
|
||||
import audio_text_models
|
||||
|
||||
from utils import to_markdown_table, TimingResultType, InputsType, GetterType, VType
|
||||
from utils import GetterType, InputsType, TimingResultType, to_markdown_table, VType
|
||||
|
||||
|
||||
def get_task_func(task: str) -> Callable:
|
||||
def hessian_fwdrev(model, inp, strict=None):
|
||||
return functional.hessian(model, inp, strict=False, vectorize=True, outer_jacobian_strategy="forward-mode")
|
||||
return functional.hessian(
|
||||
model,
|
||||
inp,
|
||||
strict=False,
|
||||
vectorize=True,
|
||||
outer_jacobian_strategy="forward-mode",
|
||||
)
|
||||
|
||||
def hessian_revrev(model, inp, strict=None):
|
||||
return functional.hessian(model, inp, strict=False, vectorize=True)
|
||||
|
||||
def jacfwd(model, inp, strict=None):
|
||||
return functional.jacobian(model, inp, strict=False, vectorize=True, strategy="forward-mode")
|
||||
return functional.jacobian(
|
||||
model, inp, strict=False, vectorize=True, strategy="forward-mode"
|
||||
)
|
||||
|
||||
def jacrev(model, inp, strict=None):
|
||||
return functional.jacobian(model, inp, strict=False, vectorize=True)
|
||||
@ -43,8 +53,8 @@ def get_task_func(task: str) -> Callable:
|
||||
else:
|
||||
return getattr(functional, task)
|
||||
|
||||
def get_task_functorch(task: str) -> Callable:
|
||||
|
||||
def get_task_functorch(task: str) -> Callable:
|
||||
@torch.no_grad()
|
||||
def vjp(model, inp, v=None, strict=None):
|
||||
assert v is not None
|
||||
@ -67,7 +77,9 @@ def get_task_functorch(task: str) -> Callable:
|
||||
def hvp(model, inp, v=None, strict=None):
|
||||
assert v is not None
|
||||
argnums = tuple(range(len(inp)))
|
||||
_, hvp_out, aux = ft.jvp(ft.grad_and_value(model, argnums), inp, v, has_aux=True)
|
||||
_, hvp_out, aux = ft.jvp(
|
||||
ft.grad_and_value(model, argnums), inp, v, has_aux=True
|
||||
)
|
||||
return aux, hvp_out
|
||||
|
||||
@torch.no_grad()
|
||||
@ -98,10 +110,13 @@ def get_task_functorch(task: str) -> Callable:
|
||||
if task in locals():
|
||||
return locals()[task]
|
||||
elif task == "jacobian":
|
||||
raise RuntimeError("functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet")
|
||||
raise RuntimeError(
|
||||
"functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported task: {task}")
|
||||
|
||||
|
||||
# Listing of the different tasks
|
||||
FAST_TASKS_NO_DOUBLE_BACK = [
|
||||
"vjp",
|
||||
@ -112,11 +127,7 @@ FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
|
||||
"jvp",
|
||||
]
|
||||
|
||||
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + [
|
||||
"hvp",
|
||||
"jacobian",
|
||||
"hessian"
|
||||
]
|
||||
ALL_TASKS_NON_VECTORIZED = FAST_TASKS + ["hvp", "jacobian", "hessian"]
|
||||
|
||||
DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
|
||||
|
||||
@ -124,6 +135,7 @@ VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
|
||||
|
||||
ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
|
||||
|
||||
|
||||
# Model definition which contains:
|
||||
# - name: a string with the model name.
|
||||
# - getter: a function to get the model. It takes as input the device on which the model
|
||||
@ -137,6 +149,7 @@ class ModelDef(NamedTuple):
|
||||
tasks: List[str]
|
||||
unsupported: List[str]
|
||||
|
||||
|
||||
MODELS = [
|
||||
ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
|
||||
ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
|
||||
@ -144,11 +157,17 @@ MODELS = [
|
||||
ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
|
||||
ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
|
||||
ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
|
||||
ModelDef("deepspeech", audio_text_models.get_deepspeech, FAST_TASKS_NO_DOUBLE_BACK, DOUBLE_BACKWARD_TASKS),
|
||||
ModelDef(
|
||||
"deepspeech",
|
||||
audio_text_models.get_deepspeech,
|
||||
FAST_TASKS_NO_DOUBLE_BACK,
|
||||
DOUBLE_BACKWARD_TASKS,
|
||||
),
|
||||
ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
|
||||
ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
|
||||
]
|
||||
|
||||
|
||||
def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
||||
v: VType
|
||||
|
||||
@ -165,6 +184,7 @@ def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None:
|
||||
func = get_task_func(task)
|
||||
|
||||
@ -173,7 +193,10 @@ def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) ->
|
||||
else:
|
||||
res = func(model, inp, strict=True)
|
||||
|
||||
def run_once_functorch(model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False) -> None:
|
||||
|
||||
def run_once_functorch(
|
||||
model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False
|
||||
) -> None:
|
||||
func = get_task_functorch(task)
|
||||
|
||||
if v is not None:
|
||||
@ -188,14 +211,24 @@ def run_once_functorch(model: Callable, inp: InputsType, task: str, v: VType, ma
|
||||
else:
|
||||
expected = af_func(model, inp, strict=True)
|
||||
atol = 1e-2 if task == "vhp" else 5e-3
|
||||
torch.testing.assert_close(res, expected, rtol=1e-5, atol=atol, msg=f"Consistency fail for task '{task}'")
|
||||
torch.testing.assert_close(
|
||||
res,
|
||||
expected,
|
||||
rtol=1e-5,
|
||||
atol=atol,
|
||||
msg=f"Consistency fail for task '{task}'",
|
||||
)
|
||||
|
||||
def run_model(model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once) -> List[float]:
|
||||
|
||||
def run_model(
|
||||
model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once
|
||||
) -> List[float]:
|
||||
if args.gpu == -1:
|
||||
device = torch.device("cpu")
|
||||
|
||||
def noop():
|
||||
pass
|
||||
|
||||
do_sync = noop
|
||||
else:
|
||||
device = torch.device(f"cuda:{args.gpu}")
|
||||
@ -220,16 +253,37 @@ def run_model(model_getter: GetterType, args: Any, task: str, run_once_fn: Calla
|
||||
|
||||
return elapsed
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
|
||||
parser.add_argument("--output", type=str, default="", help="Text file where to write the output")
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="", help="Text file where to write the output"
|
||||
)
|
||||
parser.add_argument("--num-iters", type=int, default=10)
|
||||
parser.add_argument("--gpu", type=int, default=-2, help="GPU to use, -1 for CPU and -2 for auto-detect")
|
||||
parser.add_argument("--run-slow-tasks", action="store_true", help="Run even the slow tasks")
|
||||
parser.add_argument("--model-filter", type=str, default="", help="Only run the models in this filter")
|
||||
parser.add_argument("--task-filter", type=str, default="", help="Only run the tasks in this filter")
|
||||
parser.add_argument("--num-threads", type=int, default=10,
|
||||
help="Number of concurrent threads to use when running on cpu")
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=-2,
|
||||
help="GPU to use, -1 for CPU and -2 for auto-detect",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run-slow-tasks", action="store_true", help="Run even the slow tasks"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-filter",
|
||||
type=str,
|
||||
default="",
|
||||
help="Only run the models in this filter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-filter", type=str, default="", help="Only run the tasks in this filter"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of concurrent threads to use when running on cpu",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -261,19 +315,27 @@ def main():
|
||||
|
||||
if has_functorch:
|
||||
try:
|
||||
runtimes = run_model(model_getter, args, task, run_once_fn=run_once_functorch)
|
||||
runtimes = run_model(
|
||||
model_getter, args, task, run_once_fn=run_once_functorch
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", e)
|
||||
print(
|
||||
f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t",
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
runtimes = torch.tensor(runtimes)
|
||||
mean, var = runtimes.mean(), runtimes.var()
|
||||
results[name][f"functorch {task}"] = (mean.item(), var.item())
|
||||
print(f"Results for model {name} on task {task} using Functorch: {mean}s (var: {var})")
|
||||
print(
|
||||
f"Results for model {name} on task {task} using Functorch: {mean}s (var: {var})"
|
||||
)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(to_markdown_table(results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user