From 1f2d00e537ac25d4038a5541101daca3da04115f Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 27 Apr 2023 14:50:59 +0000 Subject: [PATCH] move SchemaCheckMode to torch/_subclasses (#99743) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99743 Approved by: https://github.com/albanD --- .lintrunner.toml | 1 + test/test_schema_check.py | 2 +- .../schema_check_mode.py | 93 ++++++++++++------- 3 files changed, 64 insertions(+), 32 deletions(-) rename torch/{testing/_internal => _subclasses}/schema_check_mode.py (60%) diff --git a/.lintrunner.toml b/.lintrunner.toml index af86a3b68666..a3f64339c213 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -116,6 +116,7 @@ exclude_patterns = [ 'torch/_functorch/partitioners.py', 'torch/_functorch/top_operators_github_usage.py', 'torch/_functorch/vmap.py', + 'torch/_subclasses/schema_check_mode.py', 'torch/distributed/elastic/agent/server/api.py', 'torch/testing/_internal/**', 'torch/distributed/fsdp/fully_sharded_data_parallel.py', diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 7191c6b58a5f..c6c50a387368 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -8,7 +8,7 @@ import unittest from torch.testing._internal.common_utils import run_tests from torch.fx.operator_schemas import normalize_function -from torch.testing._internal.schema_check_mode import SchemaCheckMode +from torch._subclasses.schema_check_mode import SchemaCheckMode from torch.utils._python_dispatch import TorchDispatchMode from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.jit_utils import JitTestCase diff --git a/torch/testing/_internal/schema_check_mode.py b/torch/_subclasses/schema_check_mode.py similarity index 60% rename from torch/testing/_internal/schema_check_mode.py rename to torch/_subclasses/schema_check_mode.py index e67afff1443d..1535a43e1763 100644 --- a/torch/testing/_internal/schema_check_mode.py +++ b/torch/_subclasses/schema_check_mode.py @@ -1,15 +1,16 @@ +from collections import namedtuple +from copy import deepcopy +from itertools import combinations + import torch -from torch.utils._pytree import tree_flatten, tree_map from torch.fx.operator_schemas import normalize_function from torch.testing._internal.jit_utils import clone_inputs from torch.utils._python_dispatch import TorchDispatchMode -from itertools import combinations -from collections import namedtuple -from copy import deepcopy +from torch.utils._pytree import tree_flatten, tree_map # Named Tuples used within SchemaCheckMode -Mutation = namedtuple('Mutation', ['op_name', 'arg_name']) -Aliasing = namedtuple('Aliasing', ['op_name', 'arg_name', 'output_number']) +Mutation = namedtuple("Mutation", ["op_name", "arg_name"]) +Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"]) # Simplified naming for C++ classes SchemaArgument = torch._C._SchemaArgument @@ -22,6 +23,7 @@ SchemaInfo = torch._C._SchemaInfo # - Checks for mutations on all inputs # - Checks for aliasing on all inputs + class SchemaCheckMode(TorchDispatchMode): def __init__(self): # Information recorded for testing purposes. For example: @@ -42,12 +44,16 @@ class SchemaCheckMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): def has_mutated(before, after, md): are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor - if are_tensors and before.layout != torch.sparse_csr and after.layout != torch.sparse_csr: + if ( + are_tensors + and before.layout != torch.sparse_csr + and after.layout != torch.sparse_csr + ): return not ( - before.size() == after.size() and - torch.allclose(before, after, equal_nan=True) and - md[0] == after.stride() and - md[1] == after._typed_storage()._cdata + before.size() == after.size() + and torch.allclose(before, after, equal_nan=True) + and md[0] == after.stride() + and md[1] == after._typed_storage()._cdata ) return False @@ -76,11 +82,14 @@ class SchemaCheckMode(TorchDispatchMode): if not type(e) == torch.Tensor: try: current = e.elem - return (deepcopy(current.stride()), current._typed_storage()._cdata) + return ( + deepcopy(current.stride()), + current._typed_storage()._cdata, + ) except AttributeError as t: return None # Sparse CSR tensors do not have strides or storage - elif (e.layout != torch.sparse_csr): + elif e.layout != torch.sparse_csr: return (deepcopy(e.stride()), e._typed_storage()._cdata) return None @@ -88,19 +97,23 @@ class SchemaCheckMode(TorchDispatchMode): # Clone and process arguments and outputs pre_arguments = normalize_function( - func, - args, - kwargs, - normalize_to_only_use_kwargs=True + func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) - cloned_arguments = {name : tree_map(unwrap, c_p_args.get(name)) for name in c_p_args} - cloned_metadata = {name : tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments} + cloned_arguments = { + name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args + } + cloned_metadata = { + name: tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) + for name in pre_arguments + } out = func(*args, **kwargs) - arguments = {name : tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments} - tuple_out = out if isinstance(out, tuple) else (out, ) + arguments = { + name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments + } + tuple_out = out if isinstance(out, tuple) else (out,) tuple_out = tree_map(unwrap, tuple_out) schema_info = SchemaInfo(func._schema) @@ -116,17 +129,34 @@ class SchemaCheckMode(TorchDispatchMode): after = arguments.get(name) for j in range(len(tuple_out)): # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe) - unsafe_ops = ('aten::_unsafe_view', 'aten::unsafe_split') - if has_aliased(tuple_out[j], after) and func._schema.name not in unsafe_ops: + unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split") + if ( + has_aliased(tuple_out[j], after) + and func._schema.name not in unsafe_ops + ): if not schema_info.may_contain_alias( SchemaArgument(SchemaArgType.output, j), - SchemaArgument(SchemaArgType.input, i)): - raise RuntimeError(f'Argument {name} is not defined to alias output but was aliasing') + SchemaArgument(SchemaArgType.input, i), + ): + raise RuntimeError( + f"Argument {name} is not defined to alias output but was aliasing" + ) else: - self.aliasing.append(Aliasing(func._schema.name, name, f"output_{j}")) - if any(has_mutated(a, b, c) for a, b, c in zip(tree_flatten(before)[0], tree_flatten(after)[0], md)): - if not schema_info.is_mutable(SchemaArgument(SchemaArgType.input, i)): - raise RuntimeError(f"Argument {name} is not defined as mutable but was mutated") + self.aliasing.append( + Aliasing(func._schema.name, name, f"output_{j}") + ) + if any( + has_mutated(a, b, c) + for a, b, c in zip( + tree_flatten(before)[0], tree_flatten(after)[0], md + ) + ): + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ): + raise RuntimeError( + f"Argument {name} is not defined as mutable but was mutated" + ) else: self.mutated.append(Mutation(func._schema.name, name)) @@ -135,7 +165,8 @@ class SchemaCheckMode(TorchDispatchMode): if has_aliased(tuple_out[i], tuple_out[j]): if not schema_info.may_contain_alias( SchemaArgument(SchemaArgType.output, i), - SchemaArgument(SchemaArgType.output, j)): - raise RuntimeError(f'Outputs {i} and {j} alias unexpectedly') + SchemaArgument(SchemaArgType.output, j), + ): + raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly") return out