mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
[opaque obj] Allow non-effectful scriptobjs (#163714)
Fixes functionalization so that we can run ops using ScriptObjects w/o needing effects. Previously we would run into an error when running functionalization on the TorchBindOpOverloads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163714 Approved by: https://github.com/zou3519 ghstack dependencies: #163284
This commit is contained in:
committed by
PyTorch MergeBot
parent
35571fe94b
commit
c9b09a31e8
@ -90,7 +90,7 @@ class TestOpaqueObject(TestCase):
|
||||
# This is not accurate since the queue could have tensors that are
|
||||
# not rank 1
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return torch.empty(u0)
|
||||
|
||||
self.lib._register_fake("queue_pop", pop_impl_fake)
|
||||
@ -107,8 +107,7 @@ class TestOpaqueObject(TestCase):
|
||||
@size_impl.register_fake
|
||||
def size_impl_fake(q: torch._C.ScriptObject) -> int:
|
||||
ctx = torch._custom_op.impl.get_ctx()
|
||||
u0 = ctx.create_unbacked_symint()
|
||||
torch._check_is_size(u0)
|
||||
u0 = ctx.new_dynamic_size()
|
||||
return u0
|
||||
|
||||
super().setUp()
|
||||
|
||||
@ -4,6 +4,8 @@ import random
|
||||
|
||||
import torch
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._library.effects import EffectType
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import register_opaque_type
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -233,6 +235,65 @@ def forward(self, arg0_1, arg1_1):
|
||||
):
|
||||
make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3))
|
||||
|
||||
def test_aot_export(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, rng_state, x):
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x * x
|
||||
x = torch.ops._TestOpaqueObject.noisy_inject(x, rng_state)
|
||||
x = x + x
|
||||
return (x,)
|
||||
|
||||
mod = Model()
|
||||
rng = RNGState(0)
|
||||
x = torch.ones(2, 3)
|
||||
|
||||
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
||||
fake_rng = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, rng)
|
||||
fake_x = fake_mode.from_tensor(x)
|
||||
gm = aot_export_module(mod, (fake_rng, fake_x), trace_joint=False)[0]
|
||||
|
||||
# By default we don't register ops containing PyObjs as being effectful
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(arg1_1, arg0_1); arg1_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None
|
||||
noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None
|
||||
add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", EffectType.ORDERED
|
||||
)
|
||||
try:
|
||||
gm = aot_export_module(mod, (rng, fake_x), trace_joint=False)[0]
|
||||
# inputs: token, rng, x
|
||||
# return: token, res
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TestOpaqueObject.noisy_inject.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
|
||||
getitem = with_effects[0]
|
||||
getitem_1 = with_effects[1]; with_effects = None
|
||||
mul = torch.ops.aten.mul.Tensor(getitem_1, getitem_1); getitem_1 = None
|
||||
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TestOpaqueObject.noisy_inject.default, mul, arg1_1); getitem = mul = arg1_1 = None
|
||||
getitem_2 = with_effects_1[0]
|
||||
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None
|
||||
return (getitem_2, add)""", # noqa: B950
|
||||
)
|
||||
finally:
|
||||
torch.library._register_effectful_op(
|
||||
"_TestOpaqueObject::noisy_inject", None
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOpaqueObject)
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from torch._export.passes.lift_constants_pass import ConstantAttrMap
|
||||
from torch._export.utils import _fakify_params_buffers
|
||||
from torch._guards import Source
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.export import Constraint
|
||||
from torch.export.dynamic_shapes import (
|
||||
@ -946,7 +947,9 @@ def _fakify_script_objects(
|
||||
|
||||
try:
|
||||
for obj, fqns in constant_attrs.items():
|
||||
if torch._library.fake_class_registry._is_script_object(obj):
|
||||
if torch._library.fake_class_registry._is_script_object(
|
||||
obj
|
||||
) or is_opaque_type(obj):
|
||||
fake_script_obj = _maybe_fakify_obj(obj)
|
||||
for fqn in fqns:
|
||||
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -46,7 +47,7 @@ def process_inputs(
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
if isinstance(x, torch.ScriptObject):
|
||||
if isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)):
|
||||
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, x
|
||||
)
|
||||
|
||||
@ -534,6 +534,7 @@ def create_aot_state(
|
||||
stack.enter_context(autograd_fallback_mode("error"))
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
|
||||
# Tracing may mutate the states the fake script object,
|
||||
# so we need to duplicate the fake script objects so that subsequent tracing
|
||||
@ -541,7 +542,7 @@ def create_aot_state(
|
||||
def _dup_fake_script_obj(fake_flat_args):
|
||||
return [
|
||||
maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj)
|
||||
if isinstance(arg, FakeScriptObject)
|
||||
if isinstance(arg, FakeScriptObject) or is_opaque_type(type(arg))
|
||||
else arg
|
||||
for arg in fake_flat_args
|
||||
]
|
||||
|
||||
@ -91,6 +91,7 @@ from torch._inductor.utils import (
|
||||
tensor_is_aligned,
|
||||
)
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._library.opaque_object import is_opaque_type
|
||||
from torch._logging import trace_structured
|
||||
from torch._utils_internal import compile_time_strobelight_meta
|
||||
from torch.fx import GraphModule
|
||||
@ -2747,7 +2748,9 @@ def _compile_fx_main(
|
||||
node.meta["val"] = fake_mode.from_tensor(
|
||||
target, static_shapes=True
|
||||
)
|
||||
elif isinstance(target, torch.ScriptObject):
|
||||
elif isinstance(target, torch.ScriptObject) or is_opaque_type(
|
||||
type(target)
|
||||
):
|
||||
node.meta["val"] = (
|
||||
torch._library.fake_class_registry.maybe_to_fake_obj(
|
||||
fake_mode, target
|
||||
|
||||
@ -1023,6 +1023,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||
DispatchKey.BackendSelect,
|
||||
DispatchKey.PythonTLSSnapshot,
|
||||
DispatchKey.PythonDispatcher,
|
||||
DispatchKey.Functionalize,
|
||||
]
|
||||
|
||||
def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import _functionalization_reapply_views_tls as _reapply_views
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch
|
||||
from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.utils._python_dispatch import (
|
||||
_detect_infra_mode,
|
||||
@ -504,65 +504,81 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
- FunctionalTensor._extra_dispatch_keys
|
||||
)
|
||||
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
if isinstance(func, TorchBindOpOverload):
|
||||
# When the function is a TorchBindOpOverload, meaning some of the
|
||||
# inputs are FakeScriptObjects, we need to skip c++ dispatcher and
|
||||
# dispatch in python because C++ dispatcher will check the schema
|
||||
# and cannot recognize FakeScriptObject.
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
fully_unwrapped_args = ctx.unwrap_tensors(args)
|
||||
fully_unwrapped_kwargs = ctx.unwrap_tensors(
|
||||
kwargs # pyrefly: ignore[bad-argument-type]
|
||||
)
|
||||
outs_unwrapped = func(
|
||||
*fully_unwrapped_args,
|
||||
**fully_unwrapped_kwargs,
|
||||
)
|
||||
outs_wrapped = ctx.wrap_tensors(outs_unwrapped)
|
||||
else:
|
||||
# All we want to do here is reuse the existing C++ functionalization logic.
|
||||
# This requires swizzling our TLS dispatch keys so that the Functionalize key is active.
|
||||
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
||||
try:
|
||||
# By default for python functionalization (for AOTAutograd), we reapply views.
|
||||
old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined]
|
||||
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
# Sometimes these functions cannot be directly dispatched to functionalize key
|
||||
# because args are sometimes not functional tensors for some reason?
|
||||
if func in FunctionalTensor.metadata_fns:
|
||||
outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped)
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
else:
|
||||
# Note: [Functionalization View Replay Annotation]
|
||||
# When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases
|
||||
# at the first time they are next used.
|
||||
# This is a problem when plumbing user annotations during tracing. We want the view ops from view replay
|
||||
# to have the same annotation that the user specified on the original views. But view replay in
|
||||
# functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)),
|
||||
# so when we regenerate views before calling into second_op, those views will end up getting the metadata
|
||||
# for second_op!
|
||||
#
|
||||
# Instead, we need to remember the node metadata from the original views, and ensure that this node metadata
|
||||
# is globally set when we lazily perform view replay.
|
||||
# The globally set metadata will be used to populate the fx node created for the replayed operation.
|
||||
if m := torch._C._get_dispatch_mode(
|
||||
torch._C._TorchDispatchModeKey.PROXY
|
||||
):
|
||||
for a in pytree.tree_leaves([args, kwargs]):
|
||||
if not isinstance(a, FunctionalTensor):
|
||||
continue
|
||||
curr_node = m.tracer.tensor_tracker[
|
||||
torch._from_functional_tensor(a.elem)
|
||||
].proxy.node
|
||||
with fx_traceback.set_current_replay_node(curr_node):
|
||||
torch._sync(a)
|
||||
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
# When we dispatch to the C++ functionalization kernel, we might need to jump back to the
|
||||
# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath
|
||||
# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch
|
||||
# from the TLS in order to avoid infinite looping, but this would prevent us from coming
|
||||
# back to PreDispatch later
|
||||
outs_unwrapped = func._op_dk(
|
||||
torch._C.DispatchKey.Functionalize,
|
||||
*args_unwrapped,
|
||||
**kwargs_unwrapped,
|
||||
)
|
||||
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
if self.export:
|
||||
if func is torch.ops.aten.dropout.default:
|
||||
torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined]
|
||||
outs_wrapped = pytree.tree_map_only(
|
||||
torch.Tensor, wrap, outs_unwrapped
|
||||
)
|
||||
finally:
|
||||
torch._disable_functionalization()
|
||||
torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined]
|
||||
|
||||
is_included = torch._C._dispatch_tls_is_dispatch_key_included(
|
||||
torch._C.DispatchKey.Functionalize
|
||||
|
||||
Reference in New Issue
Block a user