Revert "Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)"

This reverts commit c7515da7b00de40942c83dc5856b6daec727e280.

Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
This commit is contained in:
PyTorch MergeBot
2025-02-13 18:04:26 +00:00
parent 65e8862b9a
commit 9a883007a2
22 changed files with 29 additions and 1145 deletions

View File

@ -347,7 +347,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::cuda::assertNotCapturing(
"Please ensure to utilize the CUDAGeneratorImpl::graphsafe_set_state method during capturing.");
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;

View File

@ -11,50 +11,9 @@
namespace at::cuda {
void external_stream_deleter(cudaStream_t* stream) {
if (stream != nullptr) {
cudaStreamDestroy(*stream);
delete stream;
}
}
namespace {
UniquePtrExternalCudaStream create_external_stream() {
// From:
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g793d7d4e474388ddfda531603dc34aa3
// "Capture must be ended on the same stream in which it was initiated, and it
// may only be initiated if the stream is not already in capture mode."
// Since pytorch uses a pool of 32 pre-allocated cuda streams,
// should a user nest 32 conditional nodes, there would be an error
// for the 32nd node, since that node's stream would already be in
// capture mode. The easiest solution is to handle stream creation
// and deletion ourselves.
// we use cudaStreamNonBlocking because every default cuda stream in
// pytorch uses that flag for all streams used for stream capture
// (see kDefaultFlags in CUDAStream.cpp). This would need to be kept
// in sync, should that ever change. Or kDefaultFlags needs to be
// exposed in a header file.
auto stream_ptr = std::make_unique<cudaStream_t>();
AT_CUDA_CHECK(
cudaStreamCreateWithFlags(stream_ptr.get(), cudaStreamNonBlocking));
return UniquePtrExternalCudaStream(
stream_ptr.release(), external_stream_deleter);
}
} // anonymous namespace
static bool _cuda_graphs_debug = false;
constexpr int kSynchronizeBusyWaitMillis = 10;
// To support stream capture across multiple threads, we use a global
// hashmap mapping cuda stream capture IDs to CUDAGraph objects. This
// was originally a thread_local std::stack<CUDAGraph*>, but that was
// not acceptable since stream capture does span threads in certain
// circumstances (in particular, during autograd).
static std::mutex _currently_capturing_graphs_mutex;
static ska::flat_hash_map<CaptureId_t, CUDAGraph*> _currently_capturing_graphs;
MempoolId_t graph_pool_handle() {
// Sets just the second value, to distinguish it from MempoolId_ts created from
// cudaStreamGetCaptureInfo id_s in capture_begin.
@ -118,13 +77,11 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) {
cuda_gen->register_graph(this);
}
void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_mode) {
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
"To capture a new graph, create a new instance.");
capture_mode_ = capture_mode;
// default generator is always registered
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
std::nullopt, cuda::detail::getDefaultCUDAGenerator());
@ -162,7 +119,12 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode
// Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
// autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator
// due to the capture status being updated _after_ a capture had already started.
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter());
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
});
// At this point, any NCCL watchdogs should be aware that we are in capture mode
// and therefore should not enqueue any additional work that could be event-queried.
@ -182,10 +144,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
_currently_capturing_graphs.emplace(capture_id_, this);
}
}
void CUDAGraph::capture_end() {
@ -196,14 +154,6 @@ void CUDAGraph::capture_end() {
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
TORCH_CHECK(
_currently_capturing_graphs.count(capture_id_),
"capture_end() called before capture_begin().");
_currently_capturing_graphs.erase(capture_id_);
}
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
TORCH_CHECK(graph_ != nullptr, "Invalid capture.");
@ -250,19 +200,6 @@ void CUDAGraph::capture_end() {
wholegraph_increments = generator_state->capture_epilogue();
}
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
bool any_wholegraph_increments_nonzero = false;
for (auto& [generator_state, wholegraph_increments] : captured_generator_states_) {
if (wholegraph_increments != 0) {
any_wholegraph_increments_nonzero = true;
}
}
if (any_wholegraph_increments_nonzero && !descendent_graphs_.empty()) {
TORCH_WARN("You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph.");
}
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
size_t numCUDAGraphNodes = 0;
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes));
if (numCUDAGraphNodes == 0) {
@ -363,7 +300,7 @@ void CUDAGraph::reset() {
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
MempoolId_t CUDAGraph::pool() {
TORCH_CHECK(has_graph_exec_,
TORCH_CHECK(has_graph_exec_,
"Called CUDAGraph::pool() without a preceding successful capture.");
return mempool_id_;
}
@ -388,250 +325,4 @@ CUDAGraph::~CUDAGraph() {
#endif
}
CUDAGraph* CUDAGraph::get_currently_capturing_graph() {
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
cudaStreamCaptureStatus status{};
CaptureId_t current_capture_id = -1;
auto stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &current_capture_id));
TORCH_CHECK(
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive,
"The current stream is not currently capturing.");
TORCH_CHECK(
_currently_capturing_graphs.count(current_capture_id),
"get_currently_capturing_graph() can be used only between capture_begin() and capture_end(). Did you use a stream without making it depend upon the original stream used for capture?");
return _currently_capturing_graphs.at(current_capture_id);
}
void CUDAGraph::begin_capture_to_if_node(
const at::Tensor& scalar_cuda_pred_tensor) {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
TORCH_CHECK(
!has_graph_exec_,
"begin_capture_to_if_node() must be called before capture_begin()");
cudaStreamCaptureStatus status{};
cudaGraph_t currently_capturing_graph{};
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(), &status, nullptr, &currently_capturing_graph));
TORCH_CHECK(
status == cudaStreamCaptureStatusActive,
"capture_begin() must be called before begin_capture_to_if_node()");
cudaGraphConditionalHandle handle{};
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
&handle, currently_capturing_graph, 0, 0));
set_conditional_handle(handle, scalar_cuda_pred_tensor);
const cudaGraphNode_t* dependencies{};
size_t num_dependencies = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(),
&status,
nullptr,
&currently_capturing_graph,
&dependencies,
&num_dependencies));
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
cudaGraphNodeParams params{};
params.type = cudaGraphNodeTypeConditional;
params.conditional.handle = handle;
params.conditional.type = cudaGraphCondTypeIf;
params.conditional.size = 1;
cudaGraphNode_t cond_node{};
AT_CUDA_CHECK(cudaGraphAddNode(
&cond_node,
currently_capturing_graph,
dependencies,
num_dependencies,
&params));
cudaGraph_t if_node_child_graph = params.conditional.phGraph_out[0];
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
UniquePtrExternalCudaStream child_stream = create_external_stream();
conditional_graph_capture_streams_ids_.push(-1);
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_child_allocate_filter());
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
*child_stream, if_node_child_graph, nullptr, nullptr, 0, capture_mode_));
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
// We need to get the raw_stream here before emplace() to prevent
// std::move(child_stream) from potentially executing before
// *child_stream.
cudaStream_t raw_stream = *child_stream;
conditional_node_streams_.emplace(
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
std::move(child_stream));
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
_currently_capturing_graphs.emplace(
conditional_graph_capture_streams_ids_.top(), this);
}
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
return;
#endif
}
cudaGraphConditionalHandle CUDAGraph::begin_capture_to_while_loop_node(
const at::Tensor& scalar_cuda_pred_tensor) {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
cudaStreamCaptureStatus status{};
cudaGraph_t currently_capturing_graph{};
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(), &status, nullptr, &currently_capturing_graph));
TORCH_CHECK(
status == cudaStreamCaptureStatusActive,
"capture_begin() must be called before begin_capture_to_while_loop_node()");
cudaGraphConditionalHandle handle{};
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
&handle, currently_capturing_graph, 0, 0));
set_conditional_handle(handle, scalar_cuda_pred_tensor);
const cudaGraphNode_t* dependencies{};
size_t num_dependencies = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
getCurrentCUDAStream(),
&status,
nullptr,
&currently_capturing_graph,
&dependencies,
&num_dependencies));
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
cudaGraphNodeParams params{};
params.type = cudaGraphNodeTypeConditional;
params.conditional.handle = handle;
params.conditional.type = cudaGraphCondTypeWhile;
params.conditional.size = 1;
cudaGraphNode_t cond_node{};
AT_CUDA_CHECK(cudaGraphAddNode(
&cond_node,
currently_capturing_graph,
dependencies,
num_dependencies,
&params));
cudaGraph_t while_node_child_graph = params.conditional.phGraph_out[0];
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
UniquePtrExternalCudaStream child_stream = create_external_stream();
conditional_graph_capture_streams_ids_.push(-1);
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_child_allocate_filter());
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
*child_stream,
while_node_child_graph,
nullptr,
nullptr,
0,
capture_mode_));
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
TORCH_INTERNAL_ASSERT(
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
// We need to get the raw_stream here before emplace() to prevent
// std::move(child_stream) from potentially executing before
// *child_stream.
cudaStream_t raw_stream = *child_stream;
conditional_node_streams_.emplace(
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
std::move(child_stream));
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
_currently_capturing_graphs.emplace(
conditional_graph_capture_streams_ids_.top(), this);
}
return handle;
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
return cudaGraphConditionalHandle{};
#endif
}
void CUDAGraph::end_capture_to_conditional_node() {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
{
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
CaptureId_t capture_id = conditional_graph_capture_streams_ids_.top();
TORCH_CHECK(
_currently_capturing_graphs.count(capture_id),
"capture_end() called before capture_begin().");
_currently_capturing_graphs.erase(capture_id);
}
CUDAStream stream = conditional_node_streams_.top().first.current_stream();
cudaGraph_t graph{};
AT_CUDA_CHECK(cudaStreamEndCapture(stream.stream(), &graph));
descendent_graphs_.push_back(graph);
conditional_node_streams_.pop();
conditional_graph_capture_streams_ids_.pop();
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
if (conditional_graph_capture_streams_ids_.empty()) {
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_allocate_filter());
} else {
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
capture_dev_, mempool_id_, create_child_allocate_filter());
}
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
#endif
}
std::function<bool(cudaStream_t)> CUDAGraph::create_allocate_filter() {
return [this](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id = 0;
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
};
}
std::function<bool(cudaStream_t)> CUDAGraph::create_child_allocate_filter() {
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
return [&current_capture_id = conditional_graph_capture_streams_ids_.top()](cudaStream_t stream) {
cudaStreamCaptureStatus status{};
CaptureId_t stream_capture_id{};
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == current_capture_id;
};
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
AT_ERROR(
__func__,
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
return std::function<bool(cudaStream_t)>();
#endif
}
} // namespace at::cuda

View File

@ -1,30 +0,0 @@
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
namespace at::cuda {
namespace {
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
__global__ void set_conditional_handle_kernel(
cudaGraphConditionalHandle handle,
const bool* value) {
cudaGraphSetConditional(handle, *value);
}
#endif
}
void CUDAGraph::set_conditional_handle(
cudaGraphConditionalHandle handle,
const Tensor& scalar_cuda_pred_tensor) {
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
set_conditional_handle_kernel<<<1, 1, 0, getCurrentCUDAStream()>>>(
handle, scalar_cuda_pred_tensor.const_data_ptr<bool>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
AT_ERROR("not allowed");
return;
#endif
}
} // namespace at::cuda

View File

@ -3,20 +3,9 @@
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>
#include <limits>
#include <stack>
#if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
// this type is not defined until CUDA 12.4, but we use it as a
// parameter type and return type in some below functions, so we give
// it the same definition as in CUDA 12.4.
typedef unsigned long long cudaGraphConditionalHandle;
#endif // defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
namespace at {
struct Generator;
@ -25,9 +14,6 @@ struct CUDAGeneratorState;
namespace cuda {
using UniquePtrExternalCudaStream =
std::unique_ptr<cudaStream_t, void (*)(cudaStream_t*)>;
// Standalone way to get a unique mempool id usable as a pool=... argument
// to CUDAGraph::capture_begin
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
@ -36,26 +22,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();
// Copy and move constructors and assignments are disabled. These
// were disabled because pybind11 believed that CUDAGraph was copy
// constructable because
// pybind11::is_copy_constructible<CUDAGraph>::value originally
// evaluated to true. However, it cannot generate a copy constructor
// because CUDAGeneratorState, one of CUDAGraph's members, is an
// incomplete type unless CUDAGeneratorImpl.h is included. However,
// that would create a circular dependency between
// CUDAGeneratorImpl.h and CUDAGraph.h. Disabling the copy and move
// constructors is the most straightforward way to prevent pybind11
// from trying to generate default implementations of them.
//
// We needed pybind11 to return a reference to a CUDAGraph as part
// of wrapping CUDAGraph::get_currently_capturing_graph, which
// unearthed the above problem.
CUDAGraph(const CUDAGraph&) = delete;
CUDAGraph& operator=(const CUDAGraph&) = delete;
CUDAGraph(CUDAGraph&& other) = delete;
CUDAGraph& operator=(CUDAGraph&& other) = delete;
static void inc_pending_event_queries();
static void dec_pending_event_queries();
static int num_pending_event_queries();
@ -72,19 +38,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
void enable_debug_mode();
void debug_dump(const std::string& debug_path);
static CUDAGraph* get_currently_capturing_graph();
void begin_capture_to_if_node(const Tensor& scalar_cuda_pred_tensor);
cudaGraphConditionalHandle begin_capture_to_while_loop_node(
const Tensor& scalar_cuda_pred_tensor);
void end_capture_to_conditional_node();
static void set_conditional_handle(
cudaGraphConditionalHandle handle,
const Tensor& scalar_cuda_pred_tensor);
private:
std::function<bool(cudaStream_t)> create_allocate_filter();
std::function<bool(cudaStream_t)> create_child_allocate_filter();
protected:
cudaGraph_t graph_ = nullptr;
cudaGraphExec_t graph_exec_ = nullptr;
@ -101,7 +54,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
// the ID assigned by cuda during graph capture,
// used to identify when a stream is participating in capture
CaptureId_t capture_id_ = std::numeric_limits<CaptureId_t>::max();
CaptureId_t capture_id_ = -1;
// uuid used to request a particular private mempool from CUDACachingAllocator.
// By default, this will be set to {id_, 0}.
@ -132,15 +85,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
// init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor
static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1;
c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
cudaStreamCaptureMode capture_mode_{};
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
std::stack<std::pair<at::cuda::CUDAStreamGuard, UniquePtrExternalCudaStream>>
conditional_node_streams_;
std::stack<CaptureId_t> conditional_graph_capture_streams_ids_;
std::vector<cudaGraph_t> descendent_graphs_;
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
};
} // namespace cuda

View File

@ -929,10 +929,6 @@ and you suspect its runtime is at least somewhat CPU-limited.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
.. _cudaGraphLaunch:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
.. _issue 144787:
https://github.com/pytorch/pytorch/issues/144787#issuecomment-2606480564
.. _conditional nodes:
https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
PyTorch API
^^^^^^^^^^^
@ -1021,9 +1017,6 @@ Violating any of these will likely cause a runtime error:
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
* Dynamic control flow (based on CPU or GPU data) is prohibited, unless it is based on GPU data and implemented via higher order operators
torch.cond() and torch.while_loop(). See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.
Violating any of these will likely cause silent numerical errors or undefined behavior:
@ -1032,6 +1025,7 @@ Violating any of these will likely cause silent numerical errors or undefined be
* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
* Every replay reads from and writes to the same (virtual) memory addresses.
* Dynamic control flow (based on CPU or GPU data) is prohibited.
* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
has the same size and layout in every replay.
* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
@ -1340,45 +1334,3 @@ If, in the live workload, your callables will run in an order that occasionally
or if they'll run concurrently, passing them as a tuple to a single invocation of
:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
:func:`~torch.cuda.make_graphed_callables` separately for each one.
.. _graph-data-dependent-control-flow:
Data Dependent Control Flow
^^^^^^^^^^^^^^^^^^^^^^^^^^^
Data-dependent control flow can with cuda graphs in limited cases if
the control flow is implemented using torch.cond() or
torch.while_loop(). If your function uses these functions, compiling
it with the "cudagraphs" backend will enable control flow in the
resulting cuda graph via `conditional nodes`_.
Unfortunately, eager mode execution does not work due to reasons
described in `issue 144787`_.
Support for inductor backend to torch.compile is not available yet, but there is no fundamental blocker.
An example of using the cudagraphs backend to torch.compile on code
using torch.cond is demonstrated below::
import torch
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4, device="cuda", requires_grad=False)
pred = torch.tensor(False, device="cuda", requires_grad=False)
def foo(pred, x):
with torch.inference_mode():
return torch.cond(pred, true_fn, false_fn, [x])
# First call will run eager for warmup, second call will do graph
# capture followed by graph replay, third call and beyond will do
# just graph replay.
compiled_foo = torch.compile(foo, backend="cudagraphs")
for i in range(3):
y = compiled_foo(pred, x)
# will output x.sin()
y = compiled_foo(~pred, x)

View File

@ -13,7 +13,7 @@ For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraph
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
- Arbitrary Control Flow is not possible (However, control flow expressed via torch.cond() and torch.while_loop() can be captured in a CUDA Graph. See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.)
- Control Flow is not possible
- Kernels which trigger host to device syncs (such as .item()) errors
- All input arguments to kernels are fixed to what they were recorded
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change

View File

@ -1,16 +1,13 @@
# Owner(s): ["module: functorch"]
import contextlib
import copy
import functools
import unittest
import warnings
import torch
import torch.utils._pytree as pytree
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch._higher_order_ops.associative_scan import (
_fake_associative_scan,
associative_scan,
@ -36,7 +33,6 @@ from torch.testing._internal.common_utils import (
skipIfCrossRef,
skipIfRocm,
skipIfTorchDynamo,
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
TestCase,
@ -44,43 +40,6 @@ from torch.testing._internal.common_utils import (
)
@contextlib.contextmanager
def check_cudagraphs_not_skipped(test_case):
old_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
try:
yield
finally:
# reset before the assert, because otherwise the reset is skipped
new_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
counters["inductor"]["cudagraph_skips"] = old_cudagraph_skips
test_case.assertEqual(
counters["inductor"]["cudagraph_skips"], new_cudagraph_skips
)
def _check_compile_cudagraph(test_case, fn, args):
# test cudagraphs backend
cudagraphs_compiled_fn = torch.compile(fn, backend="cudagraphs")
# We run 3 times.
# This is what cuda graph trees does for the first 3 runs:
# 1) run in eager mode, for warmup.
# 2) do stream capture followed by graph replay.
# 3 and beyond) do graph replay
# So we need to get to iteration 3 to test all ways of running.
outputs = []
for i in range(3):
with check_cudagraphs_not_skipped(test_case):
outputs.append(
pytree.tree_map(
lambda x: x.clone() if isinstance(x, torch.Tensor) else x,
cudagraphs_compiled_fn(*args),
)
)
eager_res = fn(*args)
for output in outputs:
test_case.assertEqual(eager_res, output)
# TODO: pull these helpers from AOTAutograd later
def to_fun(t):
if isinstance(t, torch.Tensor):
@ -274,9 +233,7 @@ def _while_loop_tests():
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
i1, j1, x1, y1 = while_loop(
cond_fn_nested,
body_fn_nested,
(i1, j1, x1, y1),
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
)
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
@ -440,7 +397,6 @@ def _while_loop_tests():
const_and_symint_output,
(torch.randn(2, 3, requires_grad=True),),
),
# I need to add a test here that uses just a dictionary and see what happens.
}
@ -4386,24 +4342,6 @@ def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tens
torch.randn(2, 3),
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
def test_cond_traced_not_nested_cudagraphs(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(True).cuda()])
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(False).cuda()])
def test_while_loop_nested_traced(self):
fn, inp = WHILE_LOOP_TESTS["nested"]
graphs = self._check_tracing(fn, inp)
@ -4650,43 +4588,6 @@ def forward(self, arg0_1):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
self._check_compile(fn, inp, backend=backend)
@parametrize(
"while_loop_test",
set(WHILE_LOOP_TESTS.keys())
- {"const_and_symint_output", "int_carry", "pytree_int_carry"},
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
def test_while_loop_cuda_stream_capture(self, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
if isinstance(fn, torch.nn.Module):
fn = copy.deepcopy(fn)
fn.cuda()
inp = pytree.tree_map(lambda x: x.cuda(), inp)
_check_compile_cudagraph(self, fn, inp)
@unittest.expectedFailure
@parametrize(
"while_loop_test", {"const_and_symint_output", "int_carry", "pytree_int_carry"}
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
def test_while_loop_cuda_stream_capture_fails(self, while_loop_test):
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
if isinstance(fn, torch.nn.Module):
fn = copy.deepcopy(fn)
fn.cuda()
inp = pytree.tree_map(lambda x: x.cuda(), inp)
_check_compile_cudagraph(self, fn, inp)
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with cross ref
def test_while_loop_simple_with_linear_compile_check_graph(self):
@ -4896,13 +4797,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
f(x, torch.tensor(True), torch.tensor(True)),
)
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
_check_compile_cudagraph(
self,
f,
[x.cuda(), torch.tensor(True).cuda(), torch.tensor(True).cuda()],
)
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
@ -4916,9 +4810,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
pred = x.shape[0] == 1
return cond(pred, true_fn, false_fn, [x])
def f_(x, y):
return cond(y, true_fn, false_fn, [x])
example_inputs = (torch.ones(4, 5),)
functional_f = torch.func.functionalize(f)
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
@ -4937,10 +4828,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
pred = torch.tensor(example_inputs[0].shape[0] == 1, device="cuda")
_check_compile_cudagraph(self, f_, [torch.ones(4, 5).cuda(), pred])
def test_cond_accepts_torch_function_as_inputs(self):
a = torch.randn(3, 4)
b = torch.randn(3, 4)
@ -5460,13 +5347,6 @@ def forward(self, arg0_1):
return (mul,)""",
)
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
_check_compile_cudagraph(
self,
f,
[x.cuda(), torch.tensor(False).cuda(), torch.tensor(False).cuda()],
)
def test_raise_error_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()
@ -7783,99 +7663,11 @@ class TestHopSchema(TestCase):
self.assertEqual(schema.parse(str(schema)), schema)
class DynamicCondModel(torch.nn.Module):
def __init__(self, input_size=16, hidden_size=64, output_size=10):
super().__init__()
self.fc1_0 = torch.nn.Linear(input_size, hidden_size)
self.fc1_1 = torch.nn.Linear(input_size, 32)
self.fc1_2 = torch.nn.Linear(32, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
def true_fn(x):
return self.fc1_0(x)
def false_fn(x):
x = self.fc1_1(x)
return self.fc1_2(x)
# use PyTorch control flow API
pred = torch.tensor(x.sum() > 0, device="cuda")
x = cond(pred, true_fn, false_fn, [x])
x = self.relu(x)
x = self.fc2(x)
return x
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowNN(TestCase):
def test_cond_in_NN(self):
model = DynamicCondModel().cuda()
x = torch.randn(16, device="cuda")
_check_compile_cudagraph(self, model, [x])
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowAndRNG(TestCase):
@parametrize("rng_func", ["custom_generator", "default_generator"])
def test_rng_with_conditional_nodes_warns(self, rng_func):
pred = torch.tensor(True, device="cuda")
x = torch.ones(10, dtype=torch.float32, device="cuda")
if rng_func == "custom_generator":
self.skipTest(
"randn() currently does not work with a generator argument in dynamo."
)
generator = torch.Generator("cuda")
def custom_generator(x):
return x + torch.randn(
*x.shape, generator=generator, dtype=x.dtype, device=x.device
)
rng_func = custom_generator
elif rng_func == "default_generator":
def default_generator(x):
return x + torch.randn(*x.shape, dtype=x.dtype, device=x.device)
rng_func = default_generator
def func(pred, x):
return torch.cond(pred, rng_func, lambda x: 2 * x, [x])
compiled_func = torch.compile(func, backend="cudagraphs")
with warnings.catch_warnings(record=True) as warning_objs:
for i in range(3):
compiled_func(pred, x)
# Warn first for conditional node warmup, warn second for the
# graph capture that we will actually use.
self.assertEqual(len(warning_objs), 2)
warning_message = "You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph." # noqa: B950
for warning in warning_objs:
self.assertTrue(warning_message in str(warning.message))
instantiate_parametrized_tests(TestHopSchema)
instantiate_parametrized_tests(TestControlFlowTraced)
instantiate_parametrized_tests(TestControlFlow)
instantiate_parametrized_tests(AssociativeScanTests)
instantiate_parametrized_tests(TestControlFlowAndRNG)
if __name__ == "__main__":
run_tests()

View File

@ -1,90 +0,0 @@
# Owner(s): ["module: functorch"]
import unittest
import torch
from torch.testing._internal.common_utils import (
run_tests,
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
TestCase,
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowInCUDAGraphInitialization(TestCase):
# Duplicated from test_cuda_primary_ctx.py
CTX_ALREADY_CREATED_ERR_MSG = (
"Tests defined in TestControlFlowInCUDAGraphInitialization must be run in a process "
"where CUDA contexts are never created. Use either run_test.py or add "
"--subprocess to run each test in a different subprocess."
)
def setUp(self):
# Ensure context has not been created beforehand
self.assertFalse(
torch._C._cuda_hasPrimaryContext(0),
TestControlFlowInCUDAGraphInitialization.CTX_ALREADY_CREATED_ERR_MSG,
)
def _check_compile_cudagraphs(self, f, pred, *other_args):
f = torch.compile(f, backend="cudagraphs")
outputs = []
for p in [pred, torch.logical_not(pred)]:
for i in range(3):
outputs.append(f(pred, *other_args).clone())
# We compute the eager output only after running cudagraphs
# backend compiled function, in order to make sure that
# cudagraph trees warms up the conditional part of the code
# properly.
eager_output = f(pred, *other_args)
for output in outputs:
self.assertEqual(output, eager_output)
def test_cond_cudnn(self):
# Tests that cublasCreate() does not break stream capture
def f(pred, x, filters):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.nn.functional.conv1d(y, filters)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.randn(33, 16, 30, device="cuda")
filters = torch.randn(20, 16, 5, device="cuda")
self._check_compile_cudagraphs(f, pred, x, filters)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
def test_cond_stft(self):
# Tests that cufft plan creation does not break stream capture
def f(pred, x):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.stft(y, 512, return_complex=False)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.ones(1024 * 1024, device="cuda")
self._check_compile_cudagraphs(f, pred, x)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
if __name__ == "__main__":
run_tests()

View File

@ -551,7 +551,6 @@ RUN_PARALLEL_BLOCKLIST = [
# temporarily sets a global config
"test_autograd_fallback",
"inductor/test_compiler_bisector",
"functorch/test_control_flow_cuda_initialization",
] + FSDP_TEST
# Test files that should always be run serially with other test files,
@ -1479,7 +1478,6 @@ CUSTOM_HANDLERS = {
"distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
"distributed/rpc/test_share_memory": run_test_with_subprocess,
"distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
"functorch/test_control_flow_cuda_initialization": run_test_with_subprocess,
"doctests": run_doctests,
"test_ci_sanity_check_fail": run_ci_sanity_check,
"test_autoload_enable": test_autoload_enable,

View File

@ -2125,13 +2125,6 @@ class _CUDAGraph:
def pool(self) -> Tuple[_int, _int]: ...
def enable_debug_mode(self) -> None: ...
def debug_dump(self, debug_path: str) -> None: ...
@staticmethod
def get_currently_capturing_graph() -> _CUDAGraph: ...
def begin_capture_to_if_node(self, scalar_cuda_pred_tensor): ...
def begin_capture_to_while_loop_node(self, scalar_cuda_pred_tensor) -> _int: ...
def end_capture_to_conditional_node(self): ...
@staticmethod
def set_conditional_handle(handle, scalar_cuda_pred_tensor): ...
# Defined in torch/csrc/cuda/MemPool.cpp
class _MemPool:

View File

@ -166,7 +166,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=is_inference,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
placeholders=get_placeholder_info(aot_model.graph),
mutated_input_idxs=find_input_mutations(aot_model.graph),

View File

@ -56,7 +56,7 @@ def make_eager_backend_with_torch_function_mode(mode):
def make_eager_backend_with_torch_function_modes(modes):
"""Used to trace HOPs (cond and while) for eager execution, the metadata
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
in the HOP, so we need to externally run this mode and not trace it."""
from contextlib import ExitStack

View File

@ -151,10 +151,10 @@ def populate_builtin_to_tensor_fn_map():
most_recent_func = func
return func(*args, **kwargs)
inp0 = torch.ones(1, device="cpu")
inp1 = torch.ones(1, device="cpu")
inp0_int = torch.ones(1, dtype=torch.int32, device="cpu")
inp1_int = torch.ones(1, dtype=torch.int32, device="cpu")
inp0 = torch.ones(1)
inp1 = torch.ones(1)
inp0_int = torch.ones(1, dtype=torch.int32)
inp1_int = torch.ones(1, dtype=torch.int32)
with GetMethodMode():
setups_and_oplists = [
(lambda o: o(inp0), un_ops),

View File

@ -18,11 +18,6 @@ from torch._C._functorch import (
from torch._dispatch.python import suspend_functionalization
from torch._functorch.utils import exposed_in
from torch._guards import detect_fake_mode
from torch._higher_order_ops.cudagraph_conditional_nodes import (
ControlFlowOpWarmupDispatchMode,
CUDAGraphCaptureControlFlowOpDispatchMode,
if_else_node,
)
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
@ -38,7 +33,6 @@ from torch._higher_order_ops.utils import (
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.cuda.graphs import _graph_no_gc
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
_temp_remove_pre_dispatch_torch_function_mode,
@ -377,41 +371,6 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
return false_fn(*operands)
# WAR for https://github.com/pytorch/pytorch/issues/140322
@cond_op.py_impl(CUDAGraphCaptureControlFlowOpDispatchMode)
def cond_op_cudagraph(mode, pred, true_fn, false_fn, operands):
assert torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
# Re-enter this mode because addition torch.cond() and
# torch.while_loop() calls may be nested inside true_fn or
# false_fn
with mode:
return if_else_node(pred, true_fn, false_fn, operands)
# WAR for https://github.com/pytorch/pytorch/issues/140322
@cond_op.py_impl(ControlFlowOpWarmupDispatchMode)
def cond_op_warmup(mode, pred, true_fn, false_fn, operands):
if torch.cuda.is_current_stream_capturing():
# This is a call to torch.cond() nested within either
# torch.while_loop() or another torch.cond() function.
with mode:
return if_else_node(pred, true_fn, false_fn, operands)
else:
with _graph_no_gc(
torch.cuda.CUDAGraph(),
pool=None,
stream=mode.capture_stream,
capture_error_mode="relaxed",
), mode:
if_else_node(pred, true_fn, false_fn, operands)
# Since ControlFlowOpWarmupDispatchMode has been popped, this call
# will fall back to cond_op_dense
return cond_op_dense(pred, true_fn, false_fn, operands)
# return torch.cond(pred, true_fn, false_fn, operands)
class CondAutogradOp(torch.autograd.Function):
@staticmethod
def forward(

View File

@ -1,132 +0,0 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from typing import Any, Generator
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
class CUDAGraphCaptureControlFlowOpDispatchMode(TorchDispatchMode):
def __init__(
self,
) -> None:
super().__init__()
def __torch_dispatch__(
self,
func,
types,
args=(),
kwargs=None,
):
kwargs = {} if kwargs is None else kwargs
return func(*args, **kwargs)
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
def __init__(
self,
) -> None:
super().__init__()
self.capture_stream = torch.cuda.Stream()
def __torch_dispatch__(
self,
func,
types,
args=(),
kwargs=None,
):
kwargs = {} if kwargs is None else kwargs
with torch.cuda.graphs.thread_cuda_stream_capture_mode(
torch.cuda.cudart().cudaStreamCaptureMode.Relaxed
):
return func(*args, **kwargs)
def _is_boolean_scalar_cuda_tensor(pred: Any) -> bool:
return (
isinstance(pred, torch.Tensor)
and pred.size() == torch.Size([])
and pred.dtype == torch.bool
and pred.is_cuda
)
@contextmanager
def _if_body(pred: torch.Tensor) -> Generator[None, None, None]:
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
current_cuda_graph.begin_capture_to_if_node(pred)
try:
yield
finally:
current_cuda_graph.end_capture_to_conditional_node()
def if_else_node(pred: torch.Tensor, true_fn, false_fn, operands):
if not pred.is_cuda:
raise ValueError(
"Conditions must be on a cuda device to use conditional node in cuda graphs"
)
# if-else is not supported yet in CUDA 12.4. Therefore, we use two if conditions, where one evaluates !pred
outs = []
for lazy_pred, fn in [
(lambda: pred, true_fn),
(lambda: torch.logical_not(pred), false_fn),
]:
with _if_body(lazy_pred()):
outs.append(fn(*operands))
# Copy these two outputs into a new output buffer. Well,
# actually, what we would like is to be able to merge these two
# tensors into the same tensor... Is there an obvious way to do
# that?
if len(outs) == 2:
for if_out, else_out in zip(
pytree.tree_iter(outs[0]), pytree.tree_iter(outs[1])
):
if_out.copy_(else_out)
assert len(outs) == 2
return outs[0]
@contextmanager
def _while_loop_body(pred: torch.Tensor) -> Generator[int, None, None]:
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
conditional_handle = current_cuda_graph.begin_capture_to_while_loop_node(pred)
try:
yield conditional_handle
finally:
current_cuda_graph.end_capture_to_conditional_node()
def while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs):
carried_vals = carried_inputs
pred = cond_fn(*carried_vals, *additional_inputs)
if not _is_boolean_scalar_cuda_tensor(pred):
raise RuntimeError(
f"cond_fn must return a boolean scalar cuda tensor but got {pred}"
)
with _while_loop_body(pred) as conditional_handle:
out = body_fn(*carried_vals, *additional_inputs)
assert len(out) == len(
carried_inputs
), "body_fn should return the same number of elements as carried_inputs"
# out is not being flattened, for whatever reason.
for c, o in zip(carried_vals, out):
# TODO: Consider skipping the copy_ if the data_ptr is the
# same.
c.copy_(o)
# call the cond_fn again to update the pred.
pred = cond_fn(*carried_vals, *additional_inputs)
if not _is_boolean_scalar_cuda_tensor(pred):
raise RuntimeError(
f"cond_fn must return a boolean scalar tensor but got {pred}"
)
torch.cuda.CUDAGraph.set_conditional_handle(conditional_handle, pred)
return carried_vals

View File

@ -5,11 +5,6 @@ from typing import Callable, Union
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.cudagraph_conditional_nodes import (
ControlFlowOpWarmupDispatchMode,
CUDAGraphCaptureControlFlowOpDispatchMode,
while_loop_node,
)
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
@ -23,7 +18,6 @@ from torch._higher_order_ops.utils import (
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.graphs import _graph_no_gc
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
ProxyTorchDispatchMode,
@ -222,37 +216,6 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
return carried_vals
# WAR for https://github.com/pytorch/pytorch/issues/140322
@while_loop_op.py_impl(CUDAGraphCaptureControlFlowOpDispatchMode)
def while_loop_cudagraph(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
assert torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
# Re-enter this mode because addition torch.cond() and
# torch.while_loop() calls may be nested inside cond_fn or body_fn
with mode:
return while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
# WAR for https://github.com/pytorch/pytorch/issues/140322
@while_loop_op.py_impl(ControlFlowOpWarmupDispatchMode)
def while_loop_warmup(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
if torch.cuda.is_current_stream_capturing():
# This is a call to torch.while_loop() nested within either
# torch.while_loop() or another torch.cond() function.
with mode:
return while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
else:
with _graph_no_gc(
torch.cuda.CUDAGraph(),
pool=None,
stream=mode.capture_stream,
capture_error_mode="relaxed",
), mode:
while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
# Since ControlFlowOpWarmupDispatchMode has been popped, this call
# will fall back to while_loop_dense
return while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs)
while_loop_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(while_loop_op, deferred_error=True)
)

View File

@ -64,10 +64,6 @@ import torch.fx
from torch import Tensor
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state
from torch._higher_order_ops.cudagraph_conditional_nodes import (
ControlFlowOpWarmupDispatchMode,
CUDAGraphCaptureControlFlowOpDispatchMode,
)
from torch._inductor.compile_fx import (
align_inputs_from_check_idxs,
copy_misaligned_inputs,
@ -638,7 +634,7 @@ class CUDAWarmupNode:
self.device_index
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
), ControlFlowOpWarmupDispatchMode(), get_history_recording():
), get_history_recording():
out = self.wrapped_function.model(new_inputs)
# We need to know which outputs are allocated within the cudagraph pool
@ -1213,7 +1209,7 @@ class CUDAGraphNode:
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
), CUDAGraphCaptureControlFlowOpDispatchMode(), get_history_recording():
), get_history_recording():
static_outputs = model(inputs)
# running model should reclaim memory

View File

@ -32,7 +32,6 @@ from torch._subclasses.meta_utils import (
MetaConverter,
)
from torch._utils import render_call
from torch.cuda.graphs import thread_cuda_stream_capture_mode
from torch.fx.immutable_collections import immutable_dict
from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
@ -2653,31 +2652,10 @@ def run_fallback_kernel(
return out
return e
has_cuda_tensor = any(
isinstance(a, FakeTensor) and a.fake_device.type == "cuda"
for a in flat_args
)
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
# If one of the inputs is a CUDA tensor, it is possible that
# running the fallback kernel will do an unsafe
# action. Unfortunately, there are scenarios where pytorch can
# have a stream currently capturing on the current stream that
# is using fake tensors (in particular, for shape inference in
# higher order operators). We need to prevent stream capture
# from breaking in this case. This is basically always safe
# because the unsafe actions tend to be lazy initialization of
# things like CUFFT plans, which won't be destroyed.
maybe_relaxed: typing.ContextManager = contextlib.nullcontext()
if has_cuda_tensor:
cudart = torch.cuda.cudart()
maybe_relaxed = thread_cuda_stream_capture_mode(
cudart.cudaStreamCaptureMode.Relaxed
)
with maybe_relaxed:
r = func(*args, **kwargs)
r = func(*args, **kwargs)
storages: set[_StoragePointer] = set()

View File

@ -87,30 +87,5 @@ void THCPGraph_init(PyObject* module) {
"debug_dump",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::debug_dump),
py::arg("debug_path"))
.def_static(
"get_currently_capturing_graph",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::get_currently_capturing_graph),
py::return_value_policy::reference)
.def(
"begin_capture_to_if_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::begin_capture_to_if_node),
py::arg("scalar_cuda_pred_tensor"))
.def(
"begin_capture_to_while_loop_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::begin_capture_to_while_loop_node),
py::arg("scalar_cuda_pred_tensor"))
.def(
"end_capture_to_conditional_node",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::end_capture_to_conditional_node))
.def_static(
"set_conditional_handle",
torch::wrap_pybind_function_no_gil(
&::at::cuda::CUDAGraph::set_conditional_handle),
py::arg("handle"),
py::arg("scalar_cuda_pred_tensor"));
py::arg("debug_path"));
}

View File

@ -92,20 +92,6 @@ void initCudartBindings(PyObject* module) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
});
cudart.attr(
"cuda"
"StreamDefault") = cudaStreamDefault;
cudart.attr(
"cuda"
"StreamNonBlocking") = cudaStreamNonBlocking;
cudart.def(
"cuda"
"StreamCreateWithFlags",
[](uintptr_t ptr, unsigned int flags) -> cudaError_t {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return C10_CUDA_ERROR_HANDLED(
cudaStreamCreateWithFlags((cudaStream_t*)ptr, flags));
});
cudart.def(
"cuda"
"StreamDestroy",
@ -134,22 +120,6 @@ void initCudartBindings(PyObject* module) {
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return {device_free, device_total};
});
py::enum_<cudaStreamCaptureMode>(
cudart,
"cuda"
"StreamCaptureMode")
.value("Global", cudaStreamCaptureModeGlobal)
.value("ThreadLocal", cudaStreamCaptureModeThreadLocal)
.value("Relaxed", cudaStreamCaptureModeRelaxed);
cudart.def(
"cuda"
"ThreadExchangeStreamCaptureMode",
[](cudaStreamCaptureMode mode) -> cudaStreamCaptureMode {
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
return mode;
});
}
} // namespace torch::cuda::shared

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import contextlib
import gc
import typing
@ -133,11 +132,6 @@ class graph:
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
collect_garbage (bool, optional): If True, call torch.cuda.synchronize() followed by gc.collect() to free
memory before starting graph capture. Users almost always this to be True, but since the introduction of
conditional nodes in cuda graphs, it is possible that more than one stream may be capturing at once.
Since cudaDeviceSynchronize() synchronizes all streams, including capturing streams, previously started
stream captures will be invalidated. This is not desirable.
.. note::
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
@ -153,7 +147,11 @@ class graph:
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
def __init__(
self, cuda_graph, pool=None, stream=None, capture_error_mode: str = "global"
self,
cuda_graph,
pool=None,
stream=None,
capture_error_mode: str = "global",
):
# Lazy-init of default_capture_stream helps avoid circular-import errors.
# Not thread safe, but graphs already have the general (explicitly documented)
@ -171,7 +169,7 @@ class graph:
self.capture_error_mode = capture_error_mode
def __enter__(self):
# Free as much memory as we can for the graph.
# Free as much memory as we can for the graph
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
@ -190,30 +188,6 @@ class graph:
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
@contextlib.contextmanager
def _graph_no_gc(cuda_graph, pool, stream, capture_error_mode):
"""This is an internal function used to do stream capture without
calling torch.cuda.synchronize(), gc.collect(), and
torch.cuda.empty_cache(). Unfortunately, cudagraph trees runs its
eager warmup inside of the context manager
_use_cuda_memory_pool_manager(), which makes captures_underway in
CUDACachingAllocator.cpp non-empty. We need this in order to
warmup conditional higher order operators, like torch.cond() and
torch.while_loop(). torch.cuda.empty_cache() will fail if
captures_underway is non-empty. Removing torch.cuda.synchronize()
and gc.collect() is not strictly speaking required, but they are
expensive an unnecessary operations.
"""
stream_ctx = torch.cuda.stream(stream)
pool = () if pool is None else (pool,)
with stream_ctx:
cuda_graph.capture_begin(*pool, capture_error_mode=capture_error_mode)
try:
yield
finally:
cuda_graph.capture_end()
def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
@ -514,46 +488,3 @@ def make_graphed_callables(
return ret[0]
return tuple(ret)
@contextlib.contextmanager
def thread_cuda_stream_capture_mode(new_mode):
r"""Changes current thread's stream capture mode to `new_mode` upon __enter__ and resets the mode upon __exit__.
The only documentation on a thread's stream capture mode is here:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
However, it is a little bit inadequate, so here is a more in-depth description.
Both CPU threads and capturing cuda streams have a capture mode. A
cuda stream's capture mode is set at cudaStreamBeginCapture() and
can never be changed. Meanwhile all CPU threads start with a capture
mode of cudaStreamCaptureModeGlobal, which can be changed at any
time.
Whenever a thread executes an unsafe CUDA action while CUDA
streams are capturing, it follows the following logic to determine
whether to invalidate those streams:
if capture_mode_this_thread == cudaStreamCaptureModeRelaxed:
never invalidate any capturing cuda streams whatsoever.
elif capture_mode_this_thread == cudaStreamCaptureModeThreadLocal:
invalidate any cuda streams for which cudaStreamBeginCapture() was called by this
thread, except for streams whose capture mode is cudaStreamCaptureModeRelaxed.
elif capture_mode_this_thread == cudaStreamCaptureModeGlobal:
invalidate all cuda streams that are currently capturing on any thread,
except for streams whose capture mode is cudaStreamCaptureModeRelaxed and for
streams for which cudaStreamCaptureBegin() was called with
cudaStreamCaptureModeThreadLocal on a thread other than this one.
In practice, changed the current capture mode to
cudaStreamCaptureModeRelaxed in particular is helpful for enabling
developers to do "unsafe" things that we know are safe in our
case.
"""
cudart = torch.cuda.cudart()
old_mode = cudart.cudaThreadExchangeStreamCaptureMode(new_mode)
try:
yield
finally:
cudart.cudaThreadExchangeStreamCaptureMode(old_mode)

View File

@ -1507,12 +1507,6 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
)
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
TEST_CUDA_GRAPH_CONDITIONAL_NODES = TEST_CUDA_GRAPH and (
torch.version.cuda and (
(int(torch.version.cuda.split(".")[0]) >= 12 and int(torch.version.cuda.split(".")[1]) >= 4) or
(int(torch.version.cuda.split(".")[0]) >= 13)
)
)
def allocator_option_enabled_fn(allocator_config, _, option):
if allocator_config is None: