[reland] Make export._trace._WrapperModule work in strict mode (#146919) (#151264)

Summary:

as title

`export._trace._WrapperModule` is used to wrap functions into a Module so we can export the function.

We add `export._wrapper_utils` to `dynamo`'s `MOD_INLINELIST` so dynamo traces into `_WrapperModule`

Fixes https://github.com/pytorch/pytorch/issues/146867

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test:test_export -- -r wrapper_module
```

Differential Revision: D72986826

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151264
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-04-15 18:35:34 +00:00
committed by PyTorch MergeBot
parent 61f127aac5
commit 83d88d128d
4 changed files with 36 additions and 9 deletions

View File

@ -13592,6 +13592,30 @@ class GraphModule(torch.nn.Module):
)
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
def test_wrapper_module(self):
def f(x):
return torch.abs(x)
from torch.export import _wrapper_utils
model = _wrapper_utils._WrapperModule(f)
ep = export(
model,
(
torch.randn(
8,
),
),
)
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
def forward(self, args_0):
abs_1 = torch.ops.aten.abs.default(args_0); args_0 = None
return (abs_1,)""",
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):

View File

@ -3297,6 +3297,7 @@ MOD_INLINELIST = [
"torch.cuda.amp.autocast_mode",
"torch.distributions",
"torch.export._tree_utils",
"torch.export._wrapper_utils",
"torch.fx._pytree",
"torch.fx._symbolic_trace",
"torch.fx.experimental.proxy_tensor",

View File

@ -90,6 +90,7 @@ from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
from ._safeguard import AutogradStateOpsFailSafeguard
from ._wrapper_utils import _WrapperModule
from .exported_program import (
_disable_prexisiting_fake_mode,
ExportedProgram,
@ -1299,15 +1300,6 @@ def _temp_disable_texpr_fuser():
torch._C._jit_set_texpr_fuser_enabled(original_state)
class _WrapperModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f
def forward(self, *args, **kwargs):
return self.f(*args, **kwargs)
def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
with _temp_disable_texpr_fuser():
from torch.jit._trace import TopLevelTracedModule

View File

@ -0,0 +1,10 @@
import torch
class _WrapperModule(torch.nn.Module):
def __init__(self, f): # type: ignore[no-untyped-def]
super().__init__()
self.f = f
def forward(self, *args, **kwargs): # type: ignore[no-untyped-def]
return self.f(*args, **kwargs)