[opaque_obj] Add make_fx tracing support (#163278)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163278
Approved by: https://github.com/zou3519
ghstack dependencies: #163279, #163277
This commit is contained in:
angelayi
2025-10-07 15:36:01 -07:00
committed by PyTorch MergeBot
parent 2bb4e6876c
commit 322091d8d8
3 changed files with 224 additions and 46 deletions

View File

@ -3,12 +3,19 @@ import copy
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._library.fake_class_registry import maybe_to_fake_obj
from torch._library.opaque_object import (
get_payload,
make_opaque,
OpaqueType,
set_payload,
)
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
class OpaqueQueue:
@ -17,15 +24,23 @@ class OpaqueQueue:
self.queue = queue
self.init_tensor_ = init_tensor_
# For testing purposes
self._push_counter = 0
self._pop_counter = 0
self._size_counter = 0
def push(self, tensor: torch.Tensor) -> None:
self._push_counter += 1
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
self._pop_counter += 1
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
self._size_counter += 1
return len(self.queue)
def __eq__(self, other):
@ -56,6 +71,10 @@ class TestOpaqueObject(TestCase):
assert isinstance(queue, OpaqueQueue)
queue.push(b)
@torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib)
def push_impl_fake(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
pass
self.lib.define(
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
)
@ -67,6 +86,15 @@ class TestOpaqueObject(TestCase):
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
def pop_impl_fake(q: torch._C.ScriptObject) -> torch.Tensor:
# This is not accurate since the queue could have tensors that are
# not rank 1
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
return torch.empty(u0)
self.lib._register_fake("queue_pop", pop_impl_fake)
@torch.library.custom_op(
"_TestOpaqueObject::queue_size",
mutates_args=[],
@ -76,6 +104,13 @@ class TestOpaqueObject(TestCase):
assert isinstance(queue, OpaqueQueue)
return queue.size()
@size_impl.register_fake
def size_impl_fake(q: torch._C.ScriptObject) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
torch._check_is_size(u0)
return u0
super().setUp()
def tearDown(self):
@ -130,6 +165,105 @@ class TestOpaqueObject(TestCase):
self.assertTrue(q1 is not q2)
self.assertTrue(q1 == q2)
def test_bad_fake(self):
torch.library.define(
"_TestOpaqueObject::bad_fake",
"(__torch__.torch.classes.aten.OpaqueObject q, Tensor x) -> Tensor",
lib=self.lib,
)
def f(q, x):
torch.ops._TestOpaqueObject.bad_fake(q, x)
return x.cos()
def bad_fake1(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
payload = get_payload(q)
return b * payload
torch.library.register_fake(
"_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib
)
with FakeTensorMode() as fake_mode:
obj = make_opaque(1)
fake_obj = maybe_to_fake_obj(fake_mode, obj)
x = torch.ones(3)
with self.assertRaisesRegex(
ValueError,
"get_payload: this function was called with a FakeScriptObject",
):
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
def bad_fake2(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
set_payload(q, 2)
return torch.empty_like(b)
torch.library.register_fake(
"_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True
)
with FakeTensorMode() as fake_mode:
obj = make_opaque(1)
fake_obj = maybe_to_fake_obj(fake_mode, obj)
x = torch.ones(3)
with self.assertRaisesRegex(
ValueError,
"set_payload: this function was called with a FakeScriptObject",
):
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx(self, make_fx_tracing_mode):
class M(torch.nn.Module):
def forward(self, queue, x):
torch.ops._TestOpaqueObject.queue_push(queue, x.tan())
torch.ops._TestOpaqueObject.queue_push(queue, x.cos())
torch.ops._TestOpaqueObject.queue_push(queue, x.sin())
pop1 = torch.ops._TestOpaqueObject.queue_pop(queue)
size1 = torch.ops._TestOpaqueObject.queue_size(queue)
pop2 = torch.ops._TestOpaqueObject.queue_pop(queue)
size2 = torch.ops._TestOpaqueObject.queue_size(queue)
x_cos = pop1 + size1
x_sin = pop2 - size2
return x_sin + x_cos
q1 = OpaqueQueue([], torch.empty(0).fill_(-1))
obj1 = make_opaque(q1)
q2 = OpaqueQueue([], torch.empty(0).fill_(-1))
obj2 = make_opaque(q2)
x = torch.ones(2, 3)
gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(obj1, x)
self.assertTrue(torch.allclose(gm(obj1, x), M()(obj2, x)))
self.assertEqual(q1._push_counter, 3)
self.assertEqual(q1._pop_counter, 2)
self.assertEqual(q1._size_counter, 2)
self.assertEqual(q1.size(), 1)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
tan = torch.ops.aten.tan.default(arg1_1)
queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None
cos = torch.ops.aten.cos.default(arg1_1)
queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None
queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1)
queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None
add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None
sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None
add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None
return add_1
""",
)
instantiate_parametrized_tests(TestOpaqueObject)
if __name__ == "__main__":
run_tests()

View File

@ -20,13 +20,14 @@ class FakeScriptObject:
try:
with _disable_current_modes():
self.real_obj = copy.deepcopy(x)
except RuntimeError:
except RuntimeError as e:
log.warning(
"Unable to deepcopy the custom object %s. "
"Unable to deepcopy the custom object %s due to %s. "
"Defaulting to the user given object. This might be "
"dangerous as side effects may be directly applied "
"to the object.",
script_class_name,
str(e),
)
self.real_obj = x
@ -134,6 +135,14 @@ def maybe_to_fake_obj(
if tracing_with_real(x):
return x
from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr
if str(x._type()) == OpaqueTypeStr:
# In order to make OpaqueObjects truly opaque, the fake kernel should
# not depend on the contents of the OpaqueObject at all.
fake_x = FakeOpaqueObject()
else:
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
# want to call these ops in surrounding dispatch modes when executing it.
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
@ -154,7 +163,8 @@ def maybe_to_fake_obj(
alias_map = {
i: storage_map[_tensor_storage(inp)]
for i, inp in enumerate(flat_x)
if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i
if isinstance(inp, torch.Tensor)
and storage_map[_tensor_storage(inp)] != i
}
if len(alias_map) > 0:
log.warning(
@ -205,7 +215,7 @@ def maybe_to_fake_obj(
FakeScriptMethod(fake_x_wrapped, name, method_schema),
)
else:
override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"}
override_skip_list = {"__obj_flatten__", "__getstate__", "__setstate__"}
if name not in override_skip_list:
log.warning("fake object of %s doesn't implement method %s.", x, name)
return fake_x_wrapped

View File

@ -2,6 +2,21 @@ from typing import Any, NewType
import torch
from .fake_class_registry import FakeScriptObject, register_fake_class
@register_fake_class("aten::OpaqueObject")
class FakeOpaqueObject:
def __init__(self) -> None:
pass
@classmethod
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
raise RuntimeError(
"FakeOpaqueObject should not be created through __obj_unflatten__ "
"and should be special handled. Please file an issue to Github."
)
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
@ -80,6 +95,15 @@ def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
payload (Any): The Python object stored in the opaque object. This can
be set with `set_payload()`.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"get_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr
@ -103,6 +127,16 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
torch._C.ScriptObject: The opaque object that stores the given Python object.
payload (Any): The Python object to store in the opaque object.
"""
if isinstance(opaque_object, FakeScriptObject):
raise ValueError(
"set_payload: this function was called with a FakeScriptObject "
"implying that you are calling get_payload inside of a fake kernel."
"The fake kernel should not depend on the contents of the "
"OpaqueObject at all, so we're erroring out. If you need this"
"functionality, consider creating a custom TorchBind Object instead"
"(but note that this is more difficult)."
)
if not (
isinstance(opaque_object, torch._C.ScriptObject)
and opaque_object._type().qualified_name() == OpaqueTypeStr