mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163278 Approved by: https://github.com/zou3519 ghstack dependencies: #163279, #163277
270 lines
9.5 KiB
Python
270 lines
9.5 KiB
Python
# Owner(s): ["module: custom-operators"]
|
|
import copy
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._library.fake_class_registry import maybe_to_fake_obj
|
|
from torch._library.opaque_object import (
|
|
get_payload,
|
|
make_opaque,
|
|
OpaqueType,
|
|
set_payload,
|
|
)
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
|
|
|
|
class OpaqueQueue:
|
|
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
|
|
super().__init__()
|
|
self.queue = queue
|
|
self.init_tensor_ = init_tensor_
|
|
|
|
# For testing purposes
|
|
self._push_counter = 0
|
|
self._pop_counter = 0
|
|
self._size_counter = 0
|
|
|
|
def push(self, tensor: torch.Tensor) -> None:
|
|
self._push_counter += 1
|
|
self.queue.append(tensor)
|
|
|
|
def pop(self) -> torch.Tensor:
|
|
self._pop_counter += 1
|
|
if len(self.queue) > 0:
|
|
return self.queue.pop(0)
|
|
return self.init_tensor_
|
|
|
|
def size(self) -> int:
|
|
self._size_counter += 1
|
|
return len(self.queue)
|
|
|
|
def __eq__(self, other):
|
|
if len(self.queue) != len(other.queue):
|
|
return False
|
|
for q1, q2 in zip(self.queue, other.queue):
|
|
if not torch.allclose(q1, q2):
|
|
return False
|
|
return torch.allclose(self.init_tensor_, other.init_tensor_)
|
|
|
|
|
|
class TestOpaqueObject(TestCase):
|
|
def setUp(self):
|
|
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
|
|
|
|
torch.library.define(
|
|
"_TestOpaqueObject::queue_push",
|
|
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=self.lib,
|
|
)
|
|
|
|
@torch.library.impl(
|
|
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
|
|
)
|
|
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
|
|
queue = get_payload(q)
|
|
assert isinstance(queue, OpaqueQueue)
|
|
queue.push(b)
|
|
|
|
@torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib)
|
|
def push_impl_fake(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
|
|
pass
|
|
|
|
self.lib.define(
|
|
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
|
|
)
|
|
|
|
def pop_impl(q: torch._C.ScriptObject) -> torch.Tensor:
|
|
queue = get_payload(q)
|
|
assert isinstance(queue, OpaqueQueue)
|
|
return queue.pop()
|
|
|
|
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
|
|
|
def pop_impl_fake(q: torch._C.ScriptObject) -> torch.Tensor:
|
|
# This is not accurate since the queue could have tensors that are
|
|
# not rank 1
|
|
ctx = torch._custom_op.impl.get_ctx()
|
|
u0 = ctx.create_unbacked_symint()
|
|
return torch.empty(u0)
|
|
|
|
self.lib._register_fake("queue_pop", pop_impl_fake)
|
|
|
|
@torch.library.custom_op(
|
|
"_TestOpaqueObject::queue_size",
|
|
mutates_args=[],
|
|
)
|
|
def size_impl(q: OpaqueType) -> int:
|
|
queue = get_payload(q)
|
|
assert isinstance(queue, OpaqueQueue)
|
|
return queue.size()
|
|
|
|
@size_impl.register_fake
|
|
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
|
ctx = torch._custom_op.impl.get_ctx()
|
|
u0 = ctx.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
return u0
|
|
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
self.lib._destroy()
|
|
|
|
super().tearDown()
|
|
|
|
def test_creation(self):
|
|
queue = OpaqueQueue([], torch.zeros(3))
|
|
obj = make_opaque(queue)
|
|
self.assertTrue(isinstance(obj, torch._C.ScriptObject))
|
|
self.assertEqual(str(obj._type()), "__torch__.torch.classes.aten.OpaqueObject")
|
|
|
|
# obj.payload stores a direct reference to this python queue object
|
|
payload = get_payload(obj)
|
|
self.assertEqual(payload, queue)
|
|
queue.push(torch.ones(3))
|
|
self.assertEqual(payload.size(), 1)
|
|
|
|
def test_ops(self):
|
|
queue = OpaqueQueue([], torch.zeros(3))
|
|
obj = make_opaque()
|
|
set_payload(obj, queue)
|
|
|
|
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
|
|
self.assertEqual(queue.size(), 1)
|
|
size = torch.ops._TestOpaqueObject.queue_size(obj)
|
|
self.assertEqual(size, queue.size())
|
|
popped = torch.ops._TestOpaqueObject.queue_pop(obj)
|
|
self.assertEqual(popped, torch.ones(3) + 1)
|
|
self.assertEqual(queue.size(), 0)
|
|
|
|
def test_eq(self):
|
|
self.assertTrue(make_opaque("moo") == make_opaque("moo"))
|
|
self.assertFalse(make_opaque("moo") == make_opaque("mop"))
|
|
|
|
q1 = OpaqueQueue([torch.ones(3)], torch.zeros(3))
|
|
q2 = OpaqueQueue([torch.ones(3)], torch.zeros(3))
|
|
obj1 = make_opaque(q1)
|
|
obj2 = make_opaque(q2)
|
|
self.assertTrue(obj1 == obj1)
|
|
self.assertTrue(q1 == q2)
|
|
self.assertTrue(obj1 == obj2)
|
|
|
|
def test_deepcopy(self):
|
|
q1 = OpaqueQueue([torch.ones(3), torch.ones(3) * 2], torch.zeros(3))
|
|
obj1 = make_opaque(q1)
|
|
|
|
obj2 = copy.deepcopy(obj1)
|
|
q2 = get_payload(obj2)
|
|
|
|
self.assertTrue(q1 is not q2)
|
|
self.assertTrue(q1 == q2)
|
|
|
|
def test_bad_fake(self):
|
|
torch.library.define(
|
|
"_TestOpaqueObject::bad_fake",
|
|
"(__torch__.torch.classes.aten.OpaqueObject q, Tensor x) -> Tensor",
|
|
lib=self.lib,
|
|
)
|
|
|
|
def f(q, x):
|
|
torch.ops._TestOpaqueObject.bad_fake(q, x)
|
|
return x.cos()
|
|
|
|
def bad_fake1(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
|
|
payload = get_payload(q)
|
|
return b * payload
|
|
|
|
torch.library.register_fake(
|
|
"_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib
|
|
)
|
|
|
|
with FakeTensorMode() as fake_mode:
|
|
obj = make_opaque(1)
|
|
fake_obj = maybe_to_fake_obj(fake_mode, obj)
|
|
x = torch.ones(3)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"get_payload: this function was called with a FakeScriptObject",
|
|
):
|
|
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
|
|
|
|
def bad_fake2(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
|
|
set_payload(q, 2)
|
|
return torch.empty_like(b)
|
|
|
|
torch.library.register_fake(
|
|
"_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True
|
|
)
|
|
|
|
with FakeTensorMode() as fake_mode:
|
|
obj = make_opaque(1)
|
|
fake_obj = maybe_to_fake_obj(fake_mode, obj)
|
|
x = torch.ones(3)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"set_payload: this function was called with a FakeScriptObject",
|
|
):
|
|
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
|
|
|
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
|
def test_make_fx(self, make_fx_tracing_mode):
|
|
class M(torch.nn.Module):
|
|
def forward(self, queue, x):
|
|
torch.ops._TestOpaqueObject.queue_push(queue, x.tan())
|
|
torch.ops._TestOpaqueObject.queue_push(queue, x.cos())
|
|
torch.ops._TestOpaqueObject.queue_push(queue, x.sin())
|
|
pop1 = torch.ops._TestOpaqueObject.queue_pop(queue)
|
|
size1 = torch.ops._TestOpaqueObject.queue_size(queue)
|
|
pop2 = torch.ops._TestOpaqueObject.queue_pop(queue)
|
|
size2 = torch.ops._TestOpaqueObject.queue_size(queue)
|
|
x_cos = pop1 + size1
|
|
x_sin = pop2 - size2
|
|
return x_sin + x_cos
|
|
|
|
q1 = OpaqueQueue([], torch.empty(0).fill_(-1))
|
|
obj1 = make_opaque(q1)
|
|
q2 = OpaqueQueue([], torch.empty(0).fill_(-1))
|
|
obj2 = make_opaque(q2)
|
|
|
|
x = torch.ones(2, 3)
|
|
gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(obj1, x)
|
|
self.assertTrue(torch.allclose(gm(obj1, x), M()(obj2, x)))
|
|
self.assertEqual(q1._push_counter, 3)
|
|
self.assertEqual(q1._pop_counter, 2)
|
|
self.assertEqual(q1._size_counter, 2)
|
|
self.assertEqual(q1.size(), 1)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
tan = torch.ops.aten.tan.default(arg1_1)
|
|
queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None
|
|
cos = torch.ops.aten.cos.default(arg1_1)
|
|
queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None
|
|
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
|
queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None
|
|
queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
|
|
queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1)
|
|
queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
|
|
queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None
|
|
add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None
|
|
sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None
|
|
add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None
|
|
return add_1
|
|
""",
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestOpaqueObject)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|