Avoid AOTAutogradCache.load in stack trace on cache miss path (#158149)

The general context for the upcoming stack of commits is I am attempting
to "pipeline" AOTAutograd.  Instead of having function f call function g
which is the next "stage" of compilation, instead f should return with
its outputs, which are then piped to g for the next stage.  This will
make it easier to implement early exit / resume pipeline without forcing
callback structure, which is good for export-style use cases.  It also
reduces the size of our stack traces, which makes tools like Perfetto
happy.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158149
Approved by: https://github.com/jamesjwu
This commit is contained in:
Edward Z. Yang
2025-07-15 06:18:20 -07:00
committed by PyTorch MergeBot
parent 3beb915004
commit 148789ddd8
2 changed files with 23 additions and 23 deletions

View File

@ -1080,8 +1080,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
pass
@staticmethod
def load(
dispatch_and_compile: Callable,
def try_load(
mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper],
args,
aot_config: AOTConfig,
@ -1089,7 +1088,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
boxed_forward_device_index: Optional[BoxedDeviceIndex],
local: bool,
remote: bool,
) -> Callable:
) -> Optional[Callable]:
"""
Load a result from the cache, and reconstruct a runtime wrapper around the object
"""
@ -1198,7 +1197,6 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
time.time_ns(),
forward_symints=symints,
)
compiled_fn = dispatch_and_compile()
cache_info.update(
{
@ -1232,6 +1230,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
},
payload_fn=lambda: json.dumps(cache_info),
)
return compiled_fn
@classmethod

View File

@ -1190,26 +1190,27 @@ def aot_module_simplified(
)
return compiled_fn
# We only care if the forward will return an OutputCode.
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
local = should_use_local_autograd_cache()
remote = should_use_remote_autograd_cache()
if local or remote:
set_feature_use("aot_autograd_remote_cache", remote)
compiled_fn = AOTAutogradCache.load(
dispatch_and_compile,
mod,
fake_flat_args,
aot_config,
cudagraphs,
boxed_forward_device_index,
local,
remote,
)
else:
compiled_fn = dispatch_and_compile()
else:
while True:
# We only care if the forward will return an OutputCode.
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
local = should_use_local_autograd_cache()
remote = should_use_remote_autograd_cache()
if local or remote:
set_feature_use("aot_autograd_remote_cache", remote)
compiled_fn = AOTAutogradCache.try_load(
mod,
fake_flat_args,
aot_config,
cudagraphs,
boxed_forward_device_index,
local,
remote,
)
if compiled_fn is not None:
break
compiled_fn = dispatch_and_compile()
break
if isinstance(mod, torch._dynamo.utils.GmWrapper):
# This function is called by the flatten_graph_inputs wrapper, which boxes