mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support aot_export torchbind op (#123370)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123370 Approved by: https://github.com/zou3519 ghstack dependencies: #123367
This commit is contained in:
@ -4,6 +4,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.export import export
|
||||
@ -743,6 +744,63 @@ def forward(self, arg0_1, arg1_1):
|
||||
)
|
||||
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
|
||||
|
||||
def test_aot_export_tensor_queue_operators(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, tq, x):
|
||||
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
|
||||
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
|
||||
x_sin = torch.ops._TorchScriptTesting.queue_pop(
|
||||
tq
|
||||
) - torch.ops._TorchScriptTesting.queue_size(tq)
|
||||
x_cos = torch.ops._TorchScriptTesting.queue_pop(
|
||||
tq
|
||||
) + torch.ops._TorchScriptTesting.queue_size(tq)
|
||||
return x_sin, x_cos, tq
|
||||
|
||||
mod = Model()
|
||||
|
||||
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
||||
fake_tq1 = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq1)
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
|
||||
|
||||
# inputs: token, tq, x
|
||||
# return: token, x_sin, x_cos, tq
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
cos = torch.ops.aten.cos.default(arg2_1)
|
||||
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
|
||||
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None
|
||||
getitem_4 = with_effects_2[0]
|
||||
getitem_5 = with_effects_2[1]; with_effects_2 = None
|
||||
with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None
|
||||
getitem_6 = with_effects_3[0]; with_effects_3 = None
|
||||
sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None
|
||||
with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None
|
||||
getitem_8 = with_effects_4[0]
|
||||
getitem_9 = with_effects_4[1]; with_effects_4 = None
|
||||
with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None
|
||||
getitem_10 = with_effects_5[0]; with_effects_5 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None
|
||||
return (getitem_10, sub, add, arg1_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||
class TestRegisterFakeClass(TestCase):
|
||||
|
@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
|
||||
@ -23,6 +24,7 @@ KNOWN_TYPES = [
|
||||
bool,
|
||||
type(None),
|
||||
*py_sym_types,
|
||||
FakeScriptObject,
|
||||
]
|
||||
|
||||
original_zip = zip
|
||||
|
@ -799,9 +799,6 @@ class OpOverload(OperatorBase):
|
||||
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
|
||||
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
|
||||
class TorchBindOpOverload(OpOverload):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _fallthrough_keys(self) -> List[DispatchKey]:
|
||||
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
|
||||
# enough to the fallback in most cases that we care about.
|
||||
@ -811,6 +808,7 @@ class TorchBindOpOverload(OpOverload):
|
||||
DispatchKey.AutogradCUDA,
|
||||
DispatchKey.ADInplaceOrView,
|
||||
DispatchKey.PythonTLSSnapshot,
|
||||
DispatchKey.PythonDispatcher,
|
||||
]
|
||||
|
||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
||||
@ -830,14 +828,40 @@ class TorchBindOpOverload(OpOverload):
|
||||
if _may_use_fallthrough_instead_of_fallback(key)
|
||||
]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _register_as_effectful_op_temporarily(self):
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
SIDE_EFFECTS,
|
||||
)
|
||||
|
||||
try:
|
||||
if self not in SIDE_EFFECTS:
|
||||
_register_effectful_op(self, _EffectType.ORDERED)
|
||||
yield
|
||||
finally:
|
||||
if self in SIDE_EFFECTS:
|
||||
del SIDE_EFFECTS[self]
|
||||
|
||||
# use `self_` to avoid naming collide with arguments that
|
||||
# are named "self". This way, they can be called by kwargs.
|
||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||
if _must_dispatch_in_python(args, kwargs):
|
||||
# When any inputs are FakeScriptObject, we need to
|
||||
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher.
|
||||
return self_._dispatch_in_python(args, kwargs, self_._fallthrough_keys())
|
||||
|
||||
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
|
||||
# because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
|
||||
#
|
||||
# Note:
|
||||
# 1. We only register the torchbind op temporarily as effectful op because we only want
|
||||
# the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
|
||||
# of the eagerly executing the op might change after tracing.
|
||||
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
|
||||
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
|
||||
with self_._register_as_effectful_op_temporarily():
|
||||
return self_._dispatch_in_python(
|
||||
args, kwargs, self_._fallthrough_keys()
|
||||
)
|
||||
return self_._op(*args, **kwargs)
|
||||
|
||||
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
|
||||
|
Reference in New Issue
Block a user