mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc). Most of the PRs were completely automated with RUFF as follows: Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes: ``` --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -313,6 +313,7 @@ "ruff", "check", "--fix-only", + "--unsafe-fixes", "--exit-zero", *([f"--config={config}"] if config else []), "--stdin-filename", ``` Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent): ``` --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] @@ -87,7 +87,6 @@ "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", - "UP006", # keep-runtime-typing "UP007", # keep-runtime-typing ] select = [ ``` Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101 Approved by: https://github.com/bobrenjc93
344 lines
10 KiB
Python
344 lines
10 KiB
Python
import functools
|
|
import logging
|
|
import math
|
|
import os
|
|
from collections import Counter, defaultdict
|
|
from collections.abc import Generator, Iterable
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch.testing import make_tensor
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
OP_INP_DIRECTORY = os.path.join(os.path.dirname(__file__), "operator_inp_logs")
|
|
|
|
TIMM_DIR = os.path.join(OP_INP_DIRECTORY, "timm_train")
|
|
HF_DIR = os.path.join(OP_INP_DIRECTORY, "hf_train")
|
|
TORCHBENCH_DIR = os.path.join(OP_INP_DIRECTORY, "torchbench_train")
|
|
|
|
aten = torch.ops.aten
|
|
tensor_type = torch._C.TensorType.get()
|
|
|
|
dtype_abbrs = {
|
|
torch.bfloat16: "bf16",
|
|
torch.float64: "f64",
|
|
torch.float32: "f32",
|
|
torch.float16: "f16",
|
|
torch.complex32: "c32",
|
|
torch.complex64: "c64",
|
|
torch.complex128: "c128",
|
|
torch.int8: "i8",
|
|
torch.int16: "i16",
|
|
torch.int32: "i32",
|
|
torch.int64: "i64",
|
|
torch.bool: "b8",
|
|
torch.uint8: "u8",
|
|
}
|
|
|
|
dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
|
|
|
|
|
|
def truncate_inp(arg):
|
|
if arg in dtype_abbrs:
|
|
return dtype_abbrs[arg]
|
|
elif isinstance(arg, torch.device):
|
|
return arg.type
|
|
else:
|
|
return arg
|
|
|
|
|
|
# Serialize Function Call
|
|
class FuncCallWrapper:
|
|
def __init__(self, call, *args, **kwargs):
|
|
self.call = call
|
|
self.args = tree_map(truncate_inp, args)
|
|
self.kwargs = tree_map(truncate_inp, kwargs) if kwargs is not None else {}
|
|
|
|
def __repr__(self):
|
|
args = ", ".join([repr(arg) for arg in self.args])
|
|
kwargs = "".join(
|
|
[f", {str(key)}={value}" for key, value in self.kwargs.items()]
|
|
)
|
|
out = f"{self.call}({args}{kwargs})".strip('"')
|
|
# f strings introduce quotations we dont want
|
|
for key in dtype_abbrs_parsing:
|
|
out = out.replace(f"'{key}'", key)
|
|
return out
|
|
|
|
|
|
def serialize_sparse_tensor(e):
|
|
if isinstance(e, torch._subclasses.FakeTensor):
|
|
return FuncCallWrapper("ST", list(e.shape), e.dtype, e.layout, e.is_coalesced())
|
|
else:
|
|
return FuncCallWrapper(
|
|
"ST", list(e.shape), e.dtype, e.layout, e.is_coalesced(), e._nnz()
|
|
)
|
|
|
|
|
|
def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None):
|
|
raise NotImplementedError
|
|
|
|
|
|
def deserialize_tensor(size, dtype, stride=None):
|
|
if stride is not None:
|
|
out = torch.empty_strided(size, stride, dtype=dtype)
|
|
else:
|
|
out = torch.empty(size, dtype=dtype)
|
|
try:
|
|
out.copy_(make_tensor(size, dtype=dtype, device="cpu"))
|
|
except Exception as e:
|
|
print(e)
|
|
return out
|
|
return out
|
|
|
|
|
|
def serialize_tensor(e):
|
|
if not e.is_contiguous():
|
|
return FuncCallWrapper("T", list(e.shape), e.dtype, stride=e.stride())
|
|
else:
|
|
return FuncCallWrapper("T", list(e.shape), e.dtype)
|
|
|
|
|
|
def serialize_torch_args(e):
|
|
if isinstance(e, torch.Tensor):
|
|
if e.is_sparse:
|
|
return serialize_sparse_tensor(e)
|
|
return serialize_tensor(e)
|
|
else:
|
|
return truncate_inp(e)
|
|
|
|
|
|
def contains_tensor(elems):
|
|
for elem in pytree.tree_leaves(elems):
|
|
if isinstance(elem, torch.Tensor):
|
|
return True
|
|
return False
|
|
|
|
|
|
def skip_args(elems):
|
|
for i in pytree.tree_leaves(elems):
|
|
# only shows up in constructors and ops like that
|
|
if isinstance(i, (torch.memory_format, torch.storage.UntypedStorage)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def contains_tensor_types(type):
|
|
return type.isSubtypeOf(tensor_type) or any(
|
|
contains_tensor_types(e) for e in type.containedTypes()
|
|
)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def non_compute_operator(op):
|
|
schema = op._schema
|
|
|
|
# skip constructors
|
|
if not any(contains_tensor_types(arg.type) for arg in schema.arguments):
|
|
return True
|
|
if "_like" in op.name():
|
|
return True
|
|
|
|
# allow in place writes
|
|
if schema.is_mutable:
|
|
return False
|
|
|
|
tensor_inps = [arg for arg in schema.arguments if arg.type is tensor_type]
|
|
tensor_outputs = [ret for ret in schema.returns if ret.type is tensor_type]
|
|
|
|
# skip aliasing unless there are multiple outputs
|
|
if len(tensor_outputs) != 1:
|
|
return False
|
|
|
|
for inp in tensor_inps:
|
|
if inp.alias_info and tensor_outputs[0].alias_info:
|
|
if inp.alias_info.before_set.intersection(
|
|
tensor_outputs[0].alias_info.after_set
|
|
):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class OperatorInputsMode(TorchDispatchMode):
|
|
def __init__(self, func_db=None):
|
|
self.func_db = defaultdict(Counter) if func_db is None else func_db
|
|
|
|
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
arg_meta, kwarg_meta = tree_map(serialize_torch_args, (args, kwargs))
|
|
|
|
out = func_overload(*args, **kwargs)
|
|
|
|
inps = (args, kwargs)
|
|
if contains_tensor(inps) and not skip_args(inps) and contains_tensor(out):
|
|
serialized_str = repr((arg_meta, kwarg_meta))
|
|
self.func_db[str(func_overload)][serialized_str] += 1
|
|
|
|
return out
|
|
|
|
def log_to_file(self, output_filename, *, skip_non_compute_operators=True):
|
|
sorted_operators = sorted(self.func_db.keys())
|
|
with open(output_filename, "w") as f:
|
|
for operator in sorted_operators:
|
|
if skip_non_compute_operators and non_compute_operator(eval(operator)):
|
|
continue
|
|
f.write(f"Operator: {operator}\n")
|
|
operator_inputs = self.func_db[operator]
|
|
for inps, count in operator_inputs.items():
|
|
f.write(f"cnt: {count}, ")
|
|
# repr will add quotation marks around the dtype strings
|
|
for dtype_abbr in dtype_abbrs.values():
|
|
inps = inps.replace("'" + dtype_abbr + "'", dtype_abbr)
|
|
f.write(inps)
|
|
f.write("\n")
|
|
|
|
|
|
def map_to_device(e, device):
|
|
if isinstance(e, torch.Tensor):
|
|
return e.to(device)
|
|
elif isinstance(e, torch.device):
|
|
return device
|
|
elif isinstance(e, str):
|
|
if e == "cuda" or e == "cpu":
|
|
return device.type
|
|
else:
|
|
return e
|
|
|
|
|
|
def map_to_dtype(e, dtype):
|
|
if isinstance(e, torch.Tensor) and e.is_floating_point():
|
|
return e.to(dtype)
|
|
elif isinstance(e, torch.dtype):
|
|
return dtype
|
|
else:
|
|
return e
|
|
|
|
|
|
def deserialize_args(inps):
|
|
inps = inps.strip().strip("'")
|
|
global_vals = {
|
|
"T": deserialize_tensor,
|
|
"ST": deserialize_sparse_tensor,
|
|
"th": torch,
|
|
"inf": math.inf,
|
|
"torch": torch,
|
|
**dtype_abbrs_parsing,
|
|
}
|
|
# f strings introduce quotations we dont want
|
|
for key in dtype_abbrs_parsing:
|
|
inps = inps.replace(f"'{key}'", key)
|
|
return eval(inps.strip().strip("'").strip('"'), global_vals)
|
|
|
|
|
|
class OperatorInputsLoader:
|
|
def __init__(self, json_file_path):
|
|
self.operator_db = defaultdict(Counter)
|
|
|
|
with open(json_file_path) as f:
|
|
lines = f.readlines()
|
|
|
|
i = 0
|
|
while i < len(lines):
|
|
op_line = lines[i].strip("\n")
|
|
assert "Operator: " in op_line, op_line
|
|
operator = op_line[len("Operator: ") :]
|
|
operator = (
|
|
operator if operator != "aten.sum.SymInt" else "aten.sum.dim_IntList"
|
|
)
|
|
op_inps = Counter()
|
|
i += 1
|
|
while i < len(lines) and "Operator: " not in lines[i]:
|
|
line = lines[i]
|
|
cnt = eval(line[len("cnt: ") : line.find(",")])
|
|
inps = line[line.find(",") + 2 :].strip("'")
|
|
op_inps[inps] += cnt
|
|
i += 1
|
|
self.operator_db[operator] = op_inps
|
|
|
|
def get_inputs_for_operator(
|
|
self, operator, dtype=None, device="cuda"
|
|
) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]:
|
|
assert (
|
|
str(operator) in self.operator_db
|
|
), f"Could not find {operator}, must provide overload"
|
|
|
|
if "embedding" in str(operator):
|
|
log.warning("Embedding inputs NYI, input data cannot be randomized")
|
|
yield
|
|
return
|
|
|
|
# line[1] represents number of times these inputs occured, ignored for now
|
|
for line in self.operator_db[str(operator)].items():
|
|
inps = line[0]
|
|
|
|
args, kwargs = deserialize_args(inps)
|
|
|
|
# Backwards require some inputs to be float16 and some to be float32
|
|
# So we record on half and upcast to float when specified
|
|
if dtype and dtype != torch.float16:
|
|
to_dtype = partial(map_to_dtype, dtype=dtype)
|
|
args, kwargs = tree_map(to_dtype, (args, kwargs))
|
|
|
|
if device:
|
|
to_device = partial(map_to_device, device=torch.device(device))
|
|
args, kwargs = tree_map(to_device, (args, kwargs))
|
|
|
|
yield args, kwargs
|
|
|
|
def get_all_ops(self):
|
|
for key in self.operator_db.keys():
|
|
try:
|
|
op = eval(key)
|
|
except AttributeError as ae:
|
|
log.warning("Evaluating an op name into an OpOverload: %s", ae)
|
|
continue
|
|
yield op
|
|
|
|
def get_call_frequency(self, op):
|
|
assert (
|
|
str(op) in self.operator_db
|
|
), f"Could not find {op}, must provide overload"
|
|
|
|
count = 0
|
|
for counter in self.operator_db[str(op)].values():
|
|
count += counter
|
|
return count
|
|
|
|
def merge(self, other):
|
|
for operator, counter_dict in other.operator_db.items():
|
|
for inps, cnt in counter_dict.items():
|
|
self.operator_db[operator][inps] += cnt
|
|
|
|
@staticmethod
|
|
def get_timm_loader():
|
|
return OperatorInputsLoader._load_directory(TIMM_DIR)
|
|
|
|
@staticmethod
|
|
def get_huggingface_loader():
|
|
return OperatorInputsLoader._load_directory(HF_DIR)
|
|
|
|
@staticmethod
|
|
def get_torchbench_loader():
|
|
return OperatorInputsLoader._load_directory(TORCHBENCH_DIR)
|
|
|
|
@staticmethod
|
|
def _load_directory(inp_dir):
|
|
assert os.path.isdir(inp_dir), inp_dir
|
|
union = None
|
|
for inp in os.listdir(inp_dir):
|
|
if inp[-4:] != ".txt":
|
|
continue
|
|
path = os.path.join(inp_dir, inp)
|
|
if union is None:
|
|
union = OperatorInputsLoader(path)
|
|
else:
|
|
union.merge(OperatorInputsLoader(path))
|
|
return union
|