Summary:
This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache.
This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries.
The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries().
Saving single entries and such with DynamoCache still works normally.
Test Plan:
All existing unit tests pass.
Rollback Plan:
Differential Revision: D82380307
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886
Approved by: https://github.com/zhxchen17
The goal of this PR stack is to be able to implement `aot_compile_module`, which AOT precompiles a torch.nn.Module.
Step 1 is a simple refactor to make CompileArtifacts itself the callable, which makes it easier to use directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162169
Approved by: https://github.com/zhxchen17
Adding a new feature to torch.compile(fullgraph=True) which "aot_compile" a function with given example inputs.
On user side it should look like:
```
def foo(x, y):
return x + y
compiled_fn = torch.compile(fullgraph=True).aot_compile(((torch.randn(3, 4), torch.randn(3, 4)), {}))
```
This is different from the traditional `torch.compile` workflow where compiled object will be a drop-in replacement for the original eager model:
```
tensor input -> torch.compile() -> tensor output (and populates the cache entry)
```
`aot_compile` will instead return a compiled function as result, and it's purely functional and doesn't populate the compile cache entry in dynamo:
```
tensor input -> aot_compile() -> compiled function
```
The aot compiled function will be savable and loadable on disk as well:
```
torch.compile(fullgraph=True).aot_compile(...).save_compiled_function('my/path')
compiled_fn = torch.compiler.load_compiled_function("my/path")
```
Right now we treat compiler backend as a blackbox and it needs to implement the following interface to make compile artifacts serialzable:
```
class SerializableCallable:
def save_compile_artifacts(): ....
def load_compile_artifacts(): ....
```
We haven't implemented this for inductor yet, but this shouldn't be an issue since we gate this feature through `torch._dynamo.config.aot_compile` (which defaults to False), and this will be left as follow up PR to the current PR.
Differential Revision: [D80914270](https://our.internmc.facebook.com/intern/diff/D80914270/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161383
Approved by: https://github.com/tugsbayasgalan
TL;DR: Cuts vLLM cudagraph collection from 80s -> 24s
Stop garbage collecting by default on every cudagraph recording. The old behavior can be re-enabled by setting `TORCH_CUDAGRAPH_GC=1` or the config `force_cudagraph_gc`.
We were previously garbage collecting at the beginning of each cudagraph
capture. vLLM collects 5427 graphs and most of those garbage collections weren't
actually collecting any memory (CPU or GPU). This changes it to not collect more
than every 10s so if we're capturing in a loop we don't burn all our cycles
looking for garbage.
(These number have a lot of variance from run to run but give the correct
general scale)
```
| calls | total | synchronize | gcs | collect | empty cache | sys freed | cuda freed |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
before | 5427 | 78s | 1.48s | 5427 | 53.22s | 1.21s | 145855 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
after | 5427 | 24s | 0s | 3 | 1.53s | 0.84s | 592 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
```
total - this is the total time reported by vLLM's "Graph capturing finished" log.
The rest of these are measured in torch.cuda.graphs.graph.__enter__():
calls - number of times torch.cuda.graphs.graph.__enter__ was called
synchronize - this is the duration taken by the cuda.synchronize call
gcs - number of times gc.collect was called
collect - this is the duration taken by the gc.collect call
empty cache - this is the duration taken by the torch.cuda.empty_cache call
sys freed - the number of bytes reported freed by gc.collect
cuda freed - the number of bytes reported freed by torch.cuda.memory_reserved
So it seems like the heavy lifting is done by torch.cuda.empty_cache() which is
fairly quick.
Cudagraph results from the TorchInductor Performance DashBoard (this is from the original version using the GC clock so the real results will be slightly better than this):
<img width="1494" height="382" alt="image" src="https://github.com/user-attachments/assets/69b705ef-47ce-4b6e-9733-1ec941cad93d" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158193
Approved by: https://github.com/ngimel
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo
This PR adds strict typing support to an important file in dynamo, `decorators.py`
NOTE: Untyped fns are because there is a conflict with `__init__.py` in compiler so we can't type these at this time
Running
```
mypy torch/_dynamo/decorators.py --linecount-report /tmp/coverage_log
```
| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main | 209 | 908 | 23.02% | 9 | 39 | 23.08% |
| This PR | 870 | 943 | 100.00% | 36 | 39 | 100.00% |
| Delta | +661 | +35 | +76.98% | +27 | 0 | +76.92% |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158509
Approved by: https://github.com/williamwen42
This PR adds a new config option, `caching_precompile`, and a `DynamoCache`, which loads and saves Dynamo Cache entries automatically. It also hooks up DynamoCache to PrecompileContext, so that we can save multiple cache entries.
When this configuration is turned on, we:
- Automatically create and initialize a CompilePackage on every torch.compile
- Automatically use BundledAutogradcache
- Automatically save the CompilePackage entry to DynamoCache after every compile
You can also use PrecompileContext.serialize() to manually serialize a full object.
I've added unit tests to exhibit this behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155913
Approved by: https://github.com/zhxchen17
This PR adds a new config option, `caching_precompile`, and a `DynamoCache`, which loads and saves Dynamo Cache entries automatically. It also hooks up DynamoCache to PrecompileContext, so that we can save multiple cache entries.
When this configuration is turned on, we:
- Automatically create and initialize a CompilePackage on every torch.compile
- Automatically use BundledAutogradcache
- Automatically save the CompilePackage entry to DynamoCache after every compile
You can also use PrecompileContext.serialize() to manually serialize a full object.
I've added unit tests to exhibit this behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155913
Approved by: https://github.com/zhxchen17
This PR implements a basic interface and test for PrecompileContext, a special CacheArtifactManager specifically designed for precompile. The job of a PrecompileContext is to record things precompile needs as torch is compiling, dump it all into bytes, and then stitch it back together into a cache of callables.
## Why use CacheArtifactManager?
Precompile needs a way to record various serializable data as torch is compiling. CacheArtifactManager already does this today pretty well, handling a lot of serialization and cache information. So we're reusing a bunch of that infrastructure directly.
## How is it different from CacheArtifactManager?
Unlike regular CacheArtifactManager, PrecompileContext needs to be able to take the recorded artifacts and stitch them together after deserialization, to create a single working callable.
Since PrecompileContext doesn't need the cache keys, the "key" field of PrecompileArtifacts can be used for metadata relating to how to stitch the individual functions being compiled together into a full callable. For example, on a given dynamo compile, if there are multiple functions (via graph breaks or recompiles) being compiled, MegaCache would organize it like so:

Whereas we'd visualize PrecompileContext's result like so:

For now, we just handle eager mode; in the diff above, I'll hook up the other backend artifacts from PrecompileContext.
After this PR, precompile consists of three main interfaces:
### CompilePackage
- Everything needed to run one torch.compile'd function (including graph breaks)
- `__init__(fn, cache_entry)` Initializes with a DynamoCacheEntry
- `install(backends)` load precompile artifacts into function's dynamo state with a dictionary of backends
- `cache_entry()` return a serializable cache entry to save
### DynamoStore
- Responsible for tracking CompilePackages on disk (and/or in memory)
- `load_package(path)`: load a package given a torch compiled function and a path to the cache artifact
- `save_package(package, path): Save a CompiledPackage to a path. Calls PrecompileContext to grab backend data
- `record_package(package)`: Record a package to PrecompileContext (for global serialization/deserialization)
### PrecompileContext
- Overarching context for serializing and deserializing precompile artifacts. Supports **global** and **local** setups.
- `serialize()`: (Global) serializes all artifacts in PrecompileContext into bytes
- `populate_caches(bytes)`: (Global) takes serialized bytes and puts them into DynamoStore (TODO)
- `serialize_artifact_by_key(key)`: (Local) serialize a single artifact by its cache key
<img width="1455" alt="image" src="https://github.com/user-attachments/assets/99b61330-7607-4763-bdbc-85b366e82cdd" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154415
Approved by: https://github.com/zhxchen17
ghstack dependencies: #155118
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
This PR adds standalone_compile API that does precompilation via caching to support vLLM use case in the short term while we work on the longer term precompilation solution.
```
standalone_compile(gm, example_inputs, options) -> CompiledArtifact
CompiledArtifact.save(path, format: binary|unpacked = binary)
CompiledArtifact.load(path, format: binary|unpacked = binary)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150670
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
Right now we are susceptive to a race condition where if the torch.compiler.config is not implicitly import via dynamo/builder.py, we will throw an error when trying to set compiler configs. This fixes it by including config in `__all__`.
Previous
```
>>> import torch
>>> torch.compiler.config.dynamic_sources = "L['kwargs']['float_features']"
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'torch.compiler' has no attribute 'config'
>>> torch.compiler.config.dynamic_sources =
"L['kwargs']['float_features']"
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'torch.compiler' has no attribute 'config'
```
Now
```
>>> import torch
>>> torch.compiler.config.dynamic_sources = "L['kwargs']['float_features']"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148978
Approved by: https://github.com/bdhirsh, https://github.com/laithsakka
While using save_cache_artifacts on internal workloads, we have noticed that repeatedly calling this function after every batch is incredibly expensive. This PR significantly speeds up this function call by opting out of pickle and redesigning serialization algorithm.
Essentially what we want is to be able to call serialize many times without incurring costs from scratch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148227
Approved by: https://github.com/jamesjwu
ghstack dependencies: #148226
This PR introduces the ability to whitelist sources as dynamic. This is particularly useful for large models with graph breaks, as you can keep the dynamism across graph breaks since source names stay consistent. Additionally you can use this to mark ints as dynamic.
NB: I intentionally didn't complicate the interface by supporting specification of per dimension dynamism. There is virtue in keeping true to the standard way of representing sources (eg. L['x']). If we find in practice that we need more more fine grained control, we can explore further affordances at that time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147979
Approved by: https://github.com/Mingming-Ding
This PR essentially introduces two new APIs
* torch.compiler.save_cache_artifacts
* torch.compiler.load_cache_artifacts
which aim to create a mega cache experience where the user can start collecting cache artifacts, and later call the save API to fetch them. In the next attempt, the user can "hot load" the cache artifacts via the load function.
This bundling approach reduces the need to rely on porting individual files one by one, or relying on many network requests.
Note that these APIs CANNOT log to structured logging as these functions will be called before and after compilation, as opposed to during compilation. Due to this limitation, the API returns a struct that the user can log with.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143341
Approved by: https://github.com/jansel
We added an is_export flag under torch.compiler.is_exporting. This comes handy when we try to do some special logic in user-level and system-level (e.g. in upper of the stack).
In increasing-scope:
- `_is_fx_tracing` is set to True when we use under symbolic_trace or make_fx.
- `is_exporting` is set to True when we're doing strict or non-strict export, which internally has a step that calls make_fx and set _is_fx_tracing to be True.
- `is_compiling` is set to True when we're either doing strict, non-strict export or torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142425
Approved by: https://github.com/avikchaudhuri