Compare commits

...

1 Commits

Author SHA1 Message Date
f28f0c7031 Add support for conditional nodes in cuda graphs.
This allows torch.cond and torch.while_loop to be captured in a single
cuda graph. This is done by manually inserting conditional IF nodes
and conditional WHILE nodes during stream capture. This approach is
discussed here:

https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes

Previously, data-depenent control flow would force the usage of cuda
graph trees, since data-dependent control flow was done on the
CPU. Now, data-dependent control flow can be done on the GPU.

This work depends upon CUDA 12.4, since cuda graph conditional nodes
were introduced in CUDA 12.4.

This works only with torch.compile(..., backend="eager") and
torch.compile(..., backend="cudagraphs") backends right now. Notably,
there is no inductor support at this time!

Conditional nodes for cuda graphs were first experimented with in
https://arxiv.org/abs/2406.03791 . While this paper showed strong
improvements in data-dependent workloads that were very CPU-overhead
bound, the next place to look for improvement is horizontal and
vertical kernel fusion, which can eventually be enabled automatically
once conditional nodes work with backends like inductor. This PR is
the first step towards that. I also expect this work to benefit more
sophisticated models like autoregressive decoding of LLMs, at least
for users that are using static shape kv-caches.

This is work done by @tingyangk and me.

We have a sophisticated example of RNN-T greedy decoding (the
algorithm discussing in the paper) working with this new feature here:
975a80673e (diff-2c2a72c9a5392d4a6ea5149fea3ce7900b9fd2c630e460bbab94547379553ceaR376)
2025-11-03 21:52:32 -08:00
15 changed files with 844 additions and 11 deletions

View File

@ -8,8 +8,49 @@
namespace at::cuda {
void external_stream_deleter(cudaStream_t* stream) {
if (stream != nullptr) {
AT_CUDA_CHECK(cudaStreamDestroy(*stream));
delete stream;
}
}
namespace {
UniquePtrExternalCudaStream create_external_stream() {
// From:
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g793d7d4e474388ddfda531603dc34aa3
// "Capture must be ended on the same stream in which it was initiated, and it
// may only be initiated if the stream is not already in capture mode."
// Since pytorch uses a pool of 32 pre-allocated cuda streams,
// should a user nest 32 conditional nodes, there would be an error
// for the 32nd node, since that node's stream would already be in
// capture mode. The easiest solution is to handle stream creation
// and deletion ourselves.
// we use cudaStreamNonBlocking because every default cuda stream in
// pytorch uses that flag for all streams used for stream capture
// (see kDefaultFlags in CUDAStream.cpp). This would need to be kept
// in sync, should that ever change. Or kDefaultFlags needs to be
// exposed in a header file.
auto stream_ptr = std::make_unique<cudaStream_t>();
AT_CUDA_CHECK(
cudaStreamCreateWithFlags(stream_ptr.get(), cudaStreamNonBlocking));
return UniquePtrExternalCudaStream(
stream_ptr.release(), external_stream_deleter);
}
} // anonymous namespace
static bool _cuda_graphs_debug = false;
// To support stream capture across multiple threads, we use a global
// hashmap mapping cuda stream capture IDs to CUDAGraph objects. This
// was originally a thread_local std::stack<CUDAGraph*>, but that was
// not acceptable since stream capture does span threads in certain
// circumstances (in particular, during autograd).
static std::mutex _currently_capturing_graphs_mutex;
static ska::flat_hash_map<CaptureId_t, CUDAGraph*> _currently_capturing_graphs;
MempoolId_t graph_pool_handle() {
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
@ -55,11 +96,13 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) {
cuda_gen->register_graph(this);
}
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_mode) {
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
"To capture a new graph, create a new instance.");
capture_mode_ = capture_mode;
// default generator is always registered
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
std::nullopt, cuda::detail::getDefaultCUDAGenerator());
@ -97,12 +140,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
// Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
// autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator
// due to the capture status being updated _after_ a capture had already started.
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
});
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter());
// cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See
@ -113,6 +151,10 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
_currently_capturing_graphs.emplace(capture_id_, this);
}
}
void CUDAGraph::capture_end() {
@ -123,6 +165,14 @@ void CUDAGraph::capture_end() {
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
TORCH_CHECK(
_currently_capturing_graphs.count(capture_id_),
"capture_end() called before capture_begin().");
_currently_capturing_graphs.erase(capture_id_);
}
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
TORCH_CHECK(graph_ != nullptr, "Invalid capture.");
@ -132,6 +182,19 @@ void CUDAGraph::capture_end() {
wholegraph_increments = generator_state->capture_epilogue();
}
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
bool any_wholegraph_increments_nonzero = false;
for (auto& [generator_state, wholegraph_increments] : captured_generator_states_) {
if (wholegraph_increments != 0) {
any_wholegraph_increments_nonzero = true;
}
}
if (any_wholegraph_increments_nonzero && !descendent_graphs_.empty()) {
TORCH_WARN("You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph.");
}
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
size_t numCUDAGraphNodes = 0;
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes));
if (numCUDAGraphNodes == 0) {
@ -292,7 +355,7 @@ void CUDAGraph::reset() {
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
MempoolId_t CUDAGraph::pool() {
TORCH_CHECK(capture_ended_,
TORCH_CHECK(capture_ended_,
"Called CUDAGraph::pool() without a preceding successful capture.");
return mempool_id_;
}
@ -317,4 +380,163 @@ CUDAGraph::~CUDAGraph() {
#endif
}
CUDAGraph* CUDAGraph::get_currently_capturing_graph() {
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
cudaStreamCaptureStatus status{};
CaptureId_t current_capture_id = -1;
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &current_capture_id));
TORCH_CHECK(
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive,
"The current stream is not currently capturing.");
TORCH_CHECK(
_currently_capturing_graphs.count(current_capture_id),
"get_currently_capturing_graph() can be used only between capture_begin() and capture_end(). Did you use a stream without making it depend upon the original stream used for capture?");
return _currently_capturing_graphs.at(current_capture_id);
}
void CUDAGraph::begin_capture_to_if_node(
const at::Tensor& scalar_cuda_pred_tensor) {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
TORCH_CHECK(
!has_graph_exec_,
"begin_capture_to_if_node() must be called before capture_begin()");
cudaStreamCaptureStatus status{};
cudaGraph_t currently_capturing_graph{};
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(), &status, nullptr, &currently_capturing_graph));
TORCH_CHECK(
status == cudaStreamCaptureStatusActive,
"capture_begin() must be called before begin_capture_to_if_node()");
cudaGraphConditionalHandle handle{};
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
&handle, currently_capturing_graph, 0, 0));
set_conditional_handle(handle, scalar_cuda_pred_tensor);
const cudaGraphNode_t* dependencies{};
size_t num_dependencies = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(),
&status,
nullptr,
&currently_capturing_graph,
&dependencies,
&num_dependencies));
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
cudaGraphNodeParams params{};
params.type = cudaGraphNodeTypeConditional;
params.conditional.handle = handle;
params.conditional.type = cudaGraphCondTypeIf;
params.conditional.size = 1;
cudaGraphNode_t cond_node{};
AT_CUDA_CHECK(cudaGraphAddNode(
&cond_node,
currently_capturing_graph,
dependencies,
num_dependencies,
&params));
cudaGraph_t if_node_child_graph = params.conditional.phGraph_out[0];
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
UniquePtrExternalCudaStream child_stream = create_external_stream();
conditional_graph_capture_streams_ids_.push(-1);
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_child_allocate_filter());
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
*child_stream, if_node_child_graph, nullptr, nullptr, 0, capture_mode_));
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
// We need to get the raw_stream here before emplace() to prevent
// std::move(child_stream) from potentially executing before
// *child_stream.
cudaStream_t raw_stream = *child_stream;
conditional_node_streams_.emplace(
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
std::move(child_stream));
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
_currently_capturing_graphs.emplace(
conditional_graph_capture_streams_ids_.top(), this);
}
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
return;
#endif
}
void CUDAGraph::end_capture_to_conditional_node() {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
CaptureId_t capture_id = conditional_graph_capture_streams_ids_.top();
TORCH_CHECK(
_currently_capturing_graphs.count(capture_id),
"capture_end() called before capture_begin().");
_currently_capturing_graphs.erase(capture_id);
}
CUDAStream stream = conditional_node_streams_.top().first.current_stream();
cudaGraph_t graph{};
AT_CUDA_CHECK(cudaStreamEndCapture(stream.stream(), &graph));
descendent_graphs_.push_back(graph);
conditional_node_streams_.pop();
conditional_graph_capture_streams_ids_.pop();
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
if (conditional_graph_capture_streams_ids_.empty()) {
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_allocate_filter());
} else {
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_child_allocate_filter());
}
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
#endif
}
std::function<bool(cudaStream_t)> CUDAGraph::create_allocate_filter() {
return [this](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
};
}
std::function<bool(cudaStream_t)> CUDAGraph::create_child_allocate_filter() {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
return [&current_capture_id = conditional_graph_capture_streams_ids_.top()](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id{};
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == current_capture_id;
};
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
return std::function<bool(cudaStream_t)>();
#endif
}
} // namespace at::cuda

View File

@ -0,0 +1,30 @@
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
namespace at::cuda {
namespace {
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
__global__ void set_conditional_handle_kernel(
cudaGraphConditionalHandle handle,
const bool* value) {
cudaGraphSetConditional(handle, *value);
}
#endif
}
void CUDAGraph::set_conditional_handle(
cudaGraphConditionalHandle handle,
const Tensor& scalar_cuda_pred_tensor) {
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
set_conditional_handle_kernel<<<1, 1, 0, getCurrentCUDAStream()>>>(
handle, scalar_cuda_pred_tensor.const_data_ptr<bool>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
AT_ERROR("not allowed");
return;
#endif
}
} // namespace at::cuda

View File

@ -4,9 +4,20 @@
#include <c10/core/Device.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>
#include <limits>
#include <stack>
#if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
// this type is not defined until CUDA 12.4, but we use it as a
// parameter type and return type in some below functions, so we give
// it the same definition as in CUDA 12.4.
typedef unsigned long long cudaGraphConditionalHandle;
#endif // defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
namespace at {
struct Generator;
@ -15,6 +26,9 @@ struct CUDAGeneratorState;
namespace cuda {
using UniquePtrExternalCudaStream =
std::unique_ptr<cudaStream_t, void (*)(cudaStream_t*)>;
// Standalone way to get a unique mempool id usable as a pool=... argument
// to CUDAGraph::capture_begin
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
@ -23,6 +37,26 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph(bool keep_graph=false);
~CUDAGraph();
// Copy and move constructors and assignments are disabled. These
// were disabled because pybind11 believed that CUDAGraph was copy
// constructable because
// pybind11::is_copy_constructible<CUDAGraph>::value originally
// evaluated to true. However, it cannot generate a copy constructor
// because CUDAGeneratorState, one of CUDAGraph's members, is an
// incomplete type unless CUDAGeneratorImpl.h is included. However,
// that would create a circular dependency between
// CUDAGeneratorImpl.h and CUDAGraph.h. Disabling the copy and move
// constructors is the most straightforward way to prevent pybind11
// from trying to generate default implementations of them.
//
// We needed pybind11 to return a reference to a CUDAGraph as part
// of wrapping CUDAGraph::get_currently_capturing_graph, which
// unearthed the above problem.
CUDAGraph(const CUDAGraph&) = delete;
CUDAGraph& operator=(const CUDAGraph&) = delete;
CUDAGraph(CUDAGraph&& other) = delete;
CUDAGraph& operator=(CUDAGraph&& other) = delete;
// See Note [Explicit Registration of Generators to the CUDA Graph]
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
void register_generator_state(const at::Generator& generator);
@ -39,6 +73,17 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
cudaGraph_t raw_cuda_graph();
cudaGraphExec_t raw_cuda_graph_exec();
static CUDAGraph* get_currently_capturing_graph();
void begin_capture_to_if_node(const Tensor& scalar_cuda_pred_tensor);
void end_capture_to_conditional_node();
static void set_conditional_handle(
cudaGraphConditionalHandle handle,
const Tensor& scalar_cuda_pred_tensor);
private:
std::function<bool(cudaStream_t)> create_allocate_filter();
std::function<bool(cudaStream_t)> create_child_allocate_filter();
protected:
cudaGraph_t graph_ = nullptr;
cudaGraphExec_t graph_exec_ = nullptr;
@ -89,6 +134,14 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
bool keep_graph_;
cudaStreamCaptureMode capture_mode_{};
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
std::stack<std::pair<at::cuda::CUDAStreamGuard, UniquePtrExternalCudaStream>>
conditional_node_streams_;
std::stack<CaptureId_t> conditional_graph_capture_streams_ids_;
std::vector<cudaGraph_t> descendent_graphs_;
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
};
} // namespace cuda

View File

@ -1325,6 +1325,10 @@ and you suspect its runtime is at least somewhat CPU-limited.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
.. _cudaGraphLaunch:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
.. _issue 144787:
https://github.com/pytorch/pytorch/issues/144787#issuecomment-2606480564
.. _conditional nodes:
https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
PyTorch API
^^^^^^^^^^^
@ -1413,6 +1417,9 @@ Violating any of these will likely cause a runtime error:
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
* Dynamic control flow (based on CPU or GPU data) is prohibited, unless it is based on GPU data and implemented via higher order operator
torch.cond(). See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.
Violating any of these will likely cause silent numerical errors or undefined behavior:
@ -1421,7 +1428,6 @@ Violating any of these will likely cause silent numerical errors or undefined be
* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
* Every replay reads from and writes to the same (virtual) memory addresses.
* Dynamic control flow (based on CPU or GPU data) is prohibited.
* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
has the same size and layout in every replay.
* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
@ -1730,3 +1736,45 @@ If, in the live workload, your callables will run in an order that occasionally
or if they'll run concurrently, passing them as a tuple to a single invocation of
:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
:func:`~torch.cuda.make_graphed_callables` separately for each one.
.. _graph-data-dependent-control-flow:
Data Dependent Control Flow
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Data-dependent control flow can with cuda graphs if the control flow
is implemented using torch.cond(). If your function uses this
function, compiling it with the "cudagraphs" backend will enable
control flow in the resulting cuda graph via `conditional nodes`_.
Unfortunately, eager mode execution does not work due to reasons
described in `issue 144787`_. Support for inductor backend to
torch.compile is not available yet, but there is no fundamental
blocker.
An example of using the cudagraphs backend to torch.compile on code
using torch.cond is demonstrated below::
import torch
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, device="cuda", requires_grad=False)
pred = torch.tensor(False, device="cuda", requires_grad=False)
def foo(pred, x):
with torch.inference_mode():
return torch.cond(pred, true_fn, false_fn, [x])
# First call will run eager for warmup, second call will do graph
# capture followed by graph replay, third call and beyond will do
# just graph replay.
compiled_foo = torch.compile(foo, backend="cudagraphs")
for i in range(3):
y = compiled_foo(pred, x)
# will output x.sin()
y = compiled_foo(~pred, x)

View File

@ -10,7 +10,7 @@ For a longer background on CUDAGraphs, read [accelerating pytorch with CUDAGraph
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
- Control Flow is not possible
- Arbitrary Control Flow is not possible (However, control flow expressed via torch.cond() can be captured in a CUDA Graph. See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.)
- Kernels which trigger host to device syncs (such as .item()) errors
- All input arguments to kernels are fixed to what they were recorded
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change

View File

@ -2,12 +2,14 @@
import contextlib
import functools
import unittest
import warnings
import torch
import torch.utils._pytree as pytree
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import (
_fake_associative_scan,
associative_scan,
@ -34,12 +36,50 @@ from torch.testing._internal.common_utils import (
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@contextlib.contextmanager
def check_cudagraphs_not_skipped(test_case):
old_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
try:
yield
finally:
# reset before the assert, because otherwise the reset is skipped
new_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
counters["inductor"]["cudagraph_skips"] = old_cudagraph_skips
test_case.assertEqual(
counters["inductor"]["cudagraph_skips"], new_cudagraph_skips
)
def _check_compile_cudagraph(test_case, fn, args):
# test cudagraphs backend
cudagraphs_compiled_fn = torch.compile(fn, backend="cudagraphs")
# We run 3 times.
# This is what cuda graph trees does for the first 3 runs:
# 1) run in eager mode, for warmup.
# 2) do stream capture followed by graph replay.
# 3 and beyond) do graph replay
# So we need to get to iteration 3 to test all ways of running.
outputs = []
for i in range(3):
with check_cudagraphs_not_skipped(test_case):
outputs.append(
pytree.tree_map(
lambda x: x.clone() if isinstance(x, torch.Tensor) else x,
cudagraphs_compiled_fn(*args),
)
)
eager_res = fn(*args)
for output in outputs:
test_case.assertEqual(eager_res, output)
# TODO: pull these helpers from AOTAutograd later
def to_fun(t):
if isinstance(t, torch.Tensor):
@ -5473,6 +5513,24 @@ class GraphModule(torch.nn.Module):
torch.randn(2, 3),
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
def test_cond_traced_not_nested_cudagraphs(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(True).cuda()])
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(False).cuda()])
def test_while_loop_nested_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested"]
graphs = self._check_tracing(fn, inp)
@ -5889,6 +5947,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
f(x, torch.tensor(True), torch.tensor(True)),
)
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
_check_compile_cudagraph(
self,
f,
[x.cuda(), torch.tensor(True).cuda(), torch.tensor(True).cuda()],
)
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
@ -5902,6 +5967,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
def f_(x, y):
return cond(y, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
@ -5920,6 +5988,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
pred = torch.tensor(example_inputs[0].shape[0] == 1, device="cuda")
_check_compile_cudagraph(self, f_, [torch.ones(4, 5).cuda(), pred])
def test_cond_accepts_torch_function_as_inputs(self):
a = torch.randn(3, 4)
b = torch.randn(3, 4)
@ -6453,6 +6525,13 @@ def forward(self, arg0_1):
return (mul,)""",
)
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
_check_compile_cudagraph(
self,
f,
[x.cuda(), torch.tensor(False).cuda(), torch.tensor(False).cuda()],
)
def test_raise_error_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()
@ -9748,11 +9827,98 @@ class TestHopSchema(TestCase):
)
class DynamicCondModel(torch.nn.Module):
def __init__(self, input_size=16, hidden_size=64, output_size=10):
super().__init__()
self.fc1_0 = torch.nn.Linear(input_size, hidden_size)
self.fc1_1 = torch.nn.Linear(input_size, 32)
self.fc1_2 = torch.nn.Linear(32, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
def true_fn(x):
return self.fc1_0(x)
def false_fn(x):
x = self.fc1_1(x)
return self.fc1_2(x)
pred = torch.tensor(x.sum() > 0, device="cuda")
x = cond(pred, true_fn, false_fn, [x])
x = self.relu(x)
x = self.fc2(x)
return x
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowNN(TestCase):
def test_cond_in_NN(self):
model = DynamicCondModel().cuda()
x = torch.randn(16, device="cuda")
_check_compile_cudagraph(self, model, [x])
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowAndRNG(TestCase):
@parametrize("rng_func", ["custom_generator", "default_generator"])
def test_rng_with_conditional_nodes_warns(self, rng_func):
pred = torch.tensor(True, device="cuda")
x = torch.ones(10, dtype=torch.float32, device="cuda")
if rng_func == "custom_generator":
self.skipTest(
"randn() currently does not work with a generator argument in dynamo."
)
generator = torch.Generator("cuda")
def custom_generator(x):
return x + torch.randn(
*x.shape, generator=generator, dtype=x.dtype, device=x.device
)
rng_func = custom_generator
elif rng_func == "default_generator":
def default_generator(x):
return x + torch.randn(*x.shape, dtype=x.dtype, device=x.device)
rng_func = default_generator
def func(pred, x):
return torch.cond(pred, rng_func, lambda x: 2 * x, [x])
compiled_func = torch.compile(func, backend="cudagraphs")
with warnings.catch_warnings(record=True) as warning_objs:
for i in range(3):
compiled_func(pred, x)
# Warn first for conditional node warmup, warn second for the
# graph capture that we will actually use.
self.assertEqual(len(warning_objs), 2)
warning_message = "You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph." # noqa: B950
for warning in warning_objs:
self.assertTrue(warning_message in str(warning.message))
instantiate_parametrized_tests(TestHopSchema)
instantiate_parametrized_tests(TestControlFlowTraced)
instantiate_parametrized_tests(TestControlFlow)
instantiate_parametrized_tests(AssociativeScanTests)
instantiate_parametrized_tests(TestControlFlowAndRNG)
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,90 @@
# Owner(s): ["module: functorch"]
import unittest
import torch
from torch.testing._internal.common_utils import (
run_tests,
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
TestCase,
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowInCUDAGraphInitialization(TestCase):
# Duplicated from test_cuda_primary_ctx.py
CTX_ALREADY_CREATED_ERR_MSG = (
"Tests defined in TestControlFlowInCUDAGraphInitialization must be run in a process "
"where CUDA contexts are never created. Use either run_test.py or add "
"--subprocess to run each test in a different subprocess."
)
def setUp(self):
# Ensure context has not been created beforehand
self.assertFalse(
torch._C._cuda_hasPrimaryContext(0),
TestControlFlowInCUDAGraphInitialization.CTX_ALREADY_CREATED_ERR_MSG,
)
def _check_compile_cudagraphs(self, f, pred, *other_args):
f = torch.compile(f, backend="cudagraphs")
outputs = []
for p in [pred, torch.logical_not(pred)]:
for i in range(3):
outputs.append(f(pred, *other_args).clone())
# We compute the eager output only after running cudagraphs
# backend compiled function, in order to make sure that
# cudagraph trees warms up the conditional part of the code
# properly.
eager_output = f(pred, *other_args)
for output in outputs:
self.assertEqual(output, eager_output)
def test_cond_cudnn(self):
# Tests that cublasCreate() does not break stream capture
def f(pred, x, filters):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.nn.functional.conv1d(y, filters)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.randn(33, 16, 30, device="cuda")
filters = torch.randn(20, 16, 5, device="cuda")
self._check_compile_cudagraphs(f, pred, x, filters)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
def test_cond_stft(self):
# Tests that cufft plan creation does not break stream capture
def f(pred, x):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.stft(y, 512, return_complex=False)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.ones(1024 * 1024, device="cuda")
self._check_compile_cudagraphs(f, pred, x)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
if __name__ == "__main__":
run_tests()

View File

@ -285,6 +285,7 @@ RUN_PARALLEL_BLOCKLIST = [
"test_autograd_fallback",
"inductor/test_compiler_bisector",
"test_privateuseone_python_backend",
"functorch/test_control_flow_cuda_initialization",
] + FSDP_TEST
# Test files that should always be run serially with other test files,
@ -1277,6 +1278,7 @@ CUSTOM_HANDLERS = {
"distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
"distributed/rpc/test_share_memory": run_test_with_subprocess,
"distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
"functorch/test_control_flow_cuda_initialization": run_test_with_subprocess,
"doctests": run_doctests,
"test_ci_sanity_check_fail": run_ci_sanity_check,
"test_autoload_enable": test_autoload_enable,

View File

@ -2362,6 +2362,12 @@ class _CUDAGraph:
def debug_dump(self, debug_path: str) -> None: ...
def raw_cuda_graph(self) -> _int: ...
def raw_cuda_graph_exec(self) -> _int: ...
@staticmethod
def get_currently_capturing_graph() -> _CUDAGraph: ...
def begin_capture_to_if_node(self, scalar_cuda_pred_tensor): ...
def end_capture_to_conditional_node(self): ...
@staticmethod
def set_conditional_handle(handle, scalar_cuda_pred_tensor): ...
# Defined in torch/csrc/cuda/MemPool.cpp
class _MemPool:

View File

@ -0,0 +1,100 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from typing import Any, Generator
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
# TODO: Move this into torch/cuda/graphs.py
class CUDAGraphCaptureControlFlowOpDispatchMode(TorchDispatchMode):
def __init__(
self,
) -> None:
self.supports_higher_order_operators = True
super().__init__()
def __torch_dispatch__(
self,
func,
types,
args=(),
kwargs=None,
):
if func is torch.ops.higher_order.cond:
# Re-enter the mode to support nested conditionals
with self:
return if_else_node(*args)
kwargs = {} if kwargs is None else kwargs
return func(*args, **kwargs)
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
def __init__(
self,
) -> None:
super().__init__()
self.supports_higher_order_operators = True
self.capture_stream = torch.cuda.Stream()
def __torch_dispatch__(
self,
func,
types,
args=(),
kwargs=None,
):
kwargs = {} if kwargs is None else kwargs
with torch.cuda.graphs.thread_cuda_stream_capture_mode(
torch.cuda.cudart().cudaStreamCaptureMode.Relaxed
):
return func(*args, **kwargs)
def _is_boolean_scalar_cuda_tensor(pred: Any) -> bool:
return (
isinstance(pred, torch.Tensor)
and pred.size() == torch.Size([])
and pred.dtype == torch.bool
and pred.is_cuda
)
@contextmanager
def _if_body(pred: torch.Tensor) -> Generator[None, None, None]:
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
current_cuda_graph.begin_capture_to_if_node(pred)
try:
yield
finally:
current_cuda_graph.end_capture_to_conditional_node()
def if_else_node(pred: torch.Tensor, true_fn, false_fn, operands):
if not pred.is_cuda:
raise ValueError(
"Conditions must be on a cuda device to use conditional node in cuda graphs"
)
# if-else is not supported yet in CUDA 12.4. Therefore, we use two if conditions, where one evaluates !pred
outs = []
for lazy_pred, fn in [
(lambda: pred, true_fn),
(lambda: torch.logical_not(pred), false_fn),
]:
with _if_body(lazy_pred()):
outs.append(fn(*operands))
# Copy these two outputs into a new output buffer. Well,
# actually, what we would like is to be able to merge these two
# tensors into the same tensor... Is there an obvious way to do
# that?
if len(outs) == 2:
for if_out, else_out in zip(
pytree.tree_iter(outs[0]), pytree.tree_iter(outs[1])
):
if_out.copy_(else_out)
assert len(outs) == 2
return outs[0]

View File

@ -57,6 +57,10 @@ from torch import Tensor
from torch._dynamo.callback import CallbackTrigger
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state
from torch._higher_order_ops.cudagraph_conditional_nodes import (
ControlFlowOpWarmupDispatchMode,
CUDAGraphCaptureControlFlowOpDispatchMode,
)
from torch._inductor.compile_fx import (
align_inputs_from_check_idxs,
copy_misaligned_inputs,
@ -682,6 +686,7 @@ class CUDAWarmupNode:
_use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
),
ControlFlowOpWarmupDispatchMode(),
get_history_recording(),
):
out = self.wrapped_function.model(new_inputs)
@ -1275,6 +1280,7 @@ class CUDAGraphNode:
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
),
CUDAGraphCaptureControlFlowOpDispatchMode(),
get_history_recording(),
):
static_outputs = model(inputs)

View File

@ -112,5 +112,25 @@ void THCPGraph_init(PyObject* module) {
// compile error.
return reinterpret_cast<uintptr_t>(graph_exec);
},
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def_static(
"get_currently_capturing_graph",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::get_currently_capturing_graph),
py::return_value_policy::reference)
.def(
"begin_capture_to_if_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::begin_capture_to_if_node),
py::arg("scalar_cuda_pred_tensor"))
.def(
"end_capture_to_conditional_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::end_capture_to_conditional_node))
.def_static(
"set_conditional_handle",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::set_conditional_handle),
py::arg("handle"),
py::arg("scalar_cuda_pred_tensor"));
}

View File

@ -120,6 +120,22 @@ void initCudartBindings(PyObject* module) {
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return {device_free, device_total};
});
py::enum_<cudaStreamCaptureMode>(
cudart,
"cuda"
"StreamCaptureMode")
.value("Global", cudaStreamCaptureModeGlobal)
.value("ThreadLocal", cudaStreamCaptureModeThreadLocal)
.value("Relaxed", cudaStreamCaptureModeRelaxed);
cudart.def(
"cuda"
"ThreadExchangeStreamCaptureMode",
[](cudaStreamCaptureMode mode) -> cudaStreamCaptureMode {
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
return mode;
});
}
} // namespace torch::cuda::shared

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import gc
import typing
from collections.abc import Callable
@ -272,6 +273,30 @@ class graph:
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
@contextlib.contextmanager
def _graph_no_gc(cuda_graph, pool, stream, capture_error_mode):
"""This is an internal function used to do stream capture without
calling torch.cuda.synchronize(), gc.collect(), and
torch.cuda.empty_cache(). Unfortunately, cudagraph trees runs its
eager warmup inside of the context manager
_use_cuda_memory_pool_manager(), which makes captures_underway in
CUDACachingAllocator.cpp non-empty. We need this in order to
warmup conditional higher order operators, like torch.cond() and
torch.while_loop(). torch.cuda.empty_cache() will fail if
captures_underway is non-empty. Removing torch.cuda.synchronize()
and gc.collect() is not strictly speaking required, but they are
expensive an unnecessary operations.
"""
stream_ctx = torch.cuda.stream(stream)
pool = () if pool is None else (pool,)
with stream_ctx:
cuda_graph.capture_begin(*pool, capture_error_mode=capture_error_mode)
try:
yield
finally:
cuda_graph.capture_end()
@overload
def make_graphed_callables(
callables: _ModuleOrCallable,
@ -610,3 +635,46 @@ def make_graphed_callables(
return ret[0]
return tuple(ret)
@contextlib.contextmanager
def thread_cuda_stream_capture_mode(new_mode):
r"""Changes current thread's stream capture mode to `new_mode` upon __enter__ and resets the mode upon __exit__.
The only documentation on a thread's stream capture mode is here:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
However, it is a little bit inadequate, so here is a more in-depth description.
Both CPU threads and capturing cuda streams have a capture mode. A
cuda stream's capture mode is set at cudaStreamBeginCapture() and
can never be changed. Meanwhile all CPU threads start with a capture
mode of cudaStreamCaptureModeGlobal, which can be changed at any
time.
Whenever a thread executes an unsafe CUDA action while CUDA
streams are capturing, it follows the following logic to determine
whether to invalidate those streams:
if capture_mode_this_thread == cudaStreamCaptureModeRelaxed:
never invalidate any capturing cuda streams whatsoever.
elif capture_mode_this_thread == cudaStreamCaptureModeThreadLocal:
invalidate any cuda streams for which cudaStreamBeginCapture() was called by this
thread, except for streams whose capture mode is cudaStreamCaptureModeRelaxed.
elif capture_mode_this_thread == cudaStreamCaptureModeGlobal:
invalidate all cuda streams that are currently capturing on any thread,
except for streams whose capture mode is cudaStreamCaptureModeRelaxed and for
streams for which cudaStreamCaptureBegin() was called with
cudaStreamCaptureModeThreadLocal on a thread other than this one.
In practice, changed the current capture mode to
cudaStreamCaptureModeRelaxed in particular is helpful for enabling
developers to do "unsafe" things that we know are safe in our
case.
"""
cudart = torch.cuda.cudart()
old_mode = cudart.cudaThreadExchangeStreamCaptureMode(new_mode)
try:
yield
finally:
cudart.cudaThreadExchangeStreamCaptureMode(old_mode)

View File

@ -1567,6 +1567,12 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
)
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
TEST_CUDA_GRAPH_CONDITIONAL_NODES = TEST_CUDA_GRAPH and (
torch.version.cuda and (
(int(torch.version.cuda.split(".")[0]) >= 12 and int(torch.version.cuda.split(".")[1]) >= 4) or
(int(torch.version.cuda.split(".")[0]) >= 13)
)
)
TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12