mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Avoid AOTAutogradCache.load in stack trace on cache miss path (#158149)
The general context for the upcoming stack of commits is I am attempting to "pipeline" AOTAutograd. Instead of having function f call function g which is the next "stage" of compilation, instead f should return with its outputs, which are then piped to g for the next stage. This will make it easier to implement early exit / resume pipeline without forcing callback structure, which is good for export-style use cases. It also reduces the size of our stack traces, which makes tools like Perfetto happy. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158149 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
3beb915004
commit
148789ddd8
@ -1080,8 +1080,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
dispatch_and_compile: Callable,
|
||||
def try_load(
|
||||
mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper],
|
||||
args,
|
||||
aot_config: AOTConfig,
|
||||
@ -1089,7 +1088,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex],
|
||||
local: bool,
|
||||
remote: bool,
|
||||
) -> Callable:
|
||||
) -> Optional[Callable]:
|
||||
"""
|
||||
Load a result from the cache, and reconstruct a runtime wrapper around the object
|
||||
"""
|
||||
@ -1198,7 +1197,6 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
time.time_ns(),
|
||||
forward_symints=symints,
|
||||
)
|
||||
compiled_fn = dispatch_and_compile()
|
||||
|
||||
cache_info.update(
|
||||
{
|
||||
@ -1232,6 +1230,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
},
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
)
|
||||
|
||||
return compiled_fn
|
||||
|
||||
@classmethod
|
||||
|
@ -1190,26 +1190,27 @@ def aot_module_simplified(
|
||||
)
|
||||
return compiled_fn
|
||||
|
||||
# We only care if the forward will return an OutputCode.
|
||||
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
|
||||
local = should_use_local_autograd_cache()
|
||||
remote = should_use_remote_autograd_cache()
|
||||
if local or remote:
|
||||
set_feature_use("aot_autograd_remote_cache", remote)
|
||||
compiled_fn = AOTAutogradCache.load(
|
||||
dispatch_and_compile,
|
||||
mod,
|
||||
fake_flat_args,
|
||||
aot_config,
|
||||
cudagraphs,
|
||||
boxed_forward_device_index,
|
||||
local,
|
||||
remote,
|
||||
)
|
||||
else:
|
||||
compiled_fn = dispatch_and_compile()
|
||||
else:
|
||||
while True:
|
||||
# We only care if the forward will return an OutputCode.
|
||||
if isinstance(fw_compiler, SerializableAOTDispatchCompiler):
|
||||
local = should_use_local_autograd_cache()
|
||||
remote = should_use_remote_autograd_cache()
|
||||
if local or remote:
|
||||
set_feature_use("aot_autograd_remote_cache", remote)
|
||||
compiled_fn = AOTAutogradCache.try_load(
|
||||
mod,
|
||||
fake_flat_args,
|
||||
aot_config,
|
||||
cudagraphs,
|
||||
boxed_forward_device_index,
|
||||
local,
|
||||
remote,
|
||||
)
|
||||
if compiled_fn is not None:
|
||||
break
|
||||
|
||||
compiled_fn = dispatch_and_compile()
|
||||
break
|
||||
|
||||
if isinstance(mod, torch._dynamo.utils.GmWrapper):
|
||||
# This function is called by the flatten_graph_inputs wrapper, which boxes
|
||||
|
Reference in New Issue
Block a user