Support aot_export torchbind op (#123370)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123370
Approved by: https://github.com/zou3519
ghstack dependencies: #123367
This commit is contained in:
ydwu4
2024-04-18 13:15:03 -07:00
committed by PyTorch MergeBot
parent e62169a8fa
commit 293f756cdc
3 changed files with 90 additions and 6 deletions

View File

@ -4,6 +4,7 @@ import unittest
import torch
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
@ -743,6 +744,63 @@ def forward(self, arg0_1, arg1_1):
)
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
def test_aot_export_tensor_queue_operators(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
x_sin = torch.ops._TorchScriptTesting.queue_pop(
tq
) - torch.ops._TorchScriptTesting.queue_size(tq)
x_cos = torch.ops._TorchScriptTesting.queue_pop(
tq
) + torch.ops._TorchScriptTesting.queue_size(tq)
return x_sin, x_cos, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_tq1 = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq1)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
# inputs: token, tq, x
# return: token, x_sin, x_cos, tq
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
cos = torch.ops.aten.cos.default(arg2_1)
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None
getitem = with_effects[0]; with_effects = None
sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None
getitem_4 = with_effects_2[0]
getitem_5 = with_effects_2[1]; with_effects_2 = None
with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None
getitem_6 = with_effects_3[0]; with_effects_3 = None
sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None
with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None
getitem_8 = with_effects_4[0]
getitem_9 = with_effects_4[1]; with_effects_4 = None
with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None
getitem_10 = with_effects_5[0]; with_effects_5 = None
add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None
return (getitem_10, sub, add, arg1_1)""", # noqa: B950
)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):

View File

@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._library.fake_class_registry import FakeScriptObject
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import py_sym_types
@ -23,6 +24,7 @@ KNOWN_TYPES = [
bool,
type(None),
*py_sym_types,
FakeScriptObject,
]
original_zip = zip

View File

@ -799,9 +799,6 @@ class OpOverload(OperatorBase):
# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
# when its inputs contain FakeScriptObject in a similar way as higher order ops.
class TorchBindOpOverload(OpOverload):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _fallthrough_keys(self) -> List[DispatchKey]:
# TODO: we should be calling the fallback for these, but a fallthrough is almost close
# enough to the fallback in most cases that we care about.
@ -811,6 +808,7 @@ class TorchBindOpOverload(OpOverload):
DispatchKey.AutogradCUDA,
DispatchKey.ADInplaceOrView,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
]
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
@ -830,14 +828,40 @@ class TorchBindOpOverload(OpOverload):
if _may_use_fallthrough_instead_of_fallback(key)
]
@contextlib.contextmanager
def _register_as_effectful_op_temporarily(self):
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
SIDE_EFFECTS,
)
try:
if self not in SIDE_EFFECTS:
_register_effectful_op(self, _EffectType.ORDERED)
yield
finally:
if self in SIDE_EFFECTS:
del SIDE_EFFECTS[self]
# use `self_` to avoid naming collide with arguments that
# are named "self". This way, they can be called by kwargs.
def __call__(self_, *args, **kwargs): # noqa: B902
if _must_dispatch_in_python(args, kwargs):
# When any inputs are FakeScriptObject, we need to
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher.
return self_._dispatch_in_python(args, kwargs, self_._fallthrough_keys())
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
# because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
#
# Note:
# 1. We only register the torchbind op temporarily as effectful op because we only want
# the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
# of the eagerly executing the op might change after tracing.
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
with self_._register_as_effectful_op_temporarily():
return self_._dispatch_in_python(
args, kwargs, self_._fallthrough_keys()
)
return self_._op(*args, **kwargs)
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):