Move export_db to use new tracer, remove restriction on optional inputs (#162993)

Differential Revision: [D82478644](https://our.internmc.facebook.com/intern/diff/D82478644)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162993
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162557, #162558, #162559, #162682, #162992
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-16 14:42:35 -07:00
committed by PyTorch MergeBot
parent b26d4c9a7a
commit 72fedf0575
3 changed files with 20 additions and 16 deletions

View File

@ -4,6 +4,7 @@ import copy
import unittest
import torch._dynamo as torchdynamo
from torch._export import config
from torch._export.db.case import ExportCase, SupportLevel
from torch._export.db.examples import (
filter_examples_by_support_level,
@ -35,13 +36,14 @@ class ExampleTests(TestCase):
kwargs_export = case.example_kwargs
args_model = copy.deepcopy(args_export)
kwargs_model = copy.deepcopy(kwargs_export)
exported_program = export(
model,
args_export,
kwargs_export,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
with config.patch(use_new_tracer_experimental=True):
exported_program = export(
model,
case.example_args,
case.example_kwargs,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
exported_program.graph_module.print_readable()
self.assertEqual(
@ -68,13 +70,14 @@ class ExampleTests(TestCase):
with self.assertRaises(
(torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
):
export(
model,
case.example_args,
case.example_kwargs,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
with config.patch(use_new_tracer_experimental=True):
_ = export(
model,
case.example_args,
case.example_kwargs,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
exportdb_not_supported_rewrite_cases = [
(name, rewrite_case)

View File

@ -1612,7 +1612,8 @@ def forward(self, x):
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
_check_meta = "map" not in name
self.check_graph(model, case.example_args, _check_meta=_check_meta)
with torch._export.config.patch(use_new_tracer_experimental=True):
self.check_graph(model, case.example_args, _check_meta=_check_meta)
def test_constraints(self):
class Module(torch.nn.Module):

View File

@ -16,5 +16,5 @@ class OptionalInput(torch.nn.Module):
example_args = (torch.randn(2, 3),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
support_level = SupportLevel.SUPPORTED
model = OptionalInput()