From 322091d8d8542a0cbff524306029bef4d7338747 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 7 Oct 2025 15:36:01 -0700 Subject: [PATCH] [opaque_obj] Add make_fx tracing support (#163278) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163278 Approved by: https://github.com/zou3519 ghstack dependencies: #163279, #163277 --- test/test_opaque_obj.py | 134 ++++++++++++++++++++++++++ torch/_library/fake_class_registry.py | 102 +++++++++++--------- torch/_library/opaque_object.py | 34 +++++++ 3 files changed, 224 insertions(+), 46 deletions(-) diff --git a/test/test_opaque_obj.py b/test/test_opaque_obj.py index 106cf2815741..f78ab4faef8f 100644 --- a/test/test_opaque_obj.py +++ b/test/test_opaque_obj.py @@ -3,12 +3,19 @@ import copy import torch from torch._dynamo.test_case import run_tests, TestCase +from torch._library.fake_class_registry import maybe_to_fake_obj from torch._library.opaque_object import ( get_payload, make_opaque, OpaqueType, set_payload, ) +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) class OpaqueQueue: @@ -17,15 +24,23 @@ class OpaqueQueue: self.queue = queue self.init_tensor_ = init_tensor_ + # For testing purposes + self._push_counter = 0 + self._pop_counter = 0 + self._size_counter = 0 + def push(self, tensor: torch.Tensor) -> None: + self._push_counter += 1 self.queue.append(tensor) def pop(self) -> torch.Tensor: + self._pop_counter += 1 if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: + self._size_counter += 1 return len(self.queue) def __eq__(self, other): @@ -56,6 +71,10 @@ class TestOpaqueObject(TestCase): assert isinstance(queue, OpaqueQueue) queue.push(b) + @torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib) + def push_impl_fake(q: torch._C.ScriptObject, b: torch.Tensor) -> None: + pass + self.lib.define( "queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor", ) @@ -67,6 +86,15 @@ class TestOpaqueObject(TestCase): self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd") + def pop_impl_fake(q: torch._C.ScriptObject) -> torch.Tensor: + # This is not accurate since the queue could have tensors that are + # not rank 1 + ctx = torch._custom_op.impl.get_ctx() + u0 = ctx.create_unbacked_symint() + return torch.empty(u0) + + self.lib._register_fake("queue_pop", pop_impl_fake) + @torch.library.custom_op( "_TestOpaqueObject::queue_size", mutates_args=[], @@ -76,6 +104,13 @@ class TestOpaqueObject(TestCase): assert isinstance(queue, OpaqueQueue) return queue.size() + @size_impl.register_fake + def size_impl_fake(q: torch._C.ScriptObject) -> int: + ctx = torch._custom_op.impl.get_ctx() + u0 = ctx.create_unbacked_symint() + torch._check_is_size(u0) + return u0 + super().setUp() def tearDown(self): @@ -130,6 +165,105 @@ class TestOpaqueObject(TestCase): self.assertTrue(q1 is not q2) self.assertTrue(q1 == q2) + def test_bad_fake(self): + torch.library.define( + "_TestOpaqueObject::bad_fake", + "(__torch__.torch.classes.aten.OpaqueObject q, Tensor x) -> Tensor", + lib=self.lib, + ) + + def f(q, x): + torch.ops._TestOpaqueObject.bad_fake(q, x) + return x.cos() + + def bad_fake1(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor: + payload = get_payload(q) + return b * payload + + torch.library.register_fake( + "_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib + ) + + with FakeTensorMode() as fake_mode: + obj = make_opaque(1) + fake_obj = maybe_to_fake_obj(fake_mode, obj) + x = torch.ones(3) + + with self.assertRaisesRegex( + ValueError, + "get_payload: this function was called with a FakeScriptObject", + ): + torch.ops._TestOpaqueObject.bad_fake(fake_obj, x) + + def bad_fake2(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor: + set_payload(q, 2) + return torch.empty_like(b) + + torch.library.register_fake( + "_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True + ) + + with FakeTensorMode() as fake_mode: + obj = make_opaque(1) + fake_obj = maybe_to_fake_obj(fake_mode, obj) + x = torch.ones(3) + + with self.assertRaisesRegex( + ValueError, + "set_payload: this function was called with a FakeScriptObject", + ): + torch.ops._TestOpaqueObject.bad_fake(fake_obj, x) + + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_make_fx(self, make_fx_tracing_mode): + class M(torch.nn.Module): + def forward(self, queue, x): + torch.ops._TestOpaqueObject.queue_push(queue, x.tan()) + torch.ops._TestOpaqueObject.queue_push(queue, x.cos()) + torch.ops._TestOpaqueObject.queue_push(queue, x.sin()) + pop1 = torch.ops._TestOpaqueObject.queue_pop(queue) + size1 = torch.ops._TestOpaqueObject.queue_size(queue) + pop2 = torch.ops._TestOpaqueObject.queue_pop(queue) + size2 = torch.ops._TestOpaqueObject.queue_size(queue) + x_cos = pop1 + size1 + x_sin = pop2 - size2 + return x_sin + x_cos + + q1 = OpaqueQueue([], torch.empty(0).fill_(-1)) + obj1 = make_opaque(q1) + q2 = OpaqueQueue([], torch.empty(0).fill_(-1)) + obj2 = make_opaque(q2) + + x = torch.ones(2, 3) + gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(obj1, x) + self.assertTrue(torch.allclose(gm(obj1, x), M()(obj2, x))) + self.assertEqual(q1._push_counter, 3) + self.assertEqual(q1._pop_counter, 2) + self.assertEqual(q1._size_counter, 2) + self.assertEqual(q1.size(), 1) + self.assertExpectedInline( + gm.code.strip("\n"), + """\ +def forward(self, arg0_1, arg1_1): + tan = torch.ops.aten.tan.default(arg1_1) + queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None + cos = torch.ops.aten.cos.default(arg1_1) + queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None + sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None + queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None + queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) + queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1) + queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) + queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None + add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None + sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None + add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None + return add_1 + """, + ) + + +instantiate_parametrized_tests(TestOpaqueObject) if __name__ == "__main__": run_tests() diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index 68208d0be4a8..1902eafc0a48 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -20,13 +20,14 @@ class FakeScriptObject: try: with _disable_current_modes(): self.real_obj = copy.deepcopy(x) - except RuntimeError: + except RuntimeError as e: log.warning( - "Unable to deepcopy the custom object %s. " + "Unable to deepcopy the custom object %s due to %s. " "Defaulting to the user given object. This might be " "dangerous as side effects may be directly applied " "to the object.", script_class_name, + str(e), ) self.real_obj = x @@ -134,55 +135,64 @@ def maybe_to_fake_obj( if tracing_with_real(x): return x - # x.__obj_flatten__() could be calling some tensor operations inside but we don't - # want to call these ops in surrounding dispatch modes when executing it. - # Otherwise, for example, the fake tensor modes will error out when the tensors inside - # script object execute some operations like clone if allow_non_fake_input flag is set. - with _disable_current_modes(): - flat_x = x.__obj_flatten__() # type: ignore[attr-defined] + from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr - _check_valid_flat_script_obj(flat_x) + if str(x._type()) == OpaqueTypeStr: + # In order to make OpaqueObjects truly opaque, the fake kernel should + # not depend on the contents of the OpaqueObject at all. + fake_x = FakeOpaqueObject() - with fake_mode: - from torch._higher_order_ops.utils import _tensor_storage + else: + # x.__obj_flatten__() could be calling some tensor operations inside but we don't + # want to call these ops in surrounding dispatch modes when executing it. + # Otherwise, for example, the fake tensor modes will error out when the tensors inside + # script object execute some operations like clone if allow_non_fake_input flag is set. + with _disable_current_modes(): + flat_x = x.__obj_flatten__() # type: ignore[attr-defined] - storage_map = { - _tensor_storage(inp): i - for i, inp in enumerate(flat_x) - if isinstance(inp, torch.Tensor) - } - alias_map = { - i: storage_map[_tensor_storage(inp)] - for i, inp in enumerate(flat_x) - if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i - } - if len(alias_map) > 0: - log.warning( - "Detected script object %s has aliasing relationship among its tensors. " - "Flattened obj: %s. Aliasing tensor indices: %s. " - "This is not supported and may cause unexpected behavior.", - x, + _check_valid_flat_script_obj(flat_x) + + with fake_mode: + from torch._higher_order_ops.utils import _tensor_storage + + storage_map = { + _tensor_storage(inp): i + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) + } + alias_map = { + i: storage_map[_tensor_storage(inp)] + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) + and storage_map[_tensor_storage(inp)] != i + } + if len(alias_map) > 0: + log.warning( + "Detected script object %s has aliasing relationship among its tensors. " + "Flattened obj: %s. Aliasing tensor indices: %s. " + "This is not supported and may cause unexpected behavior.", + x, + flat_x, + alias_map, + ) + + # This breaks the aliasing relationship among the tensors inside the torchbind object + # This is bad but since we don't need to preserve the aliasing relationship anyway and + # we state clearly that aliasing relationship is not preserved in the doc so this might be OK. + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + layout=t.layout, + ), flat_x, - alias_map, ) - # This breaks the aliasing relationship among the tensors inside the torchbind object - # This is bad but since we don't need to preserve the aliasing relationship anyway and - # we state clearly that aliasing relationship is not preserved in the doc so this might be OK. - fake_flattened = pytree.tree_map_only( - torch.Tensor, - lambda t: torch.empty_strided( - t.size(), - t.stride(), - device=t.device, - dtype=t.dtype, - requires_grad=t.requires_grad, - layout=t.layout, - ), - flat_x, - ) - - fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) + fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined] @@ -205,7 +215,7 @@ def maybe_to_fake_obj( FakeScriptMethod(fake_x_wrapped, name, method_schema), ) else: - override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"} + override_skip_list = {"__obj_flatten__", "__getstate__", "__setstate__"} if name not in override_skip_list: log.warning("fake object of %s doesn't implement method %s.", x, name) return fake_x_wrapped diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index db223c38616b..ba02970d5504 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -2,6 +2,21 @@ from typing import Any, NewType import torch +from .fake_class_registry import FakeScriptObject, register_fake_class + + +@register_fake_class("aten::OpaqueObject") +class FakeOpaqueObject: + def __init__(self) -> None: + pass + + @classmethod + def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None: + raise RuntimeError( + "FakeOpaqueObject should not be created through __obj_unflatten__ " + "and should be special handled. Please file an issue to Github." + ) + OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject" @@ -80,6 +95,15 @@ def get_payload(opaque_object: torch._C.ScriptObject) -> Any: payload (Any): The Python object stored in the opaque object. This can be set with `set_payload()`. """ + if isinstance(opaque_object, FakeScriptObject): + raise ValueError( + "get_payload: this function was called with a FakeScriptObject " + "implying that you are calling get_payload inside of a fake kernel." + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) if not ( isinstance(opaque_object, torch._C.ScriptObject) and opaque_object._type().qualified_name() == OpaqueTypeStr @@ -103,6 +127,16 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None: torch._C.ScriptObject: The opaque object that stores the given Python object. payload (Any): The Python object to store in the opaque object. """ + if isinstance(opaque_object, FakeScriptObject): + raise ValueError( + "set_payload: this function was called with a FakeScriptObject " + "implying that you are calling get_payload inside of a fake kernel." + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) + if not ( isinstance(opaque_object, torch._C.ScriptObject) and opaque_object._type().qualified_name() == OpaqueTypeStr