mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492 Approved by: https://github.com/albanD
231 lines
8.4 KiB
Python
231 lines
8.4 KiB
Python
# mypy: ignore-errors
|
|
|
|
from collections import namedtuple
|
|
from copy import deepcopy
|
|
from itertools import combinations
|
|
|
|
import torch
|
|
from torch.fx.operator_schemas import normalize_function
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
# Named Tuples used within SchemaCheckMode
|
|
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
|
|
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
|
|
|
|
# Simplified naming for C++ classes
|
|
SchemaArgument = torch._C._SchemaArgument
|
|
SchemaArgType = torch._C._SchemaArgType
|
|
SchemaInfo = torch._C._SchemaInfo
|
|
|
|
# This TorchDispatchMode Subclass is used to verify op schemas
|
|
# This TorchDispatchMode Scubclass currently:
|
|
# - Records the called ops
|
|
# - Checks for mutations on all inputs
|
|
# - Checks for aliasing on all inputs
|
|
|
|
|
|
# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
|
|
|
|
|
|
def is_iterable_of_tensors(iterable):
|
|
# Tensor itself is iterable so we check this first
|
|
if isinstance(iterable, torch.Tensor):
|
|
return False
|
|
try:
|
|
if len(iterable) == 0:
|
|
return False
|
|
for t in iter(iterable):
|
|
if not isinstance(t, torch.Tensor):
|
|
return False
|
|
except TypeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def clone_inputs(args):
|
|
inputs = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
inputs.append(arg.detach().clone())
|
|
elif is_iterable_of_tensors(arg):
|
|
inputs.append([t.detach().clone() for t in arg])
|
|
else:
|
|
inputs.append(arg)
|
|
|
|
return inputs
|
|
|
|
|
|
class SchemaCheckMode(TorchDispatchMode):
|
|
def __init__(self) -> None:
|
|
# Information recorded for testing purposes. For example:
|
|
# - incorrect schemas
|
|
# - overly conservative schemas
|
|
self.ops = []
|
|
self.mutated = []
|
|
self.aliasing = []
|
|
|
|
def reset_cache(self):
|
|
self.ops.clear()
|
|
self.mutated.clear()
|
|
self.aliasing.clear()
|
|
|
|
def display_ops(self):
|
|
print(*self.ops, sep=",")
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
def bitwise_equal(lhs, rhs):
|
|
if lhs.is_quantized:
|
|
# TODO: This is only OK if can't have NaN quantized; idk if
|
|
# this is actually true
|
|
return torch.equal(lhs, rhs)
|
|
else:
|
|
return torch.allclose(lhs, rhs, equal_nan=True)
|
|
|
|
def has_mutated(before, after, md):
|
|
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
|
|
if (
|
|
are_tensors
|
|
and before.layout != torch.sparse_csr
|
|
and after.layout != torch.sparse_csr
|
|
):
|
|
return not (
|
|
before.size() == after.size()
|
|
and bitwise_equal(before, after)
|
|
and md[0] == after.stride()
|
|
and md[1] == after._typed_storage()._cdata
|
|
)
|
|
return False
|
|
|
|
def has_aliased(lhs, rhs):
|
|
try:
|
|
return torch._C._overlaps(lhs, rhs)
|
|
except Exception as exception:
|
|
if str(exception).startswith("Cannot inspect value of type "):
|
|
return False
|
|
else:
|
|
raise exception
|
|
|
|
def standardize_name(name):
|
|
return name if name != "self" else "input"
|
|
|
|
def unwrap(e):
|
|
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
|
|
try:
|
|
return e.elem
|
|
except AttributeError:
|
|
return e
|
|
return e
|
|
|
|
def parse_metadata(e):
|
|
if isinstance(e, torch.Tensor):
|
|
if not type(e) == torch.Tensor:
|
|
try:
|
|
current = e.elem
|
|
return (
|
|
deepcopy(current.stride()),
|
|
current._typed_storage()._cdata,
|
|
)
|
|
except AttributeError:
|
|
return None
|
|
# Sparse CSR tensors do not have strides or storage
|
|
elif e.layout != torch.sparse_csr:
|
|
return (deepcopy(e.stride()), e._typed_storage()._cdata)
|
|
return None
|
|
|
|
self.ops.append(func._schema.name)
|
|
|
|
# Clone and process arguments and outputs
|
|
pre_arguments = normalize_function(
|
|
func, args, kwargs, normalize_to_only_use_kwargs=True
|
|
).kwargs
|
|
|
|
c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
|
|
cloned_arguments = {
|
|
name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
|
|
}
|
|
cloned_metadata = {
|
|
name: [
|
|
parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
|
|
]
|
|
for name in pre_arguments
|
|
}
|
|
|
|
out = func(*args, **kwargs)
|
|
arguments = {
|
|
name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
|
|
}
|
|
tuple_out = out if isinstance(out, tuple) else (out,)
|
|
tuple_out = tree_map(unwrap, tuple_out)
|
|
|
|
schema_info = SchemaInfo(func._schema)
|
|
schema_info.add_argument_values(pre_arguments)
|
|
|
|
# Process arguments with outputs
|
|
for i in range(len(func._schema.arguments)):
|
|
arg = func._schema.arguments[i]
|
|
name = standardize_name(arg.name)
|
|
if arguments.get(name) is not None:
|
|
before = cloned_arguments.get(name)
|
|
md = cloned_metadata.get(name)
|
|
after = arguments.get(name)
|
|
for j in range(len(tuple_out)):
|
|
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
|
|
unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
|
|
if (
|
|
has_aliased(tuple_out[j], after)
|
|
and func._schema.name not in unsafe_ops
|
|
):
|
|
if not schema_info.may_contain_alias(
|
|
SchemaArgument(SchemaArgType.output, j),
|
|
SchemaArgument(SchemaArgType.input, i),
|
|
):
|
|
raise RuntimeError(
|
|
f"Argument {name} is not defined to alias output but was aliasing"
|
|
)
|
|
else:
|
|
self.aliasing.append(
|
|
Aliasing(func._schema.name, name, f"output_{j}")
|
|
)
|
|
if after is tuple_out[j] and isinstance(after, torch.Tensor):
|
|
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
|
|
if not schema_info.is_mutable(
|
|
SchemaArgument(SchemaArgType.input, i)
|
|
) and func not in [
|
|
torch.ops.aten.lift.default,
|
|
torch.ops.aten.lift_fresh.default,
|
|
]:
|
|
raise RuntimeError(
|
|
f"""\
|
|
Dispatcher operators below autograd are not allowed to directly return inputs.
|
|
However, we found that `outputs[{str(j)}] is {name}"""
|
|
)
|
|
if any(
|
|
has_mutated(a, b, c)
|
|
for a, b, c in zip(
|
|
pytree.tree_leaves(before), pytree.tree_leaves(after), md
|
|
)
|
|
):
|
|
if not schema_info.is_mutable(
|
|
SchemaArgument(SchemaArgType.input, i)
|
|
):
|
|
raise RuntimeError(
|
|
f"Argument {name} is not defined as mutable but was mutated"
|
|
)
|
|
else:
|
|
self.mutated.append(Mutation(func._schema.name, name))
|
|
|
|
# Aliasing between outputs
|
|
for i, j in combinations(range(len(func._schema.returns)), 2):
|
|
if has_aliased(tuple_out[i], tuple_out[j]):
|
|
if not schema_info.may_contain_alias(
|
|
SchemaArgument(SchemaArgType.output, i),
|
|
SchemaArgument(SchemaArgType.output, j),
|
|
):
|
|
raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
|
|
|
|
return out
|