Support partial _DynamoCacheEntries when not all backends available (#163521)

Differential Revision: [D82735769](https://our.internmc.facebook.com/intern/diff/D82735769/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163521
Approved by: https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-10-02 12:12:01 -07:00
committed by PyTorch MergeBot
parent 5656d45c8f
commit fa5306b4f5
4 changed files with 98 additions and 37 deletions

View File

@ -531,6 +531,53 @@ def add(x, y):
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_graph_break_partial_backend(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x):
y = x.sin()
torch._dynamo.graph_break()
return x.sin() + y
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
arg2 = arg1.clone().detach_().requires_grad_(True)
compiled_fn = torch.compile(fn)
expected1 = compiled_fn(arg1)
expected1.sum().backward()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
# Remove backends related to resume functions
dynamo_entry = next(iter(PrecompileContext._dynamo_cache_entries.values()))
for code in dynamo_entry.codes:
module = sys.modules[code.python_module]
if code.install_to_global:
# Clear the fn_names from global scope, to simulate a new environment
for fn_name in code.function_names:
module.__dict__.pop(fn_name)
for fn_name in code.function_names:
if "resume" in fn_name:
self.assertEqual(len(code.backend_ids), 1)
# delete the fn from the global scope to simulate a new
backend = code.backend_ids[0]
# Delete the backend associated with the resume function
del PrecompileContext._backend_artifacts_by_key[backend]
self._save_and_reload(expected_backends=1, expected_dynamo=1)
compiled_fn = torch.compile(fn)
# Run it again. There will be a recompile because one of the backends is deleted, but it should
# still work.
expected2 = compiled_fn(arg2)
expected2.sum().backward()
self.assertEqual(expected1, expected2)
# One recompile on a new frame, so total_frames should increase by 1
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames + 1)
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_call_function_from_resume(self, device):

View File

@ -16,6 +16,7 @@ import functools
import hashlib
import importlib
import inspect
import json
import logging
import os
import pickle
@ -419,6 +420,38 @@ class PrecompileCacheEntry:
dynamo: _DynamoCacheEntry
backends: dict[_BackendId, Any]
@staticmethod
def from_cache_entry(
cache_entry: _DynamoCacheEntry, backends: dict[_BackendId, Any]
) -> Optional["PrecompileCacheEntry"]:
backend_content: dict[_BackendId, Any] = {}
for code in cache_entry.codes:
for backend_id in code.backend_ids:
if backend_id not in backends:
logger.warning("Backend not found")
debug_str = json.dumps(
{
"entry": cache_entry.debug_info(),
"missing_backend": backend_id,
}
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_bypass",
"encoding": "json",
},
payload_fn=lambda: debug_str,
expect_trace_id=False,
)
code.bypassed = True
break
else:
backend_content[backend_id] = backends[backend_id]
return PrecompileCacheEntry(dynamo=cache_entry, backends=backend_content)
def _hash_source(source: str) -> str:
sha256_hash = hashlib.sha256()
@ -612,10 +645,6 @@ class CompilePackage:
try:
yield
finally:
if (
entry.bypassed
): # Remove the code from the cache entry if it's been bypassed
del self._codes[code]
entry.has_compile_id = True
self._current_entry = None
@ -729,6 +758,11 @@ class CompilePackage:
if entry.code_source:
target_code = _lookup_code(entry)
if entry.bypassed:
# If the entry is bypassed, do not install backends
# or guarded codes.
continue
for backend_id in entry.backend_ids:
if backend_id not in backends:
raise RuntimeError(

View File

@ -88,7 +88,7 @@ class PrecompileContext:
# Protected by the compile_lock
# _backend_artifacts_by_key organizes results by the key of each artifact.
# Each object here must be serializable
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
_backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {}
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
@ -107,7 +107,9 @@ class PrecompileContext:
"""
Records a backend artifact to be used with dynamo cache entries
"""
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy(
artifact
)
@classmethod
def record_dynamo_cache_entry(
@ -121,7 +123,7 @@ class PrecompileContext:
Edit the content of an existing artifact
"""
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
artifact = cls._backend_artifacts_by_key[key]
artifact = cls._backend_artifacts_by_key[_BackendId(key)]
artifact.edit_contents(edit_fn)
@classmethod
@ -129,12 +131,12 @@ class PrecompileContext:
"""
Return the backend cache artifact with the associated key
"""
return cls._backend_artifacts_by_key.get(key, None)
return cls._backend_artifacts_by_key.get(_BackendId(key), None)
@staticmethod
def dump_debug_info(
dynamo_entries: dict[str, _DynamoCacheEntry],
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]],
) -> dict[str, Any]:
"""
Return a JSON serializable debug dump of all entries in the precompile context
@ -195,33 +197,11 @@ class PrecompileContext:
for key, cache_entry in dynamo_entries.items():
try:
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
for id_ in backends:
if id_ not in backend_artifacts:
debug_str = json.dumps(
{
"entry": cache_entry.debug_info,
"key": key,
}
)
logger.warning("Backend not found")
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_bypass",
"encoding": "json",
},
payload_fn=lambda: debug_str,
expect_trace_id=False,
)
continue
artifact = backend_artifacts[id_]
assert isinstance(artifact, BackendCacheArtifact)
backend_content[id_] = artifact
precompile_cache_entries[key] = PrecompileCacheEntry(
dynamo=cache_entry, backends=backend_content
result = PrecompileCacheEntry.from_cache_entry(
cache_entry, backend_artifacts
)
if result is not None:
precompile_cache_entries[key] = result
except Exception as e:
logger.warning("Failed to create cache entry %s: %s", key, str(e))

View File

@ -1198,7 +1198,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
cache_state = "miss"
if (
config.strict_autograd_cache
or torch._dynamo.config.caching_precompile
or torch._dynamo.config.strict_precompile
):
raise e
# Most often this is BypassAOTAutogradCache, but
@ -1231,7 +1231,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
log_cache_bypass("bypass_aot_autograd", str(e))
if (
config.strict_autograd_cache
or torch._dynamo.config.caching_precompile
or torch._dynamo.config.strict_precompile
):
raise e
if compiled_fn is None: