mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: as title `export._trace._WrapperModule` is used to wrap functions into a Module so we can export the function. We add `export._wrapper_utils` to `dynamo`'s `MOD_INLINELIST` so dynamo traces into `_WrapperModule` Fixes https://github.com/pytorch/pytorch/issues/146867 Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test:test_export -- -r wrapper_module ``` Differential Revision: D72986826 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151264 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
61f127aac5
commit
83d88d128d
@ -13592,6 +13592,30 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
FileCheck().check_count(op_name, 1, exactly=True).run(ep.graph_module.code)
|
||||
|
||||
def test_wrapper_module(self):
|
||||
def f(x):
|
||||
return torch.abs(x)
|
||||
|
||||
from torch.export import _wrapper_utils
|
||||
|
||||
model = _wrapper_utils._WrapperModule(f)
|
||||
ep = export(
|
||||
model,
|
||||
(
|
||||
torch.randn(
|
||||
8,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, args_0):
|
||||
abs_1 = torch.ops.aten.abs.default(args_0); args_0 = None
|
||||
return (abs_1,)""",
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
@ -3297,6 +3297,7 @@ MOD_INLINELIST = [
|
||||
"torch.cuda.amp.autocast_mode",
|
||||
"torch.distributions",
|
||||
"torch.export._tree_utils",
|
||||
"torch.export._wrapper_utils",
|
||||
"torch.fx._pytree",
|
||||
"torch.fx._symbolic_trace",
|
||||
"torch.fx.experimental.proxy_tensor",
|
||||
|
@ -90,6 +90,7 @@ from torch.utils._pytree import TreeSpec
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
from ._safeguard import AutogradStateOpsFailSafeguard
|
||||
from ._wrapper_utils import _WrapperModule
|
||||
from .exported_program import (
|
||||
_disable_prexisiting_fake_mode,
|
||||
ExportedProgram,
|
||||
@ -1299,15 +1300,6 @@ def _temp_disable_texpr_fuser():
|
||||
torch._C._jit_set_texpr_fuser_enabled(original_state)
|
||||
|
||||
|
||||
class _WrapperModule(torch.nn.Module):
|
||||
def __init__(self, f):
|
||||
super().__init__()
|
||||
self.f = f
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.f(*args, **kwargs)
|
||||
|
||||
|
||||
def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
|
||||
with _temp_disable_texpr_fuser():
|
||||
from torch.jit._trace import TopLevelTracedModule
|
||||
|
10
torch/export/_wrapper_utils.py
Normal file
10
torch/export/_wrapper_utils.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torch
|
||||
|
||||
|
||||
class _WrapperModule(torch.nn.Module):
|
||||
def __init__(self, f): # type: ignore[no-untyped-def]
|
||||
super().__init__()
|
||||
self.f = f
|
||||
|
||||
def forward(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
return self.f(*args, **kwargs)
|
Reference in New Issue
Block a user