mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc). Most of the PRs were completely automated with RUFF as follows: Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes: ``` --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -313,6 +313,7 @@ "ruff", "check", "--fix-only", + "--unsafe-fixes", "--exit-zero", *([f"--config={config}"] if config else []), "--stdin-filename", ``` Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent): ``` --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -87,7 +87,6 @@ "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", - "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing ] select = [ ``` Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101 Approved by: https://github.com/bobrenjc93
121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
from collections import defaultdict
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
|
|
# Type helpers
|
|
InputsType = Union[Tensor, tuple[Tensor, ...]]
|
|
# A Getter takes in a device and returns a callable and the inputs to that callable
|
|
GetterReturnType = tuple[Callable[..., Tensor], InputsType]
|
|
GetterType = Callable[[torch.device], GetterReturnType]
|
|
# V here refers to the v in either vjp, jvp, vhp or hvp
|
|
VType = Union[None, Tensor, tuple[Tensor, ...]]
|
|
# Type used to store timing results. The first key is the model name, the second key
|
|
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
|
|
TimingResultType = dict[str, dict[str, tuple[float, ...]]]
|
|
|
|
|
|
# Utilities to make nn.Module "functional"
|
|
# In particular the goal is to be able to provide a function that takes as input
|
|
# the parameters and evaluate the nn.Module using fixed inputs.
|
|
def _del_nested_attr(obj: nn.Module, names: list[str]) -> None:
|
|
"""
|
|
Deletes the attribute specified by the given list of names.
|
|
For example, to delete the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'])
|
|
"""
|
|
if len(names) == 1:
|
|
delattr(obj, names[0])
|
|
else:
|
|
_del_nested_attr(getattr(obj, names[0]), names[1:])
|
|
|
|
|
|
def _set_nested_attr(obj: nn.Module, names: list[str], value: Tensor) -> None:
|
|
"""
|
|
Set the attribute specified by the given list of names to value.
|
|
For example, to set the attribute obj.conv.weight,
|
|
use _del_nested_attr(obj, ['conv', 'weight'], value)
|
|
"""
|
|
if len(names) == 1:
|
|
setattr(obj, names[0], value)
|
|
else:
|
|
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
|
|
|
|
|
|
def extract_weights(mod: nn.Module) -> tuple[tuple[Tensor, ...], list[str]]:
|
|
"""
|
|
This function removes all the Parameters from the model and
|
|
return them as a tuple as well as their original attribute names.
|
|
The weights must be re-loaded with `load_weights` before the model
|
|
can be used again.
|
|
Note that this function modifies the model in place and after this
|
|
call, mod.parameters() will be empty.
|
|
"""
|
|
orig_params = tuple(mod.parameters())
|
|
# Remove all the parameters in the model
|
|
names = []
|
|
for name, p in list(mod.named_parameters()):
|
|
_del_nested_attr(mod, name.split("."))
|
|
names.append(name)
|
|
|
|
# Make params regular Tensors instead of nn.Parameter
|
|
params = tuple(p.detach().requires_grad_() for p in orig_params)
|
|
return params, names
|
|
|
|
|
|
def load_weights(mod: nn.Module, names: list[str], params: tuple[Tensor, ...]) -> None:
|
|
"""
|
|
Reload a set of weights so that `mod` can be used again to perform a forward pass.
|
|
Note that the `params` are regular Tensors (that can have history) and so are left
|
|
as Tensors. This means that mod.parameters() will still be empty after this call.
|
|
"""
|
|
for name, p in zip(names, params):
|
|
_set_nested_attr(mod, name.split("."), p)
|
|
|
|
|
|
# Utilities to read/write markdown table-like content.
|
|
def to_markdown_table(
|
|
res: TimingResultType, header: Optional[tuple[str, ...]] = None
|
|
) -> str:
|
|
if header is None:
|
|
header = ("model", "task", "mean", "var")
|
|
out = ""
|
|
|
|
def write_line(*args):
|
|
nonlocal out
|
|
out += f"| {' | '.join(str(a) for a in args)} |\n"
|
|
|
|
# Make it a markdown table
|
|
write_line(*header)
|
|
write_line(*["--"] * len(header))
|
|
for model, tasks in res.items():
|
|
for task, line in tasks.items():
|
|
write_line(*(model, task) + line)
|
|
|
|
return out
|
|
|
|
|
|
def from_markdown_table(data: str) -> TimingResultType:
|
|
out = data.strip().split("\n")
|
|
out = out[2:] # Ignore the header lines
|
|
|
|
res: TimingResultType
|
|
res = defaultdict(defaultdict)
|
|
|
|
for line in out:
|
|
model, task, mean, var = (f.strip() for f in line.strip().split("|") if f)
|
|
res[model][task] = (float(mean), float(var))
|
|
|
|
return res
|
|
|
|
|
|
def check_for_functorch():
|
|
try:
|
|
import functorch # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|