AOTI util deprecated flow using the new tracer (#165582)

Reapply of https://github.com/pytorch/pytorch/pull/163260

AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582
Approved by: https://github.com/anijain2305
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-17 10:07:13 -07:00
committed by PyTorch MergeBot
parent 1b121d636e
commit 22ae059d32
4 changed files with 84 additions and 54 deletions

View File

@ -50,7 +50,22 @@ def post_process_error_msg(
return constraint_violation_error
def clean_nn_module_stack(
EXPORT_ROOT_REPLACEMENTS = [
("__export_root_", "_"),
("_export_root.", ""),
("._export_root", ""),
]
def clean_export_root_string(text: str) -> str:
"""Generic utility to clean export_root patterns from strings."""
result = text
for pattern, replacement in EXPORT_ROOT_REPLACEMENTS:
result = result.replace(pattern, replacement)
return result
def clean_nn_module_stack_and_source_fn(
graph_module: torch.fx.GraphModule, is_inline_builtin=False
) -> torch.fx.GraphModule:
"""
@ -77,12 +92,8 @@ def clean_nn_module_stack(
Returns:
The cleaned GraphModule (modified in-place)
"""
for node in graph_module.graph.nodes:
if "nn_module_stack" not in node.meta:
continue
nn_module_stack = node.meta["nn_module_stack"].copy()
def _process_nn_module_stack(nn_module_stack):
if "L__self____export_root" in nn_module_stack:
del nn_module_stack["L__self____export_root"]
@ -90,22 +101,54 @@ def clean_nn_module_stack(
cleaned_stack = {}
for key, (child_name, child_class) in nn_module_stack.items():
# Clean key by removing export_root patterns
clean_key = key.replace("__modules['_export_root']_", "").replace(
"__export_root_", ""
)
clean_key = clean_export_root_string(key)
# Clean child_name by removing export_root patterns
clean_name = child_name.replace("._modules['_export_root']", "").replace(
"._export_root", ""
)
clean_name = clean_export_root_string(child_name)
# Skip self reference for inline builtin case
if is_inline_builtin and clean_name == "L['self']":
continue
cleaned_stack[clean_key] = (clean_name, child_class)
return cleaned_stack
node.meta["nn_module_stack"] = cleaned_stack
def _process_source_fn(source_fn_stack):
cleaned_stack = []
for item in source_fn_stack:
if isinstance(item, tuple) and len(item) == 2:
name, cls = item
if isinstance(name, str):
clean_name = clean_export_root_string(name)
cleaned_stack.append((clean_name, cls))
else:
cleaned_stack.append(item)
else:
cleaned_stack.append(item)
return cleaned_stack
for node in graph_module.graph.nodes:
if "nn_module_stack" in node.meta:
node.meta["nn_module_stack"] = _process_nn_module_stack(
node.meta["nn_module_stack"].copy()
)
if "source_fn_stack" in node.meta:
node.meta["source_fn_stack"] = _process_source_fn(
node.meta["source_fn_stack"].copy()
)
if "dynamo_flat_name_to_original_fqn" in graph_module.meta:
# Clean up flat name to original fqn mapping
clean_name_to_original_fqn = {}
for flat_name, original_fqn in graph_module.meta[
"dynamo_flat_name_to_original_fqn"
].items():
clean_name_to_original_fqn[clean_export_root_string(flat_name)] = (
clean_export_root_string(original_fqn)
)
graph_module.meta["dynamo_flat_name_to_original_fqn"] = (
clean_name_to_original_fqn
)
return graph_module
@ -113,14 +156,6 @@ def clean_nn_module_stack(
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
"""Remove export_root artifacts from FX graph in-place"""
# Clean parameter names: L__self____export_root_param -> L__self___param
def clean_name(name) -> str:
if "____modules___export_root_" in name:
return name.replace("____modules___export_root_", "_")
if "__export_root_" in name:
return name.replace("__export_root_", "_")
return name
# Unlike getattr node, call_module can be invoked multiple times
# In those cases, we should fix all invocations of call_module
clean_named_module_map: dict[str, str] = {}
@ -129,7 +164,7 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.nodes:
if node.op == "get_attr":
old_target = node.target
new_target = clean_name(old_target)
new_target = clean_export_root_string(old_target)
if new_target != old_target:
node.target = new_target
assert hasattr(graph_module, old_target)
@ -140,8 +175,10 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
# Dynamo will only have one nested level
if node.op == "call_module":
old_target = node.target
new_target = clean_name(old_target)
new_name = clean_name(node.name)
assert isinstance(old_target, str)
new_target = clean_export_root_string(old_target)
assert isinstance(new_target, str)
new_name = clean_export_root_string(node.name)
if new_target == old_target:
continue
@ -150,8 +187,6 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
node.target = clean_named_module_map[old_target]
node.name = new_name
continue
assert isinstance(old_target, str)
assert isinstance(new_target, str)
target = graph_module.get_submodule(old_target)
graph_module.delete_submodule(old_target)
graph_module.add_submodule(new_target, target)
@ -559,7 +594,7 @@ def _dynamo_graph_capture_for_export(
)
transformed_graph.recompile()
clean_nn_module_stack(
clean_nn_module_stack_and_source_fn(
transformed_graph, torch._dynamo.config.inline_inbuilt_nn_modules
)
clean_export_root(transformed_graph)

View File

@ -26,6 +26,7 @@ of module state.
import functools
import inspect
import itertools
import re
import types
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING
@ -113,6 +114,12 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
@contextmanager
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
fully_qualified_name = source.name()
# Remove redundant namings
fully_qualified_name = re.sub(
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]",
r".\2",
fully_qualified_name,
)
num_calls = tx.num_calls.get(fully_qualified_name, 0)
module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key
try:

View File

@ -357,24 +357,11 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
if add_root:
def normalize_path(path):
try:
parts = []
class Path:
def __getattr__(self, name):
if name != "_modules":
parts.append(name)
return self
def __getitem__(self, idx):
# pyrefly: ignore # bad-argument-type
parts.append(str(idx))
return self
eval(path, {"L": {"self": Path()}})
return ".".join(parts)
except Exception: # TODO(zhxchen17) Remove this.
return path
if path == "L['self']":
return ""
if path.startswith("L['self']."):
return path[len("L['self'].") :]
return path
nn_module_stack = {
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),