Remove ts to export retracer (#156857)

Summary: This is probably not used anymore

Test Plan:
CI

Rollback Plan:

Reviewed By: SherlockNoMad

Differential Revision: D77318582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156857
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2025-06-27 01:54:24 +00:00
committed by PyTorch MergeBot
parent a4b59498c5
commit 8e8bbfc803
3 changed files with 1 additions and 229 deletions

View File

@ -9,7 +9,6 @@ import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._functorch.aot_autograd import aot_export_module
from torch.export import export, export_for_training
from torch.export._trace import _convert_ts_to_export_experimental
from torch.export.experimental import _export_forward_backward, _sticky_export
from torch.export.graph_signature import OutputKind
from torch.testing import FileCheck
@ -17,93 +16,6 @@ from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported")
class TestExperiment(TestCase):
def test_torchscript_module_export(self):
class M(torch.nn.Module):
def forward(self, x):
return x.cos() + x.sin()
model_to_trace = M()
inps = (torch.randn(4, 4),)
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
self.assertTrue(torch.allclose(exported_module(*inps), model_to_trace(*inps)))
def test_torchscript_module_export_single_input(self):
class M(torch.nn.Module):
def forward(self, x):
return x.cos() + x.sin()
model_to_trace = M()
inps = torch.randn(4, 4)
traced_module_by_torchscript = torch.jit.trace(M(), example_inputs=inps)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
self.assertTrue(torch.allclose(exported_module(inps), model_to_trace(inps)))
def test_torchscript_module_export_various_inputs_with_annotated_input_names(self):
def _check_equality_and_annotations(m_func, inps):
# Original module.
model_to_trace = m_func()
# ExportedProgram from TorchScript module.
traced_module_by_torchscript = torch.jit.trace(
m_func(), example_inputs=inps
)
exported_module = _convert_ts_to_export_experimental(
traced_module_by_torchscript, inps
)
# ExportedProgram from original module.
original_exported_module = torch.export.export_for_training(
m_func(), inps, strict=True
)
# Check whether input annotations are the same as tracing the original module.
orig_ph_name_list = [
n.name
for n in original_exported_module.graph.nodes
if n.op == "placeholder"
]
ph_name_list = [
n.name for n in exported_module.graph.nodes if n.op == "placeholder"
]
self.assertEqual(orig_ph_name_list, ph_name_list)
# Check results equality.
self.assertTrue(
torch.allclose(exported_module(*inps), model_to_trace(*inps))
)
# Tuple
class MTuple(torch.nn.Module):
def forward(self, x: Tuple[torch.Tensor]):
return x[0] + x[1]
_check_equality_and_annotations(MTuple, ((torch.randn(4), torch.randn(4)),))
# List
class MList(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
return x[0] + x[1]
_check_equality_and_annotations(MList, ([torch.randn(4), torch.randn(4)],))
# Dict
class MDict(torch.nn.Module):
def forward(self, x: Dict[str, torch.Tensor]):
return x["0"] + x["1"]
_check_equality_and_annotations(
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
)
def test_joint_basic(self) -> None:
class Module(torch.nn.Module):
def __init__(self) -> None:

View File

@ -98,7 +98,6 @@ from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
from ._safeguard import AutogradStateOpsFailSafeguard
from ._wrapper_utils import _WrapperModule
from .exported_program import (
_disable_prexisiting_fake_mode,
ExportedProgram,
@ -1408,44 +1407,6 @@ def _temp_disable_texpr_fuser():
torch._C._jit_set_texpr_fuser_enabled(original_state)
def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None):
with _temp_disable_texpr_fuser():
from torch.jit._trace import TopLevelTracedModule
export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs)
if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator]
return _export(
traced_callable,
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
elif isinstance(traced_callable, torch.ScriptMethod) and isinstance(
traced_callable.owner(), # type: ignore[operator]
(torch._C.ScriptModule, torch.nn.Module),
):
with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator]
return _export(
traced_callable.owner(), # type: ignore[operator]
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
else:
return _export(
_WrapperModule(traced_callable),
export_args,
export_kwargs,
strict=False,
_is_torch_jit_trace=True,
).module()
def _strict_export(
mod: torch.nn.Module,
args: tuple[Any, ...],

View File

@ -993,11 +993,7 @@ def trace(
stacklevel=2,
)
from torch._utils_internal import (
check_if_torch_exportable,
log_torch_jit_trace_exportability,
log_torchscript_usage,
)
from torch._utils_internal import log_torchscript_usage
traced_func = _trace_impl(
func,
@ -1014,103 +1010,6 @@ def trace(
_store_inputs,
)
log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
if check_if_torch_exportable():
from torch._export.converter import TS2EPConverter
from torch.export._trace import (
_convert_ts_to_export_experimental,
_process_jit_trace_inputs_for_export,
)
traced_func_for_export = _trace_impl(
func,
example_inputs=example_inputs,
optimize=optimize,
check_trace=False,
check_inputs=check_inputs,
check_tolerance=check_tolerance,
strict=strict,
_force_outplace=_force_outplace,
_module_class=_module_class,
_compilation_unit=_compilation_unit,
example_kwarg_inputs=example_kwarg_inputs,
_store_inputs=_store_inputs,
)
export_args, _ = _process_jit_trace_inputs_for_export(
example_inputs, example_kwarg_inputs
)
def _log_exportability(func_to_export, export_func, export_args, export_type):
try:
traced_result = func_to_export(*export_args)
except Exception as e:
_ = e
log_torch_jit_trace_exportability(
"trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
)
return
try:
ep_module = export_func(func_to_export, export_args)
except Exception as e:
log_torch_jit_trace_exportability(
"trace",
str(export_type),
str(_ExportOutcome.FAILED_TO_EXPORT),
str(e),
)
return
try:
export = ep_module(*export_args)
except Exception as e:
log_torch_jit_trace_exportability(
"trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e)
)
return
if not analyze_ts_result_with_export_result(export, traced_result):
log_torch_jit_trace_exportability(
"trace",
str(export_type),
str(_ExportOutcome.ACCURACY_ERROR),
"accuracy error",
)
return
log_torch_jit_trace_exportability(
"trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
)
def _direct_export_and_lower(func, export_args):
return torch.export.export(func, export_args, strict=False).module()
def _convert_ts_to_export_source_to_source(func, export_args):
return TS2EPConverter(func, export_args).convert().module()
# torch.jit.trace is noop when the original module is torch.jit.ScriptModule
if not isinstance(traced_func_for_export, torch.jit.ScriptModule):
_log_exportability(
traced_func_for_export,
_direct_export_and_lower,
export_args,
_ExportType.DIRECT_EXPORT,
)
_log_exportability(
traced_func_for_export,
_convert_ts_to_export_experimental,
export_args,
_ExportType.TRACE_AND_EXPORT,
)
_log_exportability(
traced_func_for_export,
_convert_ts_to_export_source_to_source,
export_args,
_ExportType.SOURCE_TO_SOURCE,
)
return traced_func