Rename modules in AOTAutograd (#158449)

Fixes https://github.com/pytorch/pytorch/issues/158382

```
renamed:    torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py -> torch/_functorch/_aot_autograd/graph_capture.py
renamed:    torch/_functorch/_aot_autograd/traced_function_transforms.py -> torch/_functorch/_aot_autograd/graph_capture_wrappers.py
renamed:    torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py -> torch/_functorch/_aot_autograd/graph_compile.py
```

Everything else is ONLY import changes. I did not rename any functions
even if we probably should have.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158449
Approved by: https://github.com/jamesjwu
This commit is contained in:
Edward Z. Yang
2025-07-20 21:27:45 -07:00
committed by PyTorch MergeBot
parent 1eb6b2089f
commit 979fae761c
10 changed files with 32 additions and 37 deletions

View File

@ -3211,7 +3211,7 @@ class TestUbackedOps(TestCase):
self.assertEqual(compiled_result, eager_result)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
make_non_contiguous_tensor_and_test(4)
@ -3246,7 +3246,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
torch._dynamo.decorators.mark_unbacked(x, 0)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
compiled_result = compiled_func(x, torch.tensor([10]))
@ -3305,7 +3305,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
result_eager = func(x, torch.tensor([5, 20]))
@ -3355,7 +3355,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].