mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename modules in AOTAutograd (#158449)
Fixes https://github.com/pytorch/pytorch/issues/158382 ``` renamed: torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py -> torch/_functorch/_aot_autograd/graph_capture.py renamed: torch/_functorch/_aot_autograd/traced_function_transforms.py -> torch/_functorch/_aot_autograd/graph_capture_wrappers.py renamed: torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py -> torch/_functorch/_aot_autograd/graph_compile.py ``` Everything else is ONLY import changes. I did not rename any functions even if we probably should have. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158449 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
1eb6b2089f
commit
979fae761c
@ -1213,7 +1213,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
|
||||
@torch._functorch.config.patch(donated_buffer=True)
|
||||
def test_donated_buffer1(self):
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
@torch.compile()
|
||||
def relu(x):
|
||||
@ -1233,7 +1233,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
|
||||
@torch._functorch.config.patch("donated_buffer", True)
|
||||
def test_donated_buffer2(self):
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
# we will reuse the graph for g across f1 and f2
|
||||
@torch.compile()
|
||||
@ -1255,7 +1255,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
|
||||
@torch._functorch.config.patch("donated_buffer", True)
|
||||
def test_donated_buffer3(self):
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
# we will reuse the graph for g across f1 and f2
|
||||
@torch.compile()
|
||||
@ -1278,7 +1278,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
|
||||
@torch._functorch.config.patch("donated_buffer", True)
|
||||
def test_donated_buffer4(self):
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1309,7 +1309,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
|
||||
@torch._functorch.config.patch("donated_buffer", True)
|
||||
def test_donated_buffer5(self):
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
@torch.compile()
|
||||
def f(x, z):
|
||||
@ -1346,7 +1346,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
# SymNodeVariable() is not a constant
|
||||
return
|
||||
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
def fn(x):
|
||||
p = torch.nn.Parameter(x + 123)
|
||||
|
@ -7469,7 +7469,7 @@ metadata incorrectly.
|
||||
"pack_hash",
|
||||
"unpack_hash",
|
||||
)
|
||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||
logger_name = "torch._functorch._aot_autograd.graph_compile"
|
||||
|
||||
class SAF(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
@ -433,7 +433,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
||||
aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
|
||||
)
|
||||
|
||||
result = None
|
||||
|
@ -3211,7 +3211,7 @@ class TestUbackedOps(TestCase):
|
||||
self.assertEqual(compiled_result, eager_result)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
make_non_contiguous_tensor_and_test(4)
|
||||
@ -3246,7 +3246,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
compiled_result = compiled_func(x, torch.tensor([10]))
|
||||
@ -3305,7 +3305,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
torch._dynamo.decorators.mark_unbacked(x, 1)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
result_eager = func(x, torch.tensor([5, 20]))
|
||||
@ -3355,7 +3355,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
|
||||
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
||||
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
|
||||
)
|
||||
with ctx():
|
||||
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
||||
|
@ -23,8 +23,7 @@ from .functional_utils import (
|
||||
assert_functional_graph,
|
||||
propagate_input_mutation_stacktraces,
|
||||
)
|
||||
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
|
||||
from .traced_function_transforms import (
|
||||
from .graph_capture_wrappers import (
|
||||
aot_dispatch_subclass,
|
||||
create_functionalized_fn,
|
||||
create_joint,
|
||||
@ -32,6 +31,7 @@ from .traced_function_transforms import (
|
||||
fn_prepped_for_autograd,
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
|
||||
from .utils import (
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
register_buffer_assignment_hook,
|
@ -51,10 +51,7 @@ from .autograd_cache import (
|
||||
should_bundle_autograd_cache,
|
||||
should_use_remote_autograd_cache,
|
||||
)
|
||||
from .dispatch_and_compile_graph import (
|
||||
aot_dispatch_autograd_graph,
|
||||
aot_dispatch_base_graph,
|
||||
)
|
||||
from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph
|
||||
from .logging_utils import track_graph_compiling
|
||||
from .runtime_wrappers import (
|
||||
AOTDedupeWrapper,
|
||||
@ -896,7 +893,7 @@ def create_wrap_fn(fn, args):
|
||||
|
||||
|
||||
def prepare_hook_gm(aot_config, fn, args):
|
||||
from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph
|
||||
from torch._functorch._aot_autograd.graph_capture import _create_graph
|
||||
|
||||
fn, args = create_wrap_fn(fn, args)
|
||||
gm = _create_graph(fn, args, aot_config=aot_config)
|
@ -41,6 +41,7 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from .. import config
|
||||
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
|
||||
from .functional_utils import gen_alias_from_base
|
||||
from .graph_capture_wrappers import aot_dispatch_subclass
|
||||
from .input_output_analysis import (
|
||||
compute_overlapping_inputs,
|
||||
create_synthetic_base_metadata,
|
||||
@ -65,7 +66,6 @@ from .subclass_utils import (
|
||||
runtime_unwrap_tensor_subclasses,
|
||||
wrap_tensor_subclasses,
|
||||
)
|
||||
from .traced_function_transforms import aot_dispatch_subclass
|
||||
from .utils import (
|
||||
call_func_at_runtime_with_args,
|
||||
make_boxed_func,
|
||||
|
@ -62,17 +62,26 @@ from ._aot_autograd.functional_utils import ( # noqa: F401
|
||||
sync_functional_tensor,
|
||||
to_fun,
|
||||
)
|
||||
from ._aot_autograd.graph_capture_wrappers import ( # noqa: F401
|
||||
aot_dispatch_subclass,
|
||||
create_functional_call,
|
||||
create_functionalized_fn,
|
||||
create_functionalized_rng_ops_wrapper,
|
||||
create_joint,
|
||||
fn_input_mutations_to_outputs,
|
||||
fn_prepped_for_autograd,
|
||||
)
|
||||
from ._aot_autograd.graph_compile import ( # noqa: F401
|
||||
aot_stage1_graph_capture,
|
||||
aot_stage2_compile,
|
||||
aot_stage2_export,
|
||||
)
|
||||
from ._aot_autograd.input_output_analysis import ( # noqa: F401
|
||||
compute_overlapping_inputs,
|
||||
create_graph_signature,
|
||||
create_synthetic_base_metadata,
|
||||
remove_dupe_metadata,
|
||||
)
|
||||
from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401
|
||||
aot_stage1_graph_capture,
|
||||
aot_stage2_compile,
|
||||
aot_stage2_export,
|
||||
)
|
||||
from ._aot_autograd.logging_utils import ( # noqa: F401
|
||||
callback_set,
|
||||
describe_input,
|
||||
@ -118,15 +127,6 @@ from ._aot_autograd.subclass_utils import ( # noqa: F401
|
||||
wrap_tensor_subclasses,
|
||||
wrap_tensor_subclasses_maybe_joint,
|
||||
)
|
||||
from ._aot_autograd.traced_function_transforms import ( # noqa: F401
|
||||
aot_dispatch_subclass,
|
||||
create_functional_call,
|
||||
create_functionalized_fn,
|
||||
create_functionalized_rng_ops_wrapper,
|
||||
create_joint,
|
||||
fn_input_mutations_to_outputs,
|
||||
fn_prepped_for_autograd,
|
||||
)
|
||||
from ._aot_autograd.utils import ( # noqa: F401
|
||||
_get_autocast_states,
|
||||
_get_symint_hints,
|
||||
|
@ -49,15 +49,13 @@ from torch._export.utils import (
|
||||
)
|
||||
from torch._export.verifier import SpecViolationError
|
||||
from torch._export.wrappers import _wrap_submodules
|
||||
from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call
|
||||
from torch._functorch._aot_autograd.input_output_analysis import (
|
||||
_graph_input_names,
|
||||
_graph_output_names,
|
||||
)
|
||||
from torch._functorch._aot_autograd.schemas import GraphSignature
|
||||
from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container
|
||||
from torch._functorch._aot_autograd.traced_function_transforms import (
|
||||
create_functional_call,
|
||||
)
|
||||
from torch._functorch._aot_autograd.utils import (
|
||||
create_tree_flattened_fn,
|
||||
register_buffer_assignment_hook,
|
||||
|
Reference in New Issue
Block a user