mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
AOTI util deprecated flow using the new tracer (#165582)
Reapply of https://github.com/pytorch/pytorch/pull/163260 AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b121d636e
commit
22ae059d32
@ -52,15 +52,16 @@ class AOTIRunnerUtil:
|
||||
)
|
||||
gm = ep.module()
|
||||
else:
|
||||
gm = torch.export._trace._export_to_torch_ir(
|
||||
model,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
# Disabling this flag, because instead we can rely on the mapping
|
||||
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
|
||||
restore_fqn=False,
|
||||
)
|
||||
with torch._export.config.patch(use_new_tracer_experimental=True):
|
||||
gm = torch.export._trace._export_to_torch_ir(
|
||||
model,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
# Disabling this flag, because instead we can rely on the mapping
|
||||
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
|
||||
restore_fqn=False,
|
||||
)
|
||||
|
||||
if IS_FBCODE:
|
||||
from deeplearning.aot_inductor.extern_node_thrift_serializer import (
|
||||
|
@ -50,7 +50,22 @@ def post_process_error_msg(
|
||||
return constraint_violation_error
|
||||
|
||||
|
||||
def clean_nn_module_stack(
|
||||
EXPORT_ROOT_REPLACEMENTS = [
|
||||
("__export_root_", "_"),
|
||||
("_export_root.", ""),
|
||||
("._export_root", ""),
|
||||
]
|
||||
|
||||
|
||||
def clean_export_root_string(text: str) -> str:
|
||||
"""Generic utility to clean export_root patterns from strings."""
|
||||
result = text
|
||||
for pattern, replacement in EXPORT_ROOT_REPLACEMENTS:
|
||||
result = result.replace(pattern, replacement)
|
||||
return result
|
||||
|
||||
|
||||
def clean_nn_module_stack_and_source_fn(
|
||||
graph_module: torch.fx.GraphModule, is_inline_builtin=False
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
@ -77,12 +92,8 @@ def clean_nn_module_stack(
|
||||
Returns:
|
||||
The cleaned GraphModule (modified in-place)
|
||||
"""
|
||||
for node in graph_module.graph.nodes:
|
||||
if "nn_module_stack" not in node.meta:
|
||||
continue
|
||||
|
||||
nn_module_stack = node.meta["nn_module_stack"].copy()
|
||||
|
||||
def _process_nn_module_stack(nn_module_stack):
|
||||
if "L__self____export_root" in nn_module_stack:
|
||||
del nn_module_stack["L__self____export_root"]
|
||||
|
||||
@ -90,22 +101,54 @@ def clean_nn_module_stack(
|
||||
cleaned_stack = {}
|
||||
for key, (child_name, child_class) in nn_module_stack.items():
|
||||
# Clean key by removing export_root patterns
|
||||
clean_key = key.replace("__modules['_export_root']_", "").replace(
|
||||
"__export_root_", ""
|
||||
)
|
||||
clean_key = clean_export_root_string(key)
|
||||
|
||||
# Clean child_name by removing export_root patterns
|
||||
clean_name = child_name.replace("._modules['_export_root']", "").replace(
|
||||
"._export_root", ""
|
||||
)
|
||||
clean_name = clean_export_root_string(child_name)
|
||||
|
||||
# Skip self reference for inline builtin case
|
||||
if is_inline_builtin and clean_name == "L['self']":
|
||||
continue
|
||||
|
||||
cleaned_stack[clean_key] = (clean_name, child_class)
|
||||
return cleaned_stack
|
||||
|
||||
node.meta["nn_module_stack"] = cleaned_stack
|
||||
def _process_source_fn(source_fn_stack):
|
||||
cleaned_stack = []
|
||||
for item in source_fn_stack:
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
name, cls = item
|
||||
if isinstance(name, str):
|
||||
clean_name = clean_export_root_string(name)
|
||||
cleaned_stack.append((clean_name, cls))
|
||||
else:
|
||||
cleaned_stack.append(item)
|
||||
else:
|
||||
cleaned_stack.append(item)
|
||||
return cleaned_stack
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
if "nn_module_stack" in node.meta:
|
||||
node.meta["nn_module_stack"] = _process_nn_module_stack(
|
||||
node.meta["nn_module_stack"].copy()
|
||||
)
|
||||
if "source_fn_stack" in node.meta:
|
||||
node.meta["source_fn_stack"] = _process_source_fn(
|
||||
node.meta["source_fn_stack"].copy()
|
||||
)
|
||||
|
||||
if "dynamo_flat_name_to_original_fqn" in graph_module.meta:
|
||||
# Clean up flat name to original fqn mapping
|
||||
clean_name_to_original_fqn = {}
|
||||
for flat_name, original_fqn in graph_module.meta[
|
||||
"dynamo_flat_name_to_original_fqn"
|
||||
].items():
|
||||
clean_name_to_original_fqn[clean_export_root_string(flat_name)] = (
|
||||
clean_export_root_string(original_fqn)
|
||||
)
|
||||
graph_module.meta["dynamo_flat_name_to_original_fqn"] = (
|
||||
clean_name_to_original_fqn
|
||||
)
|
||||
|
||||
return graph_module
|
||||
|
||||
@ -113,14 +156,6 @@ def clean_nn_module_stack(
|
||||
def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
"""Remove export_root artifacts from FX graph in-place"""
|
||||
|
||||
# Clean parameter names: L__self____export_root_param -> L__self___param
|
||||
def clean_name(name) -> str:
|
||||
if "____modules___export_root_" in name:
|
||||
return name.replace("____modules___export_root_", "_")
|
||||
if "__export_root_" in name:
|
||||
return name.replace("__export_root_", "_")
|
||||
return name
|
||||
|
||||
# Unlike getattr node, call_module can be invoked multiple times
|
||||
# In those cases, we should fix all invocations of call_module
|
||||
clean_named_module_map: dict[str, str] = {}
|
||||
@ -129,7 +164,7 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
old_target = node.target
|
||||
new_target = clean_name(old_target)
|
||||
new_target = clean_export_root_string(old_target)
|
||||
if new_target != old_target:
|
||||
node.target = new_target
|
||||
assert hasattr(graph_module, old_target)
|
||||
@ -140,8 +175,10 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
# Dynamo will only have one nested level
|
||||
if node.op == "call_module":
|
||||
old_target = node.target
|
||||
new_target = clean_name(old_target)
|
||||
new_name = clean_name(node.name)
|
||||
assert isinstance(old_target, str)
|
||||
new_target = clean_export_root_string(old_target)
|
||||
assert isinstance(new_target, str)
|
||||
new_name = clean_export_root_string(node.name)
|
||||
if new_target == old_target:
|
||||
continue
|
||||
|
||||
@ -150,8 +187,6 @@ def clean_export_root(graph_module: torch.fx.GraphModule) -> None:
|
||||
node.target = clean_named_module_map[old_target]
|
||||
node.name = new_name
|
||||
continue
|
||||
assert isinstance(old_target, str)
|
||||
assert isinstance(new_target, str)
|
||||
target = graph_module.get_submodule(old_target)
|
||||
graph_module.delete_submodule(old_target)
|
||||
graph_module.add_submodule(new_target, target)
|
||||
@ -559,7 +594,7 @@ def _dynamo_graph_capture_for_export(
|
||||
)
|
||||
transformed_graph.recompile()
|
||||
|
||||
clean_nn_module_stack(
|
||||
clean_nn_module_stack_and_source_fn(
|
||||
transformed_graph, torch._dynamo.config.inline_inbuilt_nn_modules
|
||||
)
|
||||
clean_export_root(transformed_graph)
|
||||
|
@ -26,6 +26,7 @@ of module state.
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import re
|
||||
import types
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
@ -113,6 +114,12 @@ def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
|
||||
@contextmanager
|
||||
def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
|
||||
fully_qualified_name = source.name()
|
||||
# Remove redundant namings
|
||||
fully_qualified_name = re.sub(
|
||||
r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]",
|
||||
r".\2",
|
||||
fully_qualified_name,
|
||||
)
|
||||
num_calls = tx.num_calls.get(fully_qualified_name, 0)
|
||||
module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key
|
||||
try:
|
||||
|
@ -357,24 +357,11 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
||||
if add_root:
|
||||
|
||||
def normalize_path(path):
|
||||
try:
|
||||
parts = []
|
||||
|
||||
class Path:
|
||||
def __getattr__(self, name):
|
||||
if name != "_modules":
|
||||
parts.append(name)
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
parts.append(str(idx))
|
||||
return self
|
||||
|
||||
eval(path, {"L": {"self": Path()}})
|
||||
return ".".join(parts)
|
||||
except Exception: # TODO(zhxchen17) Remove this.
|
||||
return path
|
||||
if path == "L['self']":
|
||||
return ""
|
||||
if path.startswith("L['self']."):
|
||||
return path[len("L['self'].") :]
|
||||
return path
|
||||
|
||||
nn_module_stack = {
|
||||
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
|
||||
|
Reference in New Issue
Block a user