mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Provide access to the cudaGraph_t underlying a CUDAGraph. (#155164)
There are a few considerations here: 1. A user might want to modify the cudaGraph_t either during the stream capture or after the stream capture (but before instantiation). This draft implements modification after stream capture only, though support could be added for modification during stream capture by applying https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR391-R404 2. Previously, the cudaGraph_t would be destroyed before the end of capture_end() unless the user had previously called enable_debug_mode(). There is no way to implement this correctly without removing this restriction, or forcing the user to always call enable_debug_mode(). However, enable_debug_mode() is a confusing API (despite being an instance method, it would modify a static global variable; thus, putting one CUDAGraph object into debug mode puts all of them into debug mode, which is not acceptable in my opinion). Therefore, I made enable_debug_mode() into a no-op. This means that the CPU memory usage will increase after this change. I think this is likely to be fine. 3. No python bindings yet. These should be easy to add. It is probably worthwhile to take some time to make sure that the returned cudaGraph_t can be converted into the cuda-python cudaGraph_t in a reasonable, hopefully type-safe, manner (but without making cuda-python a dependency of pytorch), since I imagine most users will use the pip cuda-python package to make modifications. 4. There are two foot guns: a. The cudaGraph_t returned by raw_cuda_graph() is not owned by the user, so it will be destroyed once the owning CUDAGraph is destroyed (or calls reset()). b. The following seuquence won't work as intended: ``` g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): foo() g.replay() raw_graph = g.raw_cuda_graph() modify(raw_graph) g.replay() ``` This won't work because the user must call instantiate() again after modifying cudaGraph_t. You could add a "safety" mechanism by traversing the cudaGraph_t to create a hash and seeing if the hash changes between calls to replay(), but this is likely way too expensive. I think these two foot guns are probably okay given that this a bit of an experts' API. Fixes #155106 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155164 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
17b38b850e
commit
9ed0060225
@ -1505,6 +1505,26 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
|
||||
|
||||
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
|
||||
|
||||
TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
|
||||
torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
|
||||
)
|
||||
|
||||
if TEST_CUDA_PYTHON_BINDINGS:
|
||||
def cuda_python_error_check(function_call_output):
|
||||
"""Makes calls to cuda-python's cuda runtime functions more
|
||||
pythonic by throwing an exception if they return a status
|
||||
which is not cudaSuccess
|
||||
"""
|
||||
import cuda.bindings # type: ignore[import]
|
||||
|
||||
error, *others = function_call_output
|
||||
if error != cuda.bindings.runtime.cudaError_t.cudaSuccess:
|
||||
raise ValueError(f"CUDA failure! {error}")
|
||||
else:
|
||||
return tuple(others)
|
||||
else:
|
||||
cuda_python_error_check = None # type: ignore[assignment]
|
||||
|
||||
def allocator_option_enabled_fn(allocator_config, _, option):
|
||||
if allocator_config is None:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user