mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Reland of the benchmark code that broke the slow tests because the GPU were running out of memory Pull Request resolved: https://github.com/pytorch/pytorch/pull/43428 Reviewed By: ngimel Differential Revision: D23296136 Pulled By: albanD fbshipit-source-id: 0002ae23dc82f401604e33d0905d6b9eedebc851
104 lines
3.9 KiB
Python
104 lines
3.9 KiB
Python
import torch
|
|
|
|
from collections import defaultdict
|
|
|
|
from torch import nn, Tensor
|
|
from typing import List, Tuple, Dict, Union, Callable
|
|
|
|
# Type helpers
|
|
InputsType = Union[Tensor, Tuple[Tensor, ...]]
|
|
# A Getter takes in a device and returns a callable and the inputs to that callable
|
|
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
|
|
GetterType = Callable[[torch.device], GetterReturnType]
|
|
# V here refers to the v in either vjp, jvp, vhp or hvp
|
|
VType = Union[None, Tensor, Tuple[Tensor, ...]]
|
|
# Type used to store timing results. The first key is the model name, the second key
|
|
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
|
|
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
|
|
|
|
# Utilities to make nn.Module "functional"
|
|
# In particular the goal is to be able to provide a function that takes as input
|
|
# the parameters and evaluate the nn.Module using fixed inputs.
|
|
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
|
|
"""
|
|
Deletes the attribute specified by the given list of names.
|
|
For example, to delete the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'])
|
|
"""
|
|
if len(names) == 1:
|
|
delattr(obj, names[0])
|
|
else:
|
|
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
|
|
|
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
|
|
"""
|
|
Set the attribute specified by the given list of names to value.
|
|
For example, to set the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
|
"""
|
|
if len(names) == 1:
|
|
setattr(obj, names[0], value)
|
|
else:
|
|
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
|
|
|
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
|
|
"""
|
|
This function removes all the Parameters from the model and
|
|
return them as a tuple as well as their original attribute names.
|
|
The weights must be re-loaded with `load_weights` before the model
|
|
can be used again.
|
|
Note that this function modifies the model in place and after this
|
|
call, mod.parameters() will be empty.
|
|
"""
|
|
orig_params = tuple(mod.parameters())
|
|
# Remove all the parameters in the model
|
|
names = []
|
|
for name, p in list(mod.named_parameters()):
|
|
_del_nested_attr(mod, name.split("."))
|
|
names.append(name)
|
|
|
|
# Make params regular Tensors instead of nn.Parameter
|
|
params = tuple(p.detach().requires_grad_() for p in orig_params)
|
|
return params, names
|
|
|
|
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
|
|
"""
|
|
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
|
Note that the `params` are regular Tensors (that can have history) and so are left
|
|
as Tensors. This means that mod.parameters() will still be empty after this call.
|
|
"""
|
|
for name, p in zip(names, params):
|
|
_set_nested_attr(mod, name.split("."), p)
|
|
|
|
# Utilities to read/write markdown table-like content.
|
|
def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str:
|
|
if header is None:
|
|
header = ("model", "task", "mean", "var")
|
|
out = ""
|
|
|
|
def write_line(*args):
|
|
nonlocal out
|
|
out += "| {} |\n".format(" | ".join(str(a) for a in args))
|
|
|
|
# Make it a markdown table
|
|
write_line(*header)
|
|
write_line(*["--"] * len(header))
|
|
for model, tasks in res.items():
|
|
for task, line in tasks.items():
|
|
write_line(*(model, task) + line)
|
|
|
|
return out
|
|
|
|
def from_markdown_table(data: str) -> TimingResultType:
|
|
out = data.strip().split("\n")
|
|
out = out[2:] # Ignore the header lines
|
|
|
|
res: TimingResultType
|
|
res = defaultdict(defaultdict)
|
|
|
|
for line in out:
|
|
model, task, mean, var = [f.strip() for f in line.strip().split("|") if f]
|
|
res[model][task] = (float(mean), float(var))
|
|
|
|
return res
|