Files
pytorch/torch/_subclasses/schema_check_mode.py

231 lines
8.4 KiB
Python

# mypy: ignore-errors
from collections import namedtuple
from copy import deepcopy
from itertools import combinations
import torch
from torch.fx.operator_schemas import normalize_function
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
# Named Tuples used within SchemaCheckMode
Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
# Simplified naming for C++ classes
SchemaArgument = torch._C._SchemaArgument
SchemaArgType = torch._C._SchemaArgType
SchemaInfo = torch._C._SchemaInfo
# This TorchDispatchMode Subclass is used to verify op schemas
# This TorchDispatchMode Scubclass currently:
# - Records the called ops
# - Checks for mutations on all inputs
# - Checks for aliasing on all inputs
# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
def is_iterable_of_tensors(iterable):
# Tensor itself is iterable so we check this first
if isinstance(iterable, torch.Tensor):
return False
try:
if len(iterable) == 0:
return False
for t in iter(iterable):
if not isinstance(t, torch.Tensor):
return False
except TypeError:
return False
return True
def clone_inputs(args):
inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
inputs.append(arg.detach().clone())
elif is_iterable_of_tensors(arg):
inputs.append([t.detach().clone() for t in arg])
else:
inputs.append(arg)
return inputs
class SchemaCheckMode(TorchDispatchMode):
def __init__(self) -> None:
# Information recorded for testing purposes. For example:
# - incorrect schemas
# - overly conservative schemas
self.ops = []
self.mutated = []
self.aliasing = []
def reset_cache(self):
self.ops.clear()
self.mutated.clear()
self.aliasing.clear()
def display_ops(self):
print(*self.ops, sep=",")
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def bitwise_equal(lhs, rhs):
if lhs.is_quantized:
# TODO: This is only OK if can't have NaN quantized; idk if
# this is actually true
return torch.equal(lhs, rhs)
else:
return torch.allclose(lhs, rhs, equal_nan=True)
def has_mutated(before, after, md):
are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
if (
are_tensors
and before.layout != torch.sparse_csr
and after.layout != torch.sparse_csr
):
return not (
before.size() == after.size()
and bitwise_equal(before, after)
and md[0] == after.stride()
and md[1] == after._typed_storage()._cdata
)
return False
def has_aliased(lhs, rhs):
try:
return torch._C._overlaps(lhs, rhs)
except Exception as exception:
if str(exception).startswith("Cannot inspect value of type "):
return False
else:
raise exception
def standardize_name(name):
return name if name != "self" else "input"
def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
try:
return e.elem
except AttributeError:
return e
return e
def parse_metadata(e):
if isinstance(e, torch.Tensor):
if not type(e) == torch.Tensor:
try:
current = e.elem
return (
deepcopy(current.stride()),
current._typed_storage()._cdata,
)
except AttributeError:
return None
# Sparse CSR tensors do not have strides or storage
elif e.layout != torch.sparse_csr:
return (deepcopy(e.stride()), e._typed_storage()._cdata)
return None
self.ops.append(func._schema.name)
# Clone and process arguments and outputs
pre_arguments = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
).kwargs
c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
cloned_arguments = {
name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
}
cloned_metadata = {
name: [
parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
]
for name in pre_arguments
}
out = func(*args, **kwargs)
arguments = {
name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
}
tuple_out = out if isinstance(out, tuple) else (out,)
tuple_out = tree_map(unwrap, tuple_out)
schema_info = SchemaInfo(func._schema)
schema_info.add_argument_values(pre_arguments)
# Process arguments with outputs
for i in range(len(func._schema.arguments)):
arg = func._schema.arguments[i]
name = standardize_name(arg.name)
if arguments.get(name) is not None:
before = cloned_arguments.get(name)
md = cloned_metadata.get(name)
after = arguments.get(name)
for j in range(len(tuple_out)):
# aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
if (
has_aliased(tuple_out[j], after)
and func._schema.name not in unsafe_ops
):
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, j),
SchemaArgument(SchemaArgType.input, i),
):
raise RuntimeError(
f"Argument {name} is not defined to alias output but was aliasing"
)
else:
self.aliasing.append(
Aliasing(func._schema.name, name, f"output_{j}")
)
if after is tuple_out[j] and isinstance(after, torch.Tensor):
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
if not schema_info.is_mutable(
SchemaArgument(SchemaArgType.input, i)
) and func not in [
torch.ops.aten.lift.default,
torch.ops.aten.lift_fresh.default,
]:
raise RuntimeError(
f"""\
Dispatcher operators below autograd are not allowed to directly return inputs.
However, we found that `outputs[{str(j)}] is {name}"""
)
if any(
has_mutated(a, b, c)
for a, b, c in zip(
pytree.tree_leaves(before), pytree.tree_leaves(after), md
)
):
if not schema_info.is_mutable(
SchemaArgument(SchemaArgType.input, i)
):
raise RuntimeError(
f"Argument {name} is not defined as mutable but was mutated"
)
else:
self.mutated.append(Mutation(func._schema.name, name))
# Aliasing between outputs
for i, j in combinations(range(len(func._schema.returns)), 2):
if has_aliased(tuple_out[i], tuple_out[j]):
if not schema_info.may_contain_alias(
SchemaArgument(SchemaArgType.output, i),
SchemaArgument(SchemaArgType.output, j),
):
raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
return out