[opaque obj] Allow non-effectful scriptobjs (#163714)

Fixes functionalization so that we can run ops using ScriptObjects w/o needing effects. Previously we would run into an error when running functionalization on the TorchBindOpOverloads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163714
Approved by: https://github.com/zou3519
ghstack dependencies: #163284
This commit is contained in:
angelayi
2025-11-12 09:25:51 -08:00
committed by PyTorch MergeBot
parent 35571fe94b
commit c9b09a31e8
8 changed files with 149 additions and 64 deletions

View File

@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
# This is not accurate since the queue could have tensors that are
# not rank 1
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
u0 = ctx.new_dynamic_size()
return torch.empty(u0)
self.lib._register_fake("queue_pop", pop_impl_fake)
@ -107,8 +107,7 @@ class TestOpaqueObject(TestCase):
@size_impl.register_fake
def size_impl_fake(q: torch._C.ScriptObject) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.create_unbacked_symint()
torch._check_is_size(u0)
u0 = ctx.new_dynamic_size()
return u0
super().setUp()

View File

@ -4,6 +4,8 @@ import random
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._functorch.aot_autograd import aot_export_module
from torch._library.effects import EffectType
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import register_opaque_type
from torch.fx.experimental.proxy_tensor import make_fx
@ -233,6 +235,65 @@ def forward(self, arg0_1, arg1_1):
):
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
def test_aot_export(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, rng_state, x):
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x * x
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
x = x + x
return (x,)
mod = Model()
rng = RNGState(0)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
# By default we don't register ops containing PyObjs as being effectful
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
return (add,)""", # noqa: B950
)
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
)
try:
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
# inputs: token, rng, x
# return: token, res
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
finally:
torch.library._register_effectful_op(
"_TestOpaqueObject::noisy_inject", None
)
instantiate_parametrized_tests(TestOpaqueObject)

View File

@ -24,6 +24,7 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
from torch._export.utils import _fakify_params_buffers
from torch._guards import Source
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import (
@ -946,7 +947,9 @@ def _fakify_script_objects(
try:
for obj, fqns in constant_attrs.items():
if torch._library.fake_class_registry._is_script_object(obj):
if torch._library.fake_class_registry._is_script_object(
obj
) or is_opaque_type(obj):
fake_script_obj = _maybe_fakify_obj(obj)
for fqn in fqns:
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)

View File

@ -8,6 +8,7 @@ from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._guards import detect_fake_mode
from torch._library.opaque_object import is_opaque_type
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -46,7 +47,7 @@ def process_inputs(
hint=x,
source=source,
)
if isinstance(x, torch.ScriptObject):
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
return torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, x
)

View File

@ -534,6 +534,7 @@ def create_aot_state(
stack.enter_context(autograd_fallback_mode("error"))
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
from torch._library.opaque_object import is_opaque_type
# Tracing may mutate the states the fake script object,
# so we need to duplicate the fake script objects so that subsequent tracing
@ -541,7 +542,7 @@ def create_aot_state(
def _dup_fake_script_obj(fake_flat_args):
return [
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
if isinstance(arg, FakeScriptObject)
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
else arg
for arg in fake_flat_args
]

View File

@ -91,6 +91,7 @@ from torch._inductor.utils import (
tensor_is_aligned,
)
from torch._library.fake_class_registry import FakeScriptObject
from torch._library.opaque_object import is_opaque_type
from torch._logging import trace_structured
from torch._utils_internal import compile_time_strobelight_meta
from torch.fx import GraphModule
@ -2747,7 +2748,9 @@ def _compile_fx_main(
node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True
)
elif isinstance(target, torch.ScriptObject):
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
type(target)
):
node.meta["val"] = (
torch._library.fake_class_registry.maybe_to_fake_obj(
fake_mode, target

View File

@ -1023,6 +1023,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
DispatchKey.BackendSelect,
DispatchKey.PythonTLSSnapshot,
DispatchKey.PythonDispatcher,
DispatchKey.Functionalize,
]
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):

View File

@ -11,7 +11,7 @@ import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
from torch._subclasses.meta_utils import is_sparse_any
from torch.utils._python_dispatch import (
_detect_infra_mode,
@ -504,65 +504,81 @@ class FunctionalTensorMode(TorchDispatchMode):
- FunctionalTensor._extra_dispatch_keys
)
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
if isinstance(func, TorchBindOpOverload):
# When the function is a TorchBindOpOverload, meaning some of the
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
# dispatch in python because C++ dispatcher will check the schema
# and cannot recognize FakeScriptObject.
ctx = PythonFunctionalizeAPI()
fully_unwrapped_args = ctx.unwrap_tensors(args)
fully_unwrapped_kwargs = ctx.unwrap_tensors(
kwargs # pyrefly: ignore[bad-argument-type]
)
outs_unwrapped = func(
*fully_unwrapped_args,
**fully_unwrapped_kwargs,
)
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
else:
# All we want to do here is reuse the existing C++ functionalization logic.
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
try:
# By default for python functionalization (for AOTAutograd), we reapply views.
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# Sometimes these functions cannot be directly dispatched to functionalize key
# because args are sometimes not functional tensors for some reason?
if func in FunctionalTensor.metadata_fns:
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
else:
# Note: [Functionalization View Replay Annotation]
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
# at the first time they are next used.
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
# to have the same annotation that the user specified on the original views. But view replay in
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
# for second_op!
#
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
# is globally set when we lazily perform view replay.
# The globally set metadata will be used to populate the fx node created for the replayed operation.
if m := torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.PROXY
):
for a in pytree.tree_leaves([args, kwargs]):
if not isinstance(a, FunctionalTensor):
continue
curr_node = m.tracer.tensor_tracker[
torch._from_functional_tensor(a.elem)
].proxy.node
with fx_traceback.set_current_replay_node(curr_node):
torch._sync(a)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
# back to PreDispatch later
outs_unwrapped = func._op_dk(
torch._C.DispatchKey.Functionalize,
*args_unwrapped,
**kwargs_unwrapped,
)
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
if self.export:
if func is torch.ops.aten.dropout.default:
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
outs_wrapped = pytree.tree_map_only(
torch.Tensor, wrap, outs_unwrapped
)
finally:
torch._disable_functionalization()
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize