mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[opaque_obj] Add make_fx tracing support (#163278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163278 Approved by: https://github.com/zou3519 ghstack dependencies: #163279, #163277
This commit is contained in:
committed by
PyTorch MergeBot
parent
2bb4e6876c
commit
322091d8d8
@ -3,12 +3,19 @@ import copy
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._library.fake_class_registry import maybe_to_fake_obj
|
||||
from torch._library.opaque_object import (
|
||||
get_payload,
|
||||
make_opaque,
|
||||
OpaqueType,
|
||||
set_payload,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
class OpaqueQueue:
|
||||
@ -17,15 +24,23 @@ class OpaqueQueue:
|
||||
self.queue = queue
|
||||
self.init_tensor_ = init_tensor_
|
||||
|
||||
# For testing purposes
|
||||
self._push_counter = 0
|
||||
self._pop_counter = 0
|
||||
self._size_counter = 0
|
||||
|
||||
def push(self, tensor: torch.Tensor) -> None:
|
||||
self._push_counter += 1
|
||||
self.queue.append(tensor)
|
||||
|
||||
def pop(self) -> torch.Tensor:
|
||||
self._pop_counter += 1
|
||||
if len(self.queue) > 0:
|
||||
return self.queue.pop(0)
|
||||
return self.init_tensor_
|
||||
|
||||
def size(self) -> int:
|
||||
self._size_counter += 1
|
||||
return len(self.queue)
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -56,6 +71,10 @@ class TestOpaqueObject(TestCase):
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
queue.push(b)
|
||||
|
||||
@torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib)
|
||||
def push_impl_fake(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
self.lib.define(
|
||||
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
|
||||
)
|
||||
@ -67,6 +86,15 @@ class TestOpaqueObject(TestCase):
|
||||
|
||||
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
||||
|
||||
def pop_impl_fake(q: torch._C.ScriptObject) -> torch.Tensor:
|
||||
# This is not accurate since the queue could have tensors that are
|
||||
# not rank 1
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
return torch.empty(u0)
|
||||
|
||||
self.lib._register_fake("queue_pop", pop_impl_fake)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_TestOpaqueObject::queue_size",
|
||||
mutates_args=[],
|
||||
@ -76,6 +104,13 @@ class TestOpaqueObject(TestCase):
|
||||
assert isinstance(queue, OpaqueQueue)
|
||||
return queue.size()
|
||||
|
||||
@size_impl.register_fake
|
||||
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
torch._check_is_size(u0)
|
||||
return u0
|
||||
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
@ -130,6 +165,105 @@ class TestOpaqueObject(TestCase):
|
||||
self.assertTrue(q1 is not q2)
|
||||
self.assertTrue(q1 == q2)
|
||||
|
||||
def test_bad_fake(self):
|
||||
torch.library.define(
|
||||
"_TestOpaqueObject::bad_fake",
|
||||
"(__torch__.torch.classes.aten.OpaqueObject q, Tensor x) -> Tensor",
|
||||
lib=self.lib,
|
||||
)
|
||||
|
||||
def f(q, x):
|
||||
torch.ops._TestOpaqueObject.bad_fake(q, x)
|
||||
return x.cos()
|
||||
|
||||
def bad_fake1(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
|
||||
payload = get_payload(q)
|
||||
return b * payload
|
||||
|
||||
torch.library.register_fake(
|
||||
"_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib
|
||||
)
|
||||
|
||||
with FakeTensorMode() as fake_mode:
|
||||
obj = make_opaque(1)
|
||||
fake_obj = maybe_to_fake_obj(fake_mode, obj)
|
||||
x = torch.ones(3)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"get_payload: this function was called with a FakeScriptObject",
|
||||
):
|
||||
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
|
||||
|
||||
def bad_fake2(q: torch._C.ScriptObject, b: torch.Tensor) -> torch.Tensor:
|
||||
set_payload(q, 2)
|
||||
return torch.empty_like(b)
|
||||
|
||||
torch.library.register_fake(
|
||||
"_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True
|
||||
)
|
||||
|
||||
with FakeTensorMode() as fake_mode:
|
||||
obj = make_opaque(1)
|
||||
fake_obj = maybe_to_fake_obj(fake_mode, obj)
|
||||
x = torch.ones(3)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"set_payload: this function was called with a FakeScriptObject",
|
||||
):
|
||||
torch.ops._TestOpaqueObject.bad_fake(fake_obj, x)
|
||||
|
||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||
def test_make_fx(self, make_fx_tracing_mode):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, queue, x):
|
||||
torch.ops._TestOpaqueObject.queue_push(queue, x.tan())
|
||||
torch.ops._TestOpaqueObject.queue_push(queue, x.cos())
|
||||
torch.ops._TestOpaqueObject.queue_push(queue, x.sin())
|
||||
pop1 = torch.ops._TestOpaqueObject.queue_pop(queue)
|
||||
size1 = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||
pop2 = torch.ops._TestOpaqueObject.queue_pop(queue)
|
||||
size2 = torch.ops._TestOpaqueObject.queue_size(queue)
|
||||
x_cos = pop1 + size1
|
||||
x_sin = pop2 - size2
|
||||
return x_sin + x_cos
|
||||
|
||||
q1 = OpaqueQueue([], torch.empty(0).fill_(-1))
|
||||
obj1 = make_opaque(q1)
|
||||
q2 = OpaqueQueue([], torch.empty(0).fill_(-1))
|
||||
obj2 = make_opaque(q2)
|
||||
|
||||
x = torch.ones(2, 3)
|
||||
gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(obj1, x)
|
||||
self.assertTrue(torch.allclose(gm(obj1, x), M()(obj2, x)))
|
||||
self.assertEqual(q1._push_counter, 3)
|
||||
self.assertEqual(q1._pop_counter, 2)
|
||||
self.assertEqual(q1._size_counter, 2)
|
||||
self.assertEqual(q1.size(), 1)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
tan = torch.ops.aten.tan.default(arg1_1)
|
||||
queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None
|
||||
cos = torch.ops.aten.cos.default(arg1_1)
|
||||
queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None
|
||||
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
||||
queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None
|
||||
queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
|
||||
queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1)
|
||||
queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1)
|
||||
queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None
|
||||
sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None
|
||||
return add_1
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOpaqueObject)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -20,13 +20,14 @@ class FakeScriptObject:
|
||||
try:
|
||||
with _disable_current_modes():
|
||||
self.real_obj = copy.deepcopy(x)
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
log.warning(
|
||||
"Unable to deepcopy the custom object %s. "
|
||||
"Unable to deepcopy the custom object %s due to %s. "
|
||||
"Defaulting to the user given object. This might be "
|
||||
"dangerous as side effects may be directly applied "
|
||||
"to the object.",
|
||||
script_class_name,
|
||||
str(e),
|
||||
)
|
||||
self.real_obj = x
|
||||
|
||||
@ -134,55 +135,64 @@ def maybe_to_fake_obj(
|
||||
if tracing_with_real(x):
|
||||
return x
|
||||
|
||||
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
|
||||
# want to call these ops in surrounding dispatch modes when executing it.
|
||||
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
|
||||
# script object execute some operations like clone if allow_non_fake_input flag is set.
|
||||
with _disable_current_modes():
|
||||
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
|
||||
from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr
|
||||
|
||||
_check_valid_flat_script_obj(flat_x)
|
||||
if str(x._type()) == OpaqueTypeStr:
|
||||
# In order to make OpaqueObjects truly opaque, the fake kernel should
|
||||
# not depend on the contents of the OpaqueObject at all.
|
||||
fake_x = FakeOpaqueObject()
|
||||
|
||||
with fake_mode:
|
||||
from torch._higher_order_ops.utils import _tensor_storage
|
||||
else:
|
||||
# x.__obj_flatten__() could be calling some tensor operations inside but we don't
|
||||
# want to call these ops in surrounding dispatch modes when executing it.
|
||||
# Otherwise, for example, the fake tensor modes will error out when the tensors inside
|
||||
# script object execute some operations like clone if allow_non_fake_input flag is set.
|
||||
with _disable_current_modes():
|
||||
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
|
||||
|
||||
storage_map = {
|
||||
_tensor_storage(inp): i
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor)
|
||||
}
|
||||
alias_map = {
|
||||
i: storage_map[_tensor_storage(inp)]
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor) and storage_map[_tensor_storage(inp)] != i
|
||||
}
|
||||
if len(alias_map) > 0:
|
||||
log.warning(
|
||||
"Detected script object %s has aliasing relationship among its tensors. "
|
||||
"Flattened obj: %s. Aliasing tensor indices: %s. "
|
||||
"This is not supported and may cause unexpected behavior.",
|
||||
x,
|
||||
_check_valid_flat_script_obj(flat_x)
|
||||
|
||||
with fake_mode:
|
||||
from torch._higher_order_ops.utils import _tensor_storage
|
||||
|
||||
storage_map = {
|
||||
_tensor_storage(inp): i
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor)
|
||||
}
|
||||
alias_map = {
|
||||
i: storage_map[_tensor_storage(inp)]
|
||||
for i, inp in enumerate(flat_x)
|
||||
if isinstance(inp, torch.Tensor)
|
||||
and storage_map[_tensor_storage(inp)] != i
|
||||
}
|
||||
if len(alias_map) > 0:
|
||||
log.warning(
|
||||
"Detected script object %s has aliasing relationship among its tensors. "
|
||||
"Flattened obj: %s. Aliasing tensor indices: %s. "
|
||||
"This is not supported and may cause unexpected behavior.",
|
||||
x,
|
||||
flat_x,
|
||||
alias_map,
|
||||
)
|
||||
|
||||
# This breaks the aliasing relationship among the tensors inside the torchbind object
|
||||
# This is bad but since we don't need to preserve the aliasing relationship anyway and
|
||||
# we state clearly that aliasing relationship is not preserved in the doc so this might be OK.
|
||||
fake_flattened = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda t: torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
layout=t.layout,
|
||||
),
|
||||
flat_x,
|
||||
alias_map,
|
||||
)
|
||||
|
||||
# This breaks the aliasing relationship among the tensors inside the torchbind object
|
||||
# This is bad but since we don't need to preserve the aliasing relationship anyway and
|
||||
# we state clearly that aliasing relationship is not preserved in the doc so this might be OK.
|
||||
fake_flattened = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda t: torch.empty_strided(
|
||||
t.size(),
|
||||
t.stride(),
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
requires_grad=t.requires_grad,
|
||||
layout=t.layout,
|
||||
),
|
||||
flat_x,
|
||||
)
|
||||
|
||||
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
|
||||
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
|
||||
|
||||
fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined]
|
||||
|
||||
@ -205,7 +215,7 @@ def maybe_to_fake_obj(
|
||||
FakeScriptMethod(fake_x_wrapped, name, method_schema),
|
||||
)
|
||||
else:
|
||||
override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"}
|
||||
override_skip_list = {"__obj_flatten__", "__getstate__", "__setstate__"}
|
||||
if name not in override_skip_list:
|
||||
log.warning("fake object of %s doesn't implement method %s.", x, name)
|
||||
return fake_x_wrapped
|
||||
|
@ -2,6 +2,21 @@ from typing import Any, NewType
|
||||
|
||||
import torch
|
||||
|
||||
from .fake_class_registry import FakeScriptObject, register_fake_class
|
||||
|
||||
|
||||
@register_fake_class("aten::OpaqueObject")
|
||||
class FakeOpaqueObject:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None:
|
||||
raise RuntimeError(
|
||||
"FakeOpaqueObject should not be created through __obj_unflatten__ "
|
||||
"and should be special handled. Please file an issue to Github."
|
||||
)
|
||||
|
||||
|
||||
OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject"
|
||||
|
||||
@ -80,6 +95,15 @@ def get_payload(opaque_object: torch._C.ScriptObject) -> Any:
|
||||
payload (Any): The Python object stored in the opaque object. This can
|
||||
be set with `set_payload()`.
|
||||
"""
|
||||
if isinstance(opaque_object, FakeScriptObject):
|
||||
raise ValueError(
|
||||
"get_payload: this function was called with a FakeScriptObject "
|
||||
"implying that you are calling get_payload inside of a fake kernel."
|
||||
"The fake kernel should not depend on the contents of the "
|
||||
"OpaqueObject at all, so we're erroring out. If you need this"
|
||||
"functionality, consider creating a custom TorchBind Object instead"
|
||||
"(but note that this is more difficult)."
|
||||
)
|
||||
if not (
|
||||
isinstance(opaque_object, torch._C.ScriptObject)
|
||||
and opaque_object._type().qualified_name() == OpaqueTypeStr
|
||||
@ -103,6 +127,16 @@ def set_payload(opaque_object: torch._C.ScriptObject, payload: Any) -> None:
|
||||
torch._C.ScriptObject: The opaque object that stores the given Python object.
|
||||
payload (Any): The Python object to store in the opaque object.
|
||||
"""
|
||||
if isinstance(opaque_object, FakeScriptObject):
|
||||
raise ValueError(
|
||||
"set_payload: this function was called with a FakeScriptObject "
|
||||
"implying that you are calling get_payload inside of a fake kernel."
|
||||
"The fake kernel should not depend on the contents of the "
|
||||
"OpaqueObject at all, so we're erroring out. If you need this"
|
||||
"functionality, consider creating a custom TorchBind Object instead"
|
||||
"(but note that this is more difficult)."
|
||||
)
|
||||
|
||||
if not (
|
||||
isinstance(opaque_object, torch._C.ScriptObject)
|
||||
and opaque_object._type().qualified_name() == OpaqueTypeStr
|
||||
|
Reference in New Issue
Block a user