Files
pytorch/torch/_dynamo/precompile_context.py
James Wu f55c5d085e [Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847)
This PR addresses a few small bugfixes needed to make NanoGPT inference work, and also adds a new `--caching-precompile` argument to torchbench. With `--caching-precompile`, after every benchmark we save precompile artifacts to DynamoCache, allowing us to test caching precompile on all existing benchmarks.

The following bugfixes are in this PR to make all of this work:
- Fix global variables being pruned with DUPLICATE_INPUT guards. DUPLICATE_INPUT guards have additional vars from the second input, which we track with additional_local_vars, but we never tracked additional global variables. This fixes the issue. (See torch/_dynamo/guards.py changes)
- Return None from PRecompileContext.serialize() if no new dynamo compiles occurred. There's no reason to save artifacts (i.e. autotuning artifacts, etc) if no dynamo_compile occurred, so we return None early. We may later want to support editing existing dynamo artifacts as a TODO, but that's upcoming.
- log `dynamo_start` on CompilePackage.load: This is only needed so that tlparse doesn't ignore TORCH_TRACE logs generated when caching precompile hits. If there are no actual compiles, we never log a "dynamo_start" entry, which makes internal tlparse ignore the TORCH_TRACE file.

## Test Plan

After this PR, the following now works:
```
TORCH_LOGS=dynamo tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance  --inference --backend inductor  --caching-precompile --warm-start-latency
```
tlparse result (internal):
Cold Start (6 seconds):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_vk9nkp4m.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Warm Start (~1 s):
https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_5l4iwrpm.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

The 1 second of warm start here can be improved: the costs here are mostly in starting up workers and triton and initializing CUDA, a lot of which should not be included in the compile time cost in real world scenarios where these are already loaded before training begins.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158847
Approved by: https://github.com/zhxchen17
2025-07-24 14:09:54 +00:00

186 lines
7.1 KiB
Python

from abc import abstractmethod
from collections import defaultdict
from itertools import chain
from typing import Any, Generic, Optional, TypeVar
from typing_extensions import override
from torch.compiler._cache import (
_serialize_single_cache,
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
CacheArtifactsResult,
CacheInfo,
)
from torch.utils._appending_byte_serializer import AppendingByteSerializer
from torch.utils._ordered_set import OrderedSet
"""
Classes and implementations related to precompile
"""
T = TypeVar("T")
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
"""
Data for each cache artifact that will be serialized and deserialized by
PrecompileContext, rather than CacheArtifactManager.
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
as needed, and use them in after_deserialization.
Example implementation:
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
my_field: int
def after_deserialization(self) -> MySerializableType:
result = pickle.loads(self.content)
# Do some extra work post deserialization
result.my_post_deserialization_function(self.my_field)
return result
"""
@override
def populate_cache(self) -> None:
raise RuntimeError("Precompile cache artifacts do not populate caches")
@override
def precompile_compatible(self) -> bool:
return True
@abstractmethod
def after_deserialization(self) -> T:
"""
Code to be run after reading raw byte contents from disk.
Generally converts self.content from raw bytes back into its original form.
"""
...
class PrecompileContext(CacheArtifactManager):
"""
PrecompileContext is a special CacheArtifactManager for handling precompilation
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
together and place it into a global Precompile Cache.
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
- DynamoCodeStateArtifact
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
"""
# Protected by the compile_lock
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
# This allows us to implement serialize_by_key easily.
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
# are transferred to _new_cache_artifacts before serialization.
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
# When serialize() is called, artifacts are transferred from _cache_artifacts to
# internal data structure of the _serializer
# This allows us to only pay the cost of serialization if serialize() is called
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
)
_cache_info: CacheInfo = CacheInfo()
@classmethod
def clear(cls) -> None:
cls._new_cache_artifacts_by_key.clear()
super().clear()
@override
@classmethod
def record_artifact(
cls,
artifact_type: str,
key: str,
content: Any,
) -> None:
"""
Called from each caching operation to record the artifact in this
"mega" list
"""
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
# TODO: although this covers completely same artifacts, it's possible
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
# (i.e. backend_ids) are different, but whose contents are equal.
# In those cases, it would be much better if we only serialize once instead
# of N times.
if artifact in cls._seen_artifacts:
return
cls._new_cache_artifacts_by_key[key] = artifact
cls._seen_artifacts.add(artifact)
@classmethod
def _save_artifacts_by_type(cls) -> None:
"""
We normally record artifacts by key, but serialization expects them to be organized
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
"""
for artifact in cls._new_cache_artifacts_by_key.values():
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear()
@classmethod
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
"""
Serialize all artifacts with the given key returned in a list.
"""
return cls._new_cache_artifacts_by_key.get(key, None)
@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
cls._save_artifacts_by_type()
# No need to serialize if there are no new dynamo compiles
if "precompile_dynamo" not in cls._new_cache_artifacts:
return None
return super().serialize()
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
PrecompileContext._ensure_cache_artifacts_registered()
artifacts_by_key = {}
cache_info = CacheInfo()
for artifact in chain(*artifacts.values()):
if artifact.type() == "autotune":
# Populate autotune cache artifacts
artifact.populate_cache()
else:
artifacts_by_key[artifact.key] = artifact
cache_info.add(artifact)
from torch._dynamo.package import _BackendId, DynamoCache
for dynamo_entry in artifacts["precompile_dynamo"]:
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
cache_entry = dynamo_entry.after_deserialization()
# Grab backends from the dynamo cache entry
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
for id_ in backends:
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
artifact = artifacts_by_key[id_]
assert isinstance(artifact, PrecompileCacheArtifact)
backend_content[id_] = artifact
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
return cache_info
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
BundledAOTAutogradCacheArtifact,
)