TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Summary: Add time log to cudagraph, including `create deferred_cudagraphify wrapper`, `warmup`, `record`, and `checkpoint`.
Test Plan:
1. buck2 run fbcode//mode/opt //pytorch/benchmark:run -- resnet50 -d cuda -t train --inductor --pt2-triton-cudagraph
2. Found the result in [scuba table](https://fburl.com/scuba/pt2_compile_events/0oik8nu9).
{F1954034920}
Differential Revision: D65505659
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139818
Approved by: https://github.com/eellison
Type annotations for compile_fx.
- Some of the stuff here is pretty complicated (functions which return functions that take functions) so I bailed on those and used `Any` just to get the rest landed.
- There are also changes to type signatures in other files which I did just to let mypy know more about the types in compile_fx.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138033
Approved by: https://github.com/Skylion007
This PR mostly refactors by putting code into utils files so that they can be shared between codecache.py and compile_fx.py. Afterwards, it then changes compile_fx so that:
- When saving to FXGraphCache, we save onto the CompiledFXGraph all the necessary metadata for running post compile steps (realigning inputs, cudagraphification).
- When loading from FXGraphCache, we use the saved information directly, instead of calculating them from scratch.
What this does is make it so that `FXGraphCache.load()` is a perfect cache on compile_fx_inner, in that it **returns exactly what compile_fx_inner returns**. This also makes it possible for AOTAutogradCache, given a key to the fx graph cache and example inputs, to get back the full return value of compile_fx_inner.
## What's a post compile step?
We define a **post-compile** to be the set of actions that need to run after FXGraphCache either loads from the cache or misses and runs compilation. These steps include:
- Setting the tracing context's output strides
- Running cudagraphs if enabled
- Maybe realign inputs if cudagraphs didn't run
To run these steps, we save all the necessary metadata in CompiledFxGraph, and use them on a cache hit to reconstruct the object.
## Splitting cudagraphs work into pre/post compile
Cudagraphs does a lot of work on the input graph module to determine if cudagraphs can be enabled. This is the code that involves cudagraph_tests and stack traces. This will work in a world where we have access to the input graph module, but with AOTAutograd warm start, we won't have access to that information anymore. Therefore we can split cudagraphs work into two parts: on a cache miss (and therefore a full compile), we do the cudagraphs testing work, and save cudagraph_fail_reasons into the cache. Then on a cache hit, we know whether or not we can run cudagraphs, and if we can't, we can emit the correct error messages.
Implementation notes:
- We save `fx_kwargs` directly onto the CompiledFXGraph. `fx_kwargs` is already, by definition, part of the cache key, so this is safe to do when it comes to cache correctness.
- ^ Why do we do above even though FXGraphCache.load takes fx_kwargs as an argument? Because AOTAutogradCache **doesn't** have access to fx_kwargs: they're annoyingly encoded in the functools.partial() of the fw_compiler, so *only* inductor knows about these options. They're fully captured by the AOTAutogradCache key (since every key to fx_kwargs is either a global config, or a field that's deterministic based on an input graph module), but their values are still needed to run cudagraphs/postprocessing. Therefore, it's easier/safer to store it on the cached result.
- Willing to hear other approaches here if we think saving these extra fields is not reasonable, though I can't think of another way to do this that's less complicated to explain.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130572
Approved by: https://github.com/eellison
This PR refactors placeholders in cudagraphs to be serializable. We define a new PlaceholderInfo object which only has the necessary parts of placeholders for logging/debugging, and use that instead of `torch.fx.Node` directly. This allows us to then save PlaceholderInfo into the FXGraphCache/AOTAutogradCache later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130252
Approved by: https://github.com/eellison, https://github.com/masnesral
ghstack dependencies: #129384
This PR adds support to use expandable segments with private memory pools which should unblock using it with cuda graphs and cuda graph trees. Currently, the allocator silently avoids using expandable segments when allocating in a private pool due to checkpoint saving/restoring not meshing well with how we keep track of unmapped blocks.
The PR itself is pretty short, most of the logic for checkpointing and reapplying state for non-expandable segments transfers over without much work.
Expandable segments reserve a virtual address space of size equal to the amount of physical memory on the GPU. Every time we want to `malloc()` or `free()` memory in a memory pool with expandable segments turned on, we map/unmap pages of physical GPU memory under the hood to create a new block that we return to the caller. This is beneficial due to the fact that each memory pool functions as a single segment of memory with a contiguous block of memory addresses that can grow and shrink as needed, avoiding fragmentation from allocating multiple non-contiguous segments that may not be merged together.
The caching allocator handles this by creating an unmapped block for the entire reserved virtual address space at init, which is treated similarly to an unallocated block in a free pool. When callers call `malloc()`, it's split and mapped to create allocated blocks, and calling `free()` similarly caches and merges free blocks in a free pool to be used later. Expandable blocks are unmapped and returned back to Cuda when they are cleaned up, or when we hit an OOM and the allocator attempts to remap cached free blocks. The code paths to map, free, and unmap blocks in expandable segments is similar to that for normal blocks and does all the same work of updating stats on memory usage, moving blocks between active and free pools, and returning memory to Cuda.
With Cuda Graph Trees and private memory pools, we need the ability to take checkpoints of the current state of the memory allocator after each graph capture as well as reapplying the state before capturing a new graph after replaying a captured graph so that the new cuda graph capture has access to the state of the allocator at the point after replaying a previously captured graph so it can reuse empty blocks and allocate new ones.
As mentioned in a below comment, memory in a private pool is cached until the private pool is destroyed and allocations can only grow from extra graph captures, any freeing of memory would result in invalid memory addresses and would break cuda graphs.
One implementation detail to note for unmapped blocks with expandable segments is that unmapped blocks are kept track in a member variable `unmapped` of a `BlockPool`. `unmapped` is *not* part of the checkpointed state of the caching allocator and isn't restored when reapplying checkpoints since we never free/unmap memory back to cuda and is persisted across graph captures / replays.
Checkpointing the current state of the memory allocator works as expected with expandable segments. Checkpointing grabs the first block of every segment in the active and free pools of the private pool and traverses the linked list of blocks in the segment to capture the state of every segment, which is then saved and kept for when it is needed to be reapplied. For expandable blocks, the last block in every segment will be an unallocated unmapped block containing the remaining amount of unmapped memory at graph capture time, and this too is saved in the checkpoint.
Reapplying the checkpoints works by freeing all allocated blocks and merging them into a single block per segment, then for each segment, we manually split and allocate all blocks from the checkpoint and then free the blocks marked as unallocated in the checkpoint state. For expandable segments, we need to make some modifications to not split unmapped blocks and avoid manually mapping then freeing unmapped blocks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128068
Approved by: https://github.com/eqy, https://github.com/eellison
Summary:
CUDAGraph Trees previously relies on an assumption that static inputs (parameters and buffers) does not change tensor addresses across multiple function invocations. This assumption can be used to reduce the number of tensor copies to improve performance. We also use `check_static_inputs_are_stable()` to check whether this assumption holds at runtime.
While this assumption is True in most cases, we recently observe a few cases that this assumption is not valid:
- [Inline inbuilt nn modules](https://github.com/pytorch/pytorch/pull/126822): the same function (a nn module) is used in multiple places and different parameters and buffers are passed to this function with different tensor addresses
- Some user code changes tensor addresses of parameters/buffers. See [internal example]( https://www.internalfb.com/mlhub/pipelines/runs/mast/sw-935450288-OfflineTraining_08ba1cf0?job_attempt=1&version=0&env=PRODUCTION)
- Compiled Autograd may also pass parameters/buffers with different tensor addresses across runs.
Previous PR [#126822](https://github.com/pytorch/pytorch/pull/126822) (by @mlazos) allows detecting static tensor address changes during runtime and re-recording a cudagraph if that happened. However, if the same function is re-recorded too many times, it may introduce large overhead and hurt performance. This PR adds `torch._inductor.config.triton.cudagraph_max_recording` (=5) to fallback to eager if a function has been recorded more than `cudagraph_max_recording` times for a specific node in the CUDAGraph Trees.
A summary on how static tensor address changes are handled now:
- For each child node, check the assumption via `check_invariants`. If this holds, execute node with the assumption.
- If the assumption does not hold for all child nodes, re-record if the function_id has not been recorded too many times for the current_node.
- If the function_id has been re-recorded too many times, fallback to eager function and warning.
Test Plan: CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129349
Approved by: https://github.com/eellison
This PR adds support to use expandable segments with private memory pools which should unblock using it with cuda graphs and cuda graph trees. Currently, the allocator silently avoids using expandable segments when allocating in a private pool due to checkpoint saving/restoring not meshing well with how we keep track of unmapped blocks.
The PR itself is pretty short, most of the logic for checkpointing and reapplying state for non-expandable segments transfers over without much work.
Expandable segments reserve a virtual address space of size equal to the amount of physical memory on the GPU. Every time we want to `malloc()` or `free()` memory in a memory pool with expandable segments turned on, we map/unmap pages of physical GPU memory under the hood to create a new block that we return to the caller. This is beneficial due to the fact that each memory pool functions as a single segment of memory with a contiguous block of memory addresses that can grow and shrink as needed, avoiding fragmentation from allocating multiple non-contiguous segments that may not be merged together.
The caching allocator handles this by creating an unmapped block for the entire reserved virtual address space at init, which is treated similarly to an unallocated block in a free pool. When callers call `malloc()`, it's split and mapped to create allocated blocks, and calling `free()` similarly caches and merges free blocks in a free pool to be used later. Expandable blocks are unmapped and returned back to Cuda when they are cleaned up, or when we hit an OOM and the allocator attempts to remap cached free blocks. The code paths to map, free, and unmap blocks in expandable segments is similar to that for normal blocks and does all the same work of updating stats on memory usage, moving blocks between active and free pools, and returning memory to Cuda.
With Cuda Graph Trees and private memory pools, we need the ability to take checkpoints of the current state of the memory allocator after each graph capture as well as reapplying the state before capturing a new graph after replaying a captured graph so that the new cuda graph capture has access to the state of the allocator at the point after replaying a previously captured graph so it can reuse empty blocks and allocate new ones.
As mentioned in a below comment, memory in a private pool is cached until the private pool is destroyed and allocations can only grow from extra graph captures, any freeing of memory would result in invalid memory addresses and would break cuda graphs.
One implementation detail to note for unmapped blocks with expandable segments is that unmapped blocks are kept track in a member variable `unmapped` of a `BlockPool`. `unmapped` is *not* part of the checkpointed state of the caching allocator and isn't restored when reapplying checkpoints since we never free/unmap memory back to cuda and is persisted across graph captures / replays.
Checkpointing the current state of the memory allocator works as expected with expandable segments. Checkpointing grabs the first block of every segment in the active and free pools of the private pool and traverses the linked list of blocks in the segment to capture the state of every segment, which is then saved and kept for when it is needed to be reapplied. For expandable blocks, the last block in every segment will be an unallocated unmapped block containing the remaining amount of unmapped memory at graph capture time, and this too is saved in the checkpoint.
Reapplying the checkpoints works by freeing all allocated blocks and merging them into a single block per segment, then for each segment, we manually split and allocate all blocks from the checkpoint and then free the blocks marked as unallocated in the checkpoint state. For expandable segments, we need to make some modifications to not split unmapped blocks and avoid manually mapping then freeing unmapped blocks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128068
Approved by: https://github.com/zdevito, https://github.com/eqy
### Introduction/Problem
Today when dynamo traces a builtin nn module (nn.Linear for example) it will specially handle parameters of that module by storing them as constant attributes of the graph. This requires that dynamo guard on the ID of the NNModule because if the instance of the module changes, we need to retrace and recollect the new parameters as attributes of the graph. This creates a 1:1 compiled graph to cudagraph relationship.
With hierarchical compilation, dynamo will treat builtin nn modules like any other code. This reduces complexity and critically, if there are multiple identical layers in a model, we only need to compile one of those layers once, and reuse the same compiled artifact for each layer. This introduces a problem for the current approach to parameter handling. Since the parameters could now possibly change across calls to the compiled artifact, these need to be inputs to the graph instead of attributes. This introduces a problem for cudagraphs - previously cudagraphs was guaranteed that the parameters of builtin NN Modules would be constant across calls, but now since the compiled artifact needs to be agnostic to the actual instance of the NN module being used these parameter memory locations may vary. Previously cudagraphs simply copies varying inputs to cudagraph owned memory, but since the parameters are quite large, this is catastrophic for performance.
### Solution
To avoid this performance cliff, this PR allows cudagraphs to re-record a new cudagraph if only parameters change. Metadata about which arguments are parameters are propagated from AOT Autograd to compile_fx, and these indices are passed to cudagraphs. If these memory locations change, a new graph is recorded vs previously where this would be an error (because this previously should not happen). This enables a 1:many compiled graph to cudagraph relationship. Across similar modules we will re-record cudagraphs and dispatch the correct graph if parameter pointers match when the cudagraph is executed.
### Next steps (if needed)
It is theoretically possible that a user passes Parameters that change frequently as inputs to model code - if this is a common issue this design allows for dynamo to pass metadata indicating which parameters were created in a builtin NN Module context to only permit those parameters to have the multi-cudagraph behavior, but this PR does not implement this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126822
Approved by: https://github.com/eellison
ghstack dependencies: #126820, #126821
Summary: This diff adds more log for cudagraph when static input tensor mutates. For each placeholder whose static input tensor address mutates, we log its name, changed data pointer address, and the input stack trace. Since some placeholder may have empty stack trace, we find its first user with an non-empty stack trace and print this stack trace instead.
Test Plan: buck2 run fbcode//caffe2/test/inductor:cudagraph_trees -- --r test_static_inputs_address_mutation_log
Differential Revision: D57805118
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127145
Approved by: https://github.com/eellison
Fixes
> ERROR: expected to be in states [<TrainingState.FORWARD_BACKWARD: 2>] but current state is TrainingState.IDLE
Error that would occur when composing pt2 fsdp and cudagraphs. Cudagraphs caches output tensor impls in the fast path, so we were inadvertently accumulating multiple hooks on what should have been fresh allocations.
from code comment:
```
# this output represents a fresh allocated tensor.
# We return the same TensorImpl from run to run to avoid overhead.
# autograd.Function will reset the Autograd meta of output tensors
# as part of aot_autograd, but _backward_hooks are stored on tensors separately,
# so we need to manually reset hooks.
``
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126914
Approved by: https://github.com/awgu, https://github.com/xmfan
When we do cudagraph warmup, we record which outputs are in the cudagraph pool, so subsequently when we invoke a cudagraph and need to reclaim its memory we can free the prior run's outputs and make them error on access.
In warmup, we detect this by ignoring outputs which are an alias of an input that is not a prior output. We did this by checking data pointer. In very rare situations, a data pointer of a non cudagraph input might get reallocated to a cudagraph pool and causes us to ignore it.
This was happening with gpt-fast error with gemma 2 when coordinate_descent_tuning was set to False.
This updates so that we check aliasing with non-cudagraph inputs by looking at storage pointer..
Unrelated: saw very weird behavior where an output had the same data pointer as a supposedly live input but not the same cdata 🤔 I would think that is not possible.
```
out[0]._cdata in [ref()._cdata for ab in non_cudagraph_inps_storage_refs] # False
out[0].data_ptr() in [ref().data_ptr() for ab in non_cudagraph_inps_storage_refs] # True
```
Differential Revision: [D56607721](https://our.internmc.facebook.com/intern/diff/D56607721)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124981
Approved by: https://github.com/ezyang