Build a __torch_dispatch__ class that records torch operator names

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78835

Approved by: https://github.com/Gamrix
This commit is contained in:
goldenxuett
2022-06-03 16:29:10 -07:00
committed by PyTorch MergeBot
parent ee933c3346
commit 1f53d036d2
3 changed files with 95 additions and 0 deletions

View File

@ -0,0 +1,38 @@
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
from torch.testing._internal import schema_check_tensor
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# Tests various schema checking functionalities.
class TestSchemaCheck(JitTestCase):
def setUp(self):
schema_check_tensor.reset_cache()
# Tests that SchemaCheckTensor records operator order with grad
def test_schema_check_tensor_operator_order_grad(self):
x = torch.rand((3, 3), requires_grad=True)
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
self.assertEqual(["relu.default", "detach.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
# Tests that SchemaCheckTensor records operator order without grad
def test_schema_check_tensor_operator_order_no_grad(self):
x = torch.rand((3, 3), requires_grad=False)
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
self.assertEqual(["relu.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
# Tests that SchemaCheckTensor wraps torch.Tensor
def test_schema_check_tensor_functionality(self):
x = torch.rand((3, 3), requires_grad=True)
self.assertEqual(x.relu().sin(), schema_check_tensor.SchemaCheckTensor(x).relu().sin().elem)

View File

@ -56,6 +56,7 @@ from jit.test_typing import TestTyping # noqa: F401
from jit.test_hash import TestHash # noqa: F401
from jit.test_complex import TestComplex # noqa: F401
from jit.test_jit_utils import TestJitUtils # noqa: F401
from jit.test_schema_check import TestSchemaCheck # noqa: F401
from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401
from jit.test_types import TestTypesAndAnnotation # noqa: F401
from jit.test_misc import TestMisc # noqa: F401

View File

@ -0,0 +1,56 @@
import torch
from torch.utils._pytree import tree_map
schema_check_recorded_ops = []
# This Tensor Subclass is used to verify op schemas
# This Tensor currently:
# - Records the called ops and appends to schema_check_records_ops
class SchemaCheckTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, elem):
# The wrapping tensor (SchemaCheckTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
# TODO: clone storage aliasing
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
def __repr__(self):
if self.grad_fn:
return f"SchemaCheckTensor({self.elem}, grad_fn={self.grad_fn})"
return f"SchemaCheckTensor({self.elem})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, cls) else e
def wrap(e):
return cls(e) if isinstance(e, torch.Tensor) else e
global schema_check_recorded_ops
schema_check_recorded_ops.append(func.__name__)
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
return tree_map(wrap, out)
def reset_cache():
global schema_check_recorded_ops
schema_check_recorded_ops.clear()
def display_ops():
print(*schema_check_recorded_ops, sep=",")