mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move export_db to use new tracer, remove restriction on optional inputs (#162993)
Differential Revision: [D82478644](https://our.internmc.facebook.com/intern/diff/D82478644) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162993 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162557, #162558, #162559, #162682, #162992
This commit is contained in:
committed by
PyTorch MergeBot
parent
b26d4c9a7a
commit
72fedf0575
@ -4,6 +4,7 @@ import copy
|
||||
import unittest
|
||||
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch._export import config
|
||||
from torch._export.db.case import ExportCase, SupportLevel
|
||||
from torch._export.db.examples import (
|
||||
filter_examples_by_support_level,
|
||||
@ -35,13 +36,14 @@ class ExampleTests(TestCase):
|
||||
kwargs_export = case.example_kwargs
|
||||
args_model = copy.deepcopy(args_export)
|
||||
kwargs_model = copy.deepcopy(kwargs_export)
|
||||
exported_program = export(
|
||||
model,
|
||||
args_export,
|
||||
kwargs_export,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
exported_program = export(
|
||||
model,
|
||||
case.example_args,
|
||||
case.example_kwargs,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
exported_program.graph_module.print_readable()
|
||||
|
||||
self.assertEqual(
|
||||
@ -68,13 +70,14 @@ class ExampleTests(TestCase):
|
||||
with self.assertRaises(
|
||||
(torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
|
||||
):
|
||||
export(
|
||||
model,
|
||||
case.example_args,
|
||||
case.example_kwargs,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
_ = export(
|
||||
model,
|
||||
case.example_args,
|
||||
case.example_kwargs,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
exportdb_not_supported_rewrite_cases = [
|
||||
(name, rewrite_case)
|
||||
|
@ -1612,7 +1612,8 @@ def forward(self, x):
|
||||
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
|
||||
model = case.model
|
||||
_check_meta = "map" not in name
|
||||
self.check_graph(model, case.example_args, _check_meta=_check_meta)
|
||||
with torch._export.config.patch(use_new_tracer_experimental=True):
|
||||
self.check_graph(model, case.example_args, _check_meta=_check_meta)
|
||||
|
||||
def test_constraints(self):
|
||||
class Module(torch.nn.Module):
|
||||
|
@ -16,5 +16,5 @@ class OptionalInput(torch.nn.Module):
|
||||
|
||||
example_args = (torch.randn(2, 3),)
|
||||
tags = {"python.object-model"}
|
||||
support_level = SupportLevel.NOT_SUPPORTED_YET
|
||||
support_level = SupportLevel.SUPPORTED
|
||||
model = OptionalInput()
|
||||
|
Reference in New Issue
Block a user