mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 18:25:33 +08:00
Update cudaGraphInstantiate to use node priority flag
This commit is contained in:
@ -175,7 +175,7 @@ void CUDAGraph::instantiate() {
|
||||
// who prefer not to report error message through these arguments moving forward
|
||||
// (they prefer return value, or errors on api calls internal to the capture)
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, 0));
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, cudaGraphInstantiateFlagUseNodePriority));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
|
||||
#endif
|
||||
@ -184,7 +184,7 @@ void CUDAGraph::instantiate() {
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_,
|
||||
graph_,
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
||||
cudaGraphInstantiateFlagAutoFreeOnLaunch | cudaGraphInstantiateFlagUseNodePriority));
|
||||
}
|
||||
has_graph_exec_ = true;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user