diff --git a/test/export/test_export.py b/test/export/test_export.py index 543325b55346..838d758090c6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -13592,6 +13592,30 @@ class GraphModule(torch.nn.Module): ) FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code) + def test_wrapper_module(self): + def f(x): + return torch.abs(x) + + from torch.export import _wrapper_utils + + model = _wrapper_utils._WrapperModule(f) + ep = export( + model, + ( + torch.randn( + 8, + ), + ), + ) + + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, args_0): + abs_1 = torch.ops.aten.abs.default(args_0); args_0 = None + return (abs_1,)""", + ) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index beb4a0c71836..b22a372179a7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3297,6 +3297,7 @@ MOD_INLINELIST = [ "torch.cuda.amp.autocast_mode", "torch.distributions", "torch.export._tree_utils", + "torch.export._wrapper_utils", "torch.fx._pytree", "torch.fx._symbolic_trace", "torch.fx.experimental.proxy_tensor", diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 72269acc2625..ab92d7a575ed 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -90,6 +90,7 @@ from torch.utils._pytree import TreeSpec from torch.utils._sympy.value_ranges import ValueRangeError from ._safeguard import AutogradStateOpsFailSafeguard +from ._wrapper_utils import _WrapperModule from .exported_program import ( _disable_prexisiting_fake_mode, ExportedProgram, @@ -1299,15 +1300,6 @@ def _temp_disable_texpr_fuser(): torch._C._jit_set_texpr_fuser_enabled(original_state) -class _WrapperModule(torch.nn.Module): - def __init__(self, f): - super().__init__() - self.f = f - - def forward(self, *args, **kwargs): - return self.f(*args, **kwargs) - - def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): with _temp_disable_texpr_fuser(): from torch.jit._trace import TopLevelTracedModule diff --git a/torch/export/_wrapper_utils.py b/torch/export/_wrapper_utils.py new file mode 100644 index 000000000000..bc27a8575a0a --- /dev/null +++ b/torch/export/_wrapper_utils.py @@ -0,0 +1,10 @@ +import torch + + +class _WrapperModule(torch.nn.Module): + def __init__(self, f): # type: ignore[no-untyped-def] + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self.f(*args, **kwargs)