Provide access to the cudaGraph_t underlying a CUDAGraph. (#155164)

There are a few considerations here:

1. A user might want to modify the cudaGraph_t either during the stream capture or after the stream capture (but before instantiation). This draft implements modification after stream capture only, though support could be added for modification during stream capture by applying
https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR391-R404

2. Previously, the cudaGraph_t would be destroyed before the end of capture_end() unless the user had previously called enable_debug_mode(). There is no way to implement this correctly without removing this restriction, or forcing the user to always call enable_debug_mode(). However, enable_debug_mode() is a confusing API (despite being an instance method, it would modify a static global variable; thus, putting one CUDAGraph object into debug mode puts all of them into debug mode, which is not acceptable in my opinion). Therefore, I made enable_debug_mode() into a no-op. This means that the CPU memory usage will increase after this change. I think this is likely to be fine.

3. No python bindings yet. These should be easy to add. It is probably worthwhile to take some time to make sure that the returned cudaGraph_t can be converted into the cuda-python cudaGraph_t in a reasonable, hopefully type-safe, manner (but without making cuda-python a dependency of pytorch), since I imagine most users will use the pip cuda-python package to make modifications.

4. There are two foot guns:

   a. The cudaGraph_t returned by raw_cuda_graph() is not owned by the user, so it will be destroyed once the owning CUDAGraph is destroyed (or calls reset()).

   b. The following seuquence won't work as intended:

```
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    foo()
g.replay()
raw_graph = g.raw_cuda_graph()
modify(raw_graph)
g.replay()
```

This won't work because the user must call instantiate() again after modifying cudaGraph_t. You could add a "safety" mechanism by traversing the cudaGraph_t to create a hash and seeing if the hash changes between calls to replay(), but this is likely way too expensive.

I think these two foot guns are probably okay given that this a bit of an experts' API.

Fixes #155106

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155164
Approved by: https://github.com/ngimel
This commit is contained in:
Daniel Galvez
2025-06-18 03:39:28 +00:00
committed by PyTorch MergeBot
parent 17b38b850e
commit 9ed0060225
8 changed files with 235 additions and 39 deletions

View File

@ -1505,6 +1505,26 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
)
if TEST_CUDA_PYTHON_BINDINGS:
def cuda_python_error_check(function_call_output):
"""Makes calls to cuda-python's cuda runtime functions more
pythonic by throwing an exception if they return a status
which is not cudaSuccess
"""
import cuda.bindings # type: ignore[import]
error, *others = function_call_output
if error != cuda.bindings.runtime.cudaError_t.cudaSuccess:
raise ValueError(f"CUDA failure! {error}")
else:
return tuple(others)
else:
cuda_python_error_check = None # type: ignore[assignment]
def allocator_option_enabled_fn(allocator_config, _, option):
if allocator_config is None:
return False