mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove ts to export retracer (#156857)
Summary: This is probably not used anymore Test Plan: CI Rollback Plan: Reviewed By: SherlockNoMad Differential Revision: D77318582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156857 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4b59498c5
commit
8e8bbfc803
@ -9,7 +9,6 @@ import torch._dynamo
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.export import export, export_for_training
|
||||
from torch.export._trace import _convert_ts_to_export_experimental
|
||||
from torch.export.experimental import _export_forward_backward, _sticky_export
|
||||
from torch.export.graph_signature import OutputKind
|
||||
from torch.testing import FileCheck
|
||||
@ -17,93 +16,6 @@ from torch.testing import FileCheck
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
|
||||
class TestExperiment(TestCase):
|
||||
def test_torchscript_module_export(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.cos() + x.sin()
|
||||
|
||||
model_to_trace = M()
|
||||
inps = (torch.randn(4, 4),)
|
||||
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
||||
|
||||
exported_module = _convert_ts_to_export_experimental(
|
||||
traced_module_by_torchscript, inps
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
|
||||
|
||||
def test_torchscript_module_export_single_input(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.cos() + x.sin()
|
||||
|
||||
model_to_trace = M()
|
||||
inps = torch.randn(4, 4)
|
||||
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
|
||||
|
||||
exported_module = _convert_ts_to_export_experimental(
|
||||
traced_module_by_torchscript, inps
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
|
||||
|
||||
def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
|
||||
def _check_equality_and_annotations(m_func, inps):
|
||||
# Original module.
|
||||
model_to_trace = m_func()
|
||||
|
||||
# ExportedProgram from TorchScript module.
|
||||
traced_module_by_torchscript = torch.jit.trace(
|
||||
m_func(), example_inputs=inps
|
||||
)
|
||||
exported_module = _convert_ts_to_export_experimental(
|
||||
traced_module_by_torchscript, inps
|
||||
)
|
||||
|
||||
# ExportedProgram from original module.
|
||||
original_exported_module = torch.export.export_for_training(
|
||||
m_func(), inps, strict=True
|
||||
)
|
||||
|
||||
# Check whether input annotations are the same as tracing the original module.
|
||||
orig_ph_name_list = [
|
||||
n.name
|
||||
for n in original_exported_module.graph.nodes
|
||||
if n.op == "placeholder"
|
||||
]
|
||||
ph_name_list = [
|
||||
n.name for n in exported_module.graph.nodes if n.op == "placeholder"
|
||||
]
|
||||
self.assertEqual(orig_ph_name_list, ph_name_list)
|
||||
|
||||
# Check results equality.
|
||||
self.assertTrue(
|
||||
torch.allclose(exported_module(*inps), model_to_trace(*inps))
|
||||
)
|
||||
|
||||
# Tuple
|
||||
class MTuple(torch.nn.Module):
|
||||
def forward(self, x: Tuple[torch.Tensor]):
|
||||
return x[0] + x[1]
|
||||
|
||||
_check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
|
||||
|
||||
# List
|
||||
class MList(torch.nn.Module):
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
return x[0] + x[1]
|
||||
|
||||
_check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
|
||||
|
||||
# Dict
|
||||
class MDict(torch.nn.Module):
|
||||
def forward(self, x: Dict[str, torch.Tensor]):
|
||||
return x["0"] + x["1"]
|
||||
|
||||
_check_equality_and_annotations(
|
||||
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
|
||||
)
|
||||
|
||||
def test_joint_basic(self) -> None:
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -98,7 +98,6 @@ from torch.utils._pytree import TreeSpec
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
from ._safeguard import AutogradStateOpsFailSafeguard
|
||||
from ._wrapper_utils import _WrapperModule
|
||||
from .exported_program import (
|
||||
_disable_prexisiting_fake_mode,
|
||||
ExportedProgram,
|
||||
@ -1408,44 +1407,6 @@ def _temp_disable_texpr_fuser():
|
||||
torch._C._jit_set_texpr_fuser_enabled(original_state)
|
||||
|
||||
|
||||
def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
|
||||
with _temp_disable_texpr_fuser():
|
||||
from torch.jit._trace import TopLevelTracedModule
|
||||
|
||||
export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs)
|
||||
|
||||
if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator]
|
||||
return _export(
|
||||
traced_callable,
|
||||
export_args,
|
||||
export_kwargs,
|
||||
strict=False,
|
||||
_is_torch_jit_trace=True,
|
||||
).module()
|
||||
|
||||
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
|
||||
traced_callable.owner(), # type: ignore[operator]
|
||||
(torch._C.ScriptModule, torch.nn.Module),
|
||||
):
|
||||
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
|
||||
return _export(
|
||||
traced_callable.owner(), # type: ignore[operator]
|
||||
export_args,
|
||||
export_kwargs,
|
||||
strict=False,
|
||||
_is_torch_jit_trace=True,
|
||||
).module()
|
||||
|
||||
else:
|
||||
return _export(
|
||||
_WrapperModule(traced_callable),
|
||||
export_args,
|
||||
export_kwargs,
|
||||
strict=False,
|
||||
_is_torch_jit_trace=True,
|
||||
).module()
|
||||
|
||||
|
||||
def _strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
|
@ -993,11 +993,7 @@ def trace(
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from torch._utils_internal import (
|
||||
check_if_torch_exportable,
|
||||
log_torch_jit_trace_exportability,
|
||||
log_torchscript_usage,
|
||||
)
|
||||
from torch._utils_internal import log_torchscript_usage
|
||||
|
||||
traced_func = _trace_impl(
|
||||
func,
|
||||
@ -1014,103 +1010,6 @@ def trace(
|
||||
_store_inputs,
|
||||
)
|
||||
log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
|
||||
|
||||
if check_if_torch_exportable():
|
||||
from torch._export.converter import TS2EPConverter
|
||||
from torch.export._trace import (
|
||||
_convert_ts_to_export_experimental,
|
||||
_process_jit_trace_inputs_for_export,
|
||||
)
|
||||
|
||||
traced_func_for_export = _trace_impl(
|
||||
func,
|
||||
example_inputs=example_inputs,
|
||||
optimize=optimize,
|
||||
check_trace=False,
|
||||
check_inputs=check_inputs,
|
||||
check_tolerance=check_tolerance,
|
||||
strict=strict,
|
||||
_force_outplace=_force_outplace,
|
||||
_module_class=_module_class,
|
||||
_compilation_unit=_compilation_unit,
|
||||
example_kwarg_inputs=example_kwarg_inputs,
|
||||
_store_inputs=_store_inputs,
|
||||
)
|
||||
|
||||
export_args, _ = _process_jit_trace_inputs_for_export(
|
||||
example_inputs, example_kwarg_inputs
|
||||
)
|
||||
|
||||
def _log_exportability(func_to_export, export_func, export_args, export_type):
|
||||
try:
|
||||
traced_result = func_to_export(*export_args)
|
||||
except Exception as e:
|
||||
_ = e
|
||||
log_torch_jit_trace_exportability(
|
||||
"trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
ep_module = export_func(func_to_export, export_args)
|
||||
except Exception as e:
|
||||
log_torch_jit_trace_exportability(
|
||||
"trace",
|
||||
str(export_type),
|
||||
str(_ExportOutcome.FAILED_TO_EXPORT),
|
||||
str(e),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
export = ep_module(*export_args)
|
||||
except Exception as e:
|
||||
log_torch_jit_trace_exportability(
|
||||
"trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e)
|
||||
)
|
||||
return
|
||||
|
||||
if not analyze_ts_result_with_export_result(export, traced_result):
|
||||
log_torch_jit_trace_exportability(
|
||||
"trace",
|
||||
str(export_type),
|
||||
str(_ExportOutcome.ACCURACY_ERROR),
|
||||
"accuracy error",
|
||||
)
|
||||
return
|
||||
|
||||
log_torch_jit_trace_exportability(
|
||||
"trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
|
||||
)
|
||||
|
||||
def _direct_export_and_lower(func, export_args):
|
||||
return torch.export.export(func, export_args, strict=False).module()
|
||||
|
||||
def _convert_ts_to_export_source_to_source(func, export_args):
|
||||
return TS2EPConverter(func, export_args).convert().module()
|
||||
|
||||
# torch.jit.trace is noop when the original module is torch.jit.ScriptModule
|
||||
if not isinstance(traced_func_for_export, torch.jit.ScriptModule):
|
||||
_log_exportability(
|
||||
traced_func_for_export,
|
||||
_direct_export_and_lower,
|
||||
export_args,
|
||||
_ExportType.DIRECT_EXPORT,
|
||||
)
|
||||
|
||||
_log_exportability(
|
||||
traced_func_for_export,
|
||||
_convert_ts_to_export_experimental,
|
||||
export_args,
|
||||
_ExportType.TRACE_AND_EXPORT,
|
||||
)
|
||||
_log_exportability(
|
||||
traced_func_for_export,
|
||||
_convert_ts_to_export_source_to_source,
|
||||
export_args,
|
||||
_ExportType.SOURCE_TO_SOURCE,
|
||||
)
|
||||
|
||||
return traced_func
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user