Rename modules in AOTAutograd (#158449)

Fixes https://github.com/pytorch/pytorch/issues/158382

```
renamed:    torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py -> torch/_functorch/_aot_autograd/graph_capture.py
renamed:    torch/_functorch/_aot_autograd/traced_function_transforms.py -> torch/_functorch/_aot_autograd/graph_capture_wrappers.py
renamed:    torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py -> torch/_functorch/_aot_autograd/graph_compile.py
```

Everything else is ONLY import changes. I did not rename any functions
even if we probably should have.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158449
Approved by: https://github.com/jamesjwu
This commit is contained in:
Edward Z. Yang
2025-07-20 21:27:45 -07:00
committed by PyTorch MergeBot
parent 1eb6b2089f
commit 979fae761c
10 changed files with 32 additions and 37 deletions

View File

@ -1213,7 +1213,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
@torch._functorch.config.patch(donated_buffer=True)
def test_donated_buffer1(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
@torch.compile()
def relu(x):
@ -1233,7 +1233,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer2(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
# we will reuse the graph for g across f1 and f2
@torch.compile()
@ -1255,7 +1255,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer3(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
# we will reuse the graph for g across f1 and f2
@torch.compile()
@ -1278,7 +1278,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer4(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
class Mod(torch.nn.Module):
def __init__(self) -> None:
@ -1309,7 +1309,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer5(self):
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
@torch.compile()
def f(x, z):
@ -1346,7 +1346,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
# SymNodeVariable() is not a constant
return
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
def fn(x):
p = torch.nn.Parameter(x + 123)

View File

@ -7469,7 +7469,7 @@ metadata incorrectly.
"pack_hash",
"unpack_hash",
)
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
logger_name = "torch._functorch._aot_autograd.graph_compile"
class SAF(torch.autograd.Function):
@staticmethod

View File

@ -433,7 +433,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
result = None

View File

@ -3211,7 +3211,7 @@ class TestUbackedOps(TestCase):
self.assertEqual(compiled_result, eager_result)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
make_non_contiguous_tensor_and_test(4)
@ -3246,7 +3246,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
torch._dynamo.decorators.mark_unbacked(x, 0)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
compiled_result = compiled_func(x, torch.tensor([10]))
@ -3305,7 +3305,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
result_eager = func(x, torch.tensor([5, 20]))
@ -3355,7 +3355,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].

View File

@ -23,8 +23,7 @@ from .functional_utils import (
assert_functional_graph,
propagate_input_mutation_stacktraces,
)
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
from .traced_function_transforms import (
from .graph_capture_wrappers import (
aot_dispatch_subclass,
create_functionalized_fn,
create_joint,
@ -32,6 +31,7 @@ from .traced_function_transforms import (
fn_prepped_for_autograd,
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
from .utils import (
copy_fwd_metadata_to_bw_nodes,
register_buffer_assignment_hook,

View File

@ -51,10 +51,7 @@ from .autograd_cache import (
should_bundle_autograd_cache,
should_use_remote_autograd_cache,
)
from .dispatch_and_compile_graph import (
aot_dispatch_autograd_graph,
aot_dispatch_base_graph,
)
from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph
from .logging_utils import track_graph_compiling
from .runtime_wrappers import (
AOTDedupeWrapper,
@ -896,7 +893,7 @@ def create_wrap_fn(fn, args):
def prepare_hook_gm(aot_config, fn, args):
from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph
from torch._functorch._aot_autograd.graph_capture import _create_graph
fn, args = create_wrap_fn(fn, args)
gm = _create_graph(fn, args, aot_config=aot_config)

View File

@ -41,6 +41,7 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import gen_alias_from_base
from .graph_capture_wrappers import aot_dispatch_subclass
from .input_output_analysis import (
compute_overlapping_inputs,
create_synthetic_base_metadata,
@ -65,7 +66,6 @@ from .subclass_utils import (
runtime_unwrap_tensor_subclasses,
wrap_tensor_subclasses,
)
from .traced_function_transforms import aot_dispatch_subclass
from .utils import (
call_func_at_runtime_with_args,
make_boxed_func,

View File

@ -62,17 +62,26 @@ from ._aot_autograd.functional_utils import ( # noqa: F401
sync_functional_tensor,
to_fun,
)
from ._aot_autograd.graph_capture_wrappers import ( # noqa: F401
aot_dispatch_subclass,
create_functional_call,
create_functionalized_fn,
create_functionalized_rng_ops_wrapper,
create_joint,
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
)
from ._aot_autograd.graph_compile import ( # noqa: F401
aot_stage1_graph_capture,
aot_stage2_compile,
aot_stage2_export,
)
from ._aot_autograd.input_output_analysis import ( # noqa: F401
compute_overlapping_inputs,
create_graph_signature,
create_synthetic_base_metadata,
remove_dupe_metadata,
)
from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401
aot_stage1_graph_capture,
aot_stage2_compile,
aot_stage2_export,
)
from ._aot_autograd.logging_utils import ( # noqa: F401
callback_set,
describe_input,
@ -118,15 +127,6 @@ from ._aot_autograd.subclass_utils import ( # noqa: F401
wrap_tensor_subclasses,
wrap_tensor_subclasses_maybe_joint,
)
from ._aot_autograd.traced_function_transforms import ( # noqa: F401
aot_dispatch_subclass,
create_functional_call,
create_functionalized_fn,
create_functionalized_rng_ops_wrapper,
create_joint,
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
)
from ._aot_autograd.utils import ( # noqa: F401
_get_autocast_states,
_get_symint_hints,

View File

@ -49,15 +49,13 @@ from torch._export.utils import (
)
from torch._export.verifier import SpecViolationError
from torch._export.wrappers import _wrap_submodules
from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call
from torch._functorch._aot_autograd.input_output_analysis import (
_graph_input_names,
_graph_output_names,
)
from torch._functorch._aot_autograd.schemas import GraphSignature
from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container
from torch._functorch._aot_autograd.traced_function_transforms import (
create_functional_call,
)
from torch._functorch._aot_autograd.utils import (
create_tree_flattened_fn,
register_buffer_assignment_hook,