mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Build a __torch_dispatch__ class that records torch operator names
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78835 Approved by: https://github.com/Gamrix
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee933c3346
commit
1f53d036d2
38
test/jit/test_schema_check.py
Normal file
38
test/jit/test_schema_check.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
||||
from torch.testing._internal import schema_check_tensor
|
||||
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
# Tests various schema checking functionalities.
|
||||
class TestSchemaCheck(JitTestCase):
|
||||
def setUp(self):
|
||||
schema_check_tensor.reset_cache()
|
||||
|
||||
# Tests that SchemaCheckTensor records operator order with grad
|
||||
def test_schema_check_tensor_operator_order_grad(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["relu.default", "detach.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
|
||||
|
||||
# Tests that SchemaCheckTensor records operator order without grad
|
||||
def test_schema_check_tensor_operator_order_no_grad(self):
|
||||
x = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check_tensor.SchemaCheckTensor(x).relu().sin()
|
||||
self.assertEqual(["relu.default", "sin.default"], schema_check_tensor.schema_check_recorded_ops)
|
||||
|
||||
# Tests that SchemaCheckTensor wraps torch.Tensor
|
||||
def test_schema_check_tensor_functionality(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
self.assertEqual(x.relu().sin(), schema_check_tensor.SchemaCheckTensor(x).relu().sin().elem)
|
@ -56,6 +56,7 @@ from jit.test_typing import TestTyping # noqa: F401
|
||||
from jit.test_hash import TestHash # noqa: F401
|
||||
from jit.test_complex import TestComplex # noqa: F401
|
||||
from jit.test_jit_utils import TestJitUtils # noqa: F401
|
||||
from jit.test_schema_check import TestSchemaCheck # noqa: F401
|
||||
from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401
|
||||
from jit.test_types import TestTypesAndAnnotation # noqa: F401
|
||||
from jit.test_misc import TestMisc # noqa: F401
|
||||
|
56
torch/testing/_internal/schema_check_tensor.py
Normal file
56
torch/testing/_internal/schema_check_tensor.py
Normal file
@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
schema_check_recorded_ops = []
|
||||
|
||||
# This Tensor Subclass is used to verify op schemas
|
||||
# This Tensor currently:
|
||||
# - Records the called ops and appends to schema_check_records_ops
|
||||
|
||||
class SchemaCheckTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
# The wrapping tensor (SchemaCheckTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls, elem.size(),
|
||||
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||||
# TODO: clone storage aliasing
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device=elem.device, requires_grad=elem.requires_grad
|
||||
)
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
if self.grad_fn:
|
||||
return f"SchemaCheckTensor({self.elem}, grad_fn={self.grad_fn})"
|
||||
return f"SchemaCheckTensor({self.elem})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, cls) else e
|
||||
|
||||
def wrap(e):
|
||||
return cls(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
global schema_check_recorded_ops
|
||||
schema_check_recorded_ops.append(func.__name__)
|
||||
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
return tree_map(wrap, out)
|
||||
|
||||
def reset_cache():
|
||||
global schema_check_recorded_ops
|
||||
schema_check_recorded_ops.clear()
|
||||
|
||||
def display_ops():
|
||||
print(*schema_check_recorded_ops, sep=",")
|
Reference in New Issue
Block a user