mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support partial _DynamoCacheEntries when not all backends available (#163521)
Differential Revision: [D82735769](https://our.internmc.facebook.com/intern/diff/D82735769/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163521 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
5656d45c8f
commit
fa5306b4f5
@ -531,6 +531,53 @@ def add(x, y):
|
||||
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_graph_break_partial_backend(self, device):
|
||||
if device == "cuda" and not HAS_CUDA_AND_TRITON:
|
||||
raise unittest.SkipTest("Requires CUDA/Triton")
|
||||
if device == "xpu" and not HAS_XPU_AND_TRITON:
|
||||
raise unittest.SkipTest("Requires XPU/Triton")
|
||||
|
||||
def fn(x):
|
||||
y = x.sin()
|
||||
torch._dynamo.graph_break()
|
||||
return x.sin() + y
|
||||
|
||||
arg1 = torch.randn(3, 2, device=device, requires_grad=True)
|
||||
arg2 = arg1.clone().detach_().requires_grad_(True)
|
||||
compiled_fn = torch.compile(fn)
|
||||
expected1 = compiled_fn(arg1)
|
||||
expected1.sum().backward()
|
||||
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
|
||||
|
||||
# Remove backends related to resume functions
|
||||
dynamo_entry = next(iter(PrecompileContext._dynamo_cache_entries.values()))
|
||||
for code in dynamo_entry.codes:
|
||||
module = sys.modules[code.python_module]
|
||||
if code.install_to_global:
|
||||
# Clear the fn_names from global scope, to simulate a new environment
|
||||
for fn_name in code.function_names:
|
||||
module.__dict__.pop(fn_name)
|
||||
for fn_name in code.function_names:
|
||||
if "resume" in fn_name:
|
||||
self.assertEqual(len(code.backend_ids), 1)
|
||||
# delete the fn from the global scope to simulate a new
|
||||
backend = code.backend_ids[0]
|
||||
# Delete the backend associated with the resume function
|
||||
del PrecompileContext._backend_artifacts_by_key[backend]
|
||||
|
||||
self._save_and_reload(expected_backends=1, expected_dynamo=1)
|
||||
|
||||
compiled_fn = torch.compile(fn)
|
||||
# Run it again. There will be a recompile because one of the backends is deleted, but it should
|
||||
# still work.
|
||||
expected2 = compiled_fn(arg2)
|
||||
expected2.sum().backward()
|
||||
self.assertEqual(expected1, expected2)
|
||||
# One recompile on a new frame, so total_frames should increase by 1
|
||||
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames + 1)
|
||||
|
||||
@parametrize("device", ("cpu", "cuda", "xpu"))
|
||||
@torch._dynamo.config.patch(caching_precompile=True)
|
||||
def test_call_function_from_resume(self, device):
|
||||
|
@ -16,6 +16,7 @@ import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@ -419,6 +420,38 @@ class PrecompileCacheEntry:
|
||||
dynamo: _DynamoCacheEntry
|
||||
backends: dict[_BackendId, Any]
|
||||
|
||||
@staticmethod
|
||||
def from_cache_entry(
|
||||
cache_entry: _DynamoCacheEntry, backends: dict[_BackendId, Any]
|
||||
) -> Optional["PrecompileCacheEntry"]:
|
||||
backend_content: dict[_BackendId, Any] = {}
|
||||
|
||||
for code in cache_entry.codes:
|
||||
for backend_id in code.backend_ids:
|
||||
if backend_id not in backends:
|
||||
logger.warning("Backend not found")
|
||||
debug_str = json.dumps(
|
||||
{
|
||||
"entry": cache_entry.debug_info(),
|
||||
"missing_backend": backend_id,
|
||||
}
|
||||
)
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "dynamo_cache_bypass",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: debug_str,
|
||||
expect_trace_id=False,
|
||||
)
|
||||
code.bypassed = True
|
||||
break
|
||||
else:
|
||||
backend_content[backend_id] = backends[backend_id]
|
||||
|
||||
return PrecompileCacheEntry(dynamo=cache_entry, backends=backend_content)
|
||||
|
||||
|
||||
def _hash_source(source: str) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
@ -612,10 +645,6 @@ class CompilePackage:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if (
|
||||
entry.bypassed
|
||||
): # Remove the code from the cache entry if it's been bypassed
|
||||
del self._codes[code]
|
||||
entry.has_compile_id = True
|
||||
self._current_entry = None
|
||||
|
||||
@ -729,6 +758,11 @@ class CompilePackage:
|
||||
if entry.code_source:
|
||||
target_code = _lookup_code(entry)
|
||||
|
||||
if entry.bypassed:
|
||||
# If the entry is bypassed, do not install backends
|
||||
# or guarded codes.
|
||||
continue
|
||||
|
||||
for backend_id in entry.backend_ids:
|
||||
if backend_id not in backends:
|
||||
raise RuntimeError(
|
||||
|
@ -88,7 +88,7 @@ class PrecompileContext:
|
||||
# Protected by the compile_lock
|
||||
# _backend_artifacts_by_key organizes results by the key of each artifact.
|
||||
# Each object here must be serializable
|
||||
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
|
||||
_backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {}
|
||||
|
||||
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
|
||||
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
|
||||
@ -107,7 +107,9 @@ class PrecompileContext:
|
||||
"""
|
||||
Records a backend artifact to be used with dynamo cache entries
|
||||
"""
|
||||
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
|
||||
cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy(
|
||||
artifact
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def record_dynamo_cache_entry(
|
||||
@ -121,7 +123,7 @@ class PrecompileContext:
|
||||
Edit the content of an existing artifact
|
||||
"""
|
||||
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
|
||||
artifact = cls._backend_artifacts_by_key[key]
|
||||
artifact = cls._backend_artifacts_by_key[_BackendId(key)]
|
||||
artifact.edit_contents(edit_fn)
|
||||
|
||||
@classmethod
|
||||
@ -129,12 +131,12 @@ class PrecompileContext:
|
||||
"""
|
||||
Return the backend cache artifact with the associated key
|
||||
"""
|
||||
return cls._backend_artifacts_by_key.get(key, None)
|
||||
return cls._backend_artifacts_by_key.get(_BackendId(key), None)
|
||||
|
||||
@staticmethod
|
||||
def dump_debug_info(
|
||||
dynamo_entries: dict[str, _DynamoCacheEntry],
|
||||
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
|
||||
backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Return a JSON serializable debug dump of all entries in the precompile context
|
||||
@ -195,33 +197,11 @@ class PrecompileContext:
|
||||
|
||||
for key, cache_entry in dynamo_entries.items():
|
||||
try:
|
||||
backends = cache_entry.backend_ids
|
||||
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
|
||||
for id_ in backends:
|
||||
if id_ not in backend_artifacts:
|
||||
debug_str = json.dumps(
|
||||
{
|
||||
"entry": cache_entry.debug_info,
|
||||
"key": key,
|
||||
}
|
||||
)
|
||||
logger.warning("Backend not found")
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "dynamo_cache_bypass",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: debug_str,
|
||||
expect_trace_id=False,
|
||||
)
|
||||
continue
|
||||
artifact = backend_artifacts[id_]
|
||||
assert isinstance(artifact, BackendCacheArtifact)
|
||||
backend_content[id_] = artifact
|
||||
precompile_cache_entries[key] = PrecompileCacheEntry(
|
||||
dynamo=cache_entry, backends=backend_content
|
||||
result = PrecompileCacheEntry.from_cache_entry(
|
||||
cache_entry, backend_artifacts
|
||||
)
|
||||
if result is not None:
|
||||
precompile_cache_entries[key] = result
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create cache entry %s: %s", key, str(e))
|
||||
|
||||
|
@ -1198,7 +1198,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
cache_state = "miss"
|
||||
if (
|
||||
config.strict_autograd_cache
|
||||
or torch._dynamo.config.caching_precompile
|
||||
or torch._dynamo.config.strict_precompile
|
||||
):
|
||||
raise e
|
||||
# Most often this is BypassAOTAutogradCache, but
|
||||
@ -1231,7 +1231,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
log_cache_bypass("bypass_aot_autograd", str(e))
|
||||
if (
|
||||
config.strict_autograd_cache
|
||||
or torch._dynamo.config.caching_precompile
|
||||
or torch._dynamo.config.strict_precompile
|
||||
):
|
||||
raise e
|
||||
if compiled_fn is None:
|
||||
|
Reference in New Issue
Block a user