mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
Compare commits
1 Commits
ciflow/tru
...
galv/cudag
| Author | SHA1 | Date | |
|---|---|---|---|
| f28f0c7031 |
@ -8,8 +8,49 @@
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
void external_stream_deleter(cudaStream_t* stream) {
|
||||
if (stream != nullptr) {
|
||||
AT_CUDA_CHECK(cudaStreamDestroy(*stream));
|
||||
delete stream;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
UniquePtrExternalCudaStream create_external_stream() {
|
||||
// From:
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g793d7d4e474388ddfda531603dc34aa3
|
||||
// "Capture must be ended on the same stream in which it was initiated, and it
|
||||
// may only be initiated if the stream is not already in capture mode."
|
||||
|
||||
// Since pytorch uses a pool of 32 pre-allocated cuda streams,
|
||||
// should a user nest 32 conditional nodes, there would be an error
|
||||
// for the 32nd node, since that node's stream would already be in
|
||||
// capture mode. The easiest solution is to handle stream creation
|
||||
// and deletion ourselves.
|
||||
|
||||
// we use cudaStreamNonBlocking because every default cuda stream in
|
||||
// pytorch uses that flag for all streams used for stream capture
|
||||
// (see kDefaultFlags in CUDAStream.cpp). This would need to be kept
|
||||
// in sync, should that ever change. Or kDefaultFlags needs to be
|
||||
// exposed in a header file.
|
||||
auto stream_ptr = std::make_unique<cudaStream_t>();
|
||||
AT_CUDA_CHECK(
|
||||
cudaStreamCreateWithFlags(stream_ptr.get(), cudaStreamNonBlocking));
|
||||
return UniquePtrExternalCudaStream(
|
||||
stream_ptr.release(), external_stream_deleter);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
static bool _cuda_graphs_debug = false;
|
||||
|
||||
// To support stream capture across multiple threads, we use a global
|
||||
// hashmap mapping cuda stream capture IDs to CUDAGraph objects. This
|
||||
// was originally a thread_local std::stack<CUDAGraph*>, but that was
|
||||
// not acceptable since stream capture does span threads in certain
|
||||
// circumstances (in particular, during autograd).
|
||||
static std::mutex _currently_capturing_graphs_mutex;
|
||||
static ska::flat_hash_map<CaptureId_t, CUDAGraph*> _currently_capturing_graphs;
|
||||
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
@ -55,11 +96,13 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) {
|
||||
cuda_gen->register_graph(this);
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_mode) {
|
||||
TORCH_CHECK(!has_graph_exec_,
|
||||
"This CUDAGraph instance already owns a captured graph. "
|
||||
"To capture a new graph, create a new instance.");
|
||||
|
||||
capture_mode_ = capture_mode;
|
||||
|
||||
// default generator is always registered
|
||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
std::nullopt, cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -97,12 +140,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
// Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
|
||||
// autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator
|
||||
// due to the capture status being updated _after_ a capture had already started.
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
|
||||
});
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter());
|
||||
|
||||
// cudaStreamCaptureModeGlobal is the most conservative option to
|
||||
// prevent potentially unsafe CUDA API calls during capture. See
|
||||
@ -113,6 +151,10 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
|
||||
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
_currently_capturing_graphs.emplace(capture_id_, this);
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_end() {
|
||||
@ -123,6 +165,14 @@ void CUDAGraph::capture_end() {
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(capture_id_),
|
||||
"capture_end() called before capture_begin().");
|
||||
_currently_capturing_graphs.erase(capture_id_);
|
||||
}
|
||||
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
|
||||
TORCH_CHECK(graph_ != nullptr, "Invalid capture.");
|
||||
@ -132,6 +182,19 @@ void CUDAGraph::capture_end() {
|
||||
wholegraph_increments = generator_state->capture_epilogue();
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
bool any_wholegraph_increments_nonzero = false;
|
||||
for (auto& [generator_state, wholegraph_increments] : captured_generator_states_) {
|
||||
if (wholegraph_increments != 0) {
|
||||
any_wholegraph_increments_nonzero = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (any_wholegraph_increments_nonzero && !descendent_graphs_.empty()) {
|
||||
TORCH_WARN("You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph.");
|
||||
}
|
||||
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
|
||||
|
||||
size_t numCUDAGraphNodes = 0;
|
||||
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes));
|
||||
if (numCUDAGraphNodes == 0) {
|
||||
@ -292,7 +355,7 @@ void CUDAGraph::reset() {
|
||||
|
||||
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
|
||||
MempoolId_t CUDAGraph::pool() {
|
||||
TORCH_CHECK(capture_ended_,
|
||||
TORCH_CHECK(capture_ended_,
|
||||
"Called CUDAGraph::pool() without a preceding successful capture.");
|
||||
return mempool_id_;
|
||||
}
|
||||
@ -317,4 +380,163 @@ CUDAGraph::~CUDAGraph() {
|
||||
#endif
|
||||
}
|
||||
|
||||
CUDAGraph* CUDAGraph::get_currently_capturing_graph() {
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t current_capture_id = -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, ¤t_capture_id));
|
||||
TORCH_CHECK(
|
||||
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive,
|
||||
"The current stream is not currently capturing.");
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(current_capture_id),
|
||||
"get_currently_capturing_graph() can be used only between capture_begin() and capture_end(). Did you use a stream without making it depend upon the original stream used for capture?");
|
||||
return _currently_capturing_graphs.at(current_capture_id);
|
||||
}
|
||||
|
||||
void CUDAGraph::begin_capture_to_if_node(
|
||||
const at::Tensor& scalar_cuda_pred_tensor) {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
TORCH_CHECK(
|
||||
!has_graph_exec_,
|
||||
"begin_capture_to_if_node() must be called before capture_begin()");
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t currently_capturing_graph{};
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(), &status, nullptr, ¤tly_capturing_graph));
|
||||
TORCH_CHECK(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"capture_begin() must be called before begin_capture_to_if_node()");
|
||||
cudaGraphConditionalHandle handle{};
|
||||
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
|
||||
&handle, currently_capturing_graph, 0, 0));
|
||||
|
||||
set_conditional_handle(handle, scalar_cuda_pred_tensor);
|
||||
|
||||
const cudaGraphNode_t* dependencies{};
|
||||
size_t num_dependencies = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(),
|
||||
&status,
|
||||
nullptr,
|
||||
¤tly_capturing_graph,
|
||||
&dependencies,
|
||||
&num_dependencies));
|
||||
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
|
||||
|
||||
cudaGraphNodeParams params{};
|
||||
params.type = cudaGraphNodeTypeConditional;
|
||||
params.conditional.handle = handle;
|
||||
params.conditional.type = cudaGraphCondTypeIf;
|
||||
params.conditional.size = 1;
|
||||
|
||||
cudaGraphNode_t cond_node{};
|
||||
AT_CUDA_CHECK(cudaGraphAddNode(
|
||||
&cond_node,
|
||||
currently_capturing_graph,
|
||||
dependencies,
|
||||
num_dependencies,
|
||||
¶ms));
|
||||
|
||||
cudaGraph_t if_node_child_graph = params.conditional.phGraph_out[0];
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
|
||||
|
||||
UniquePtrExternalCudaStream child_stream = create_external_stream();
|
||||
conditional_graph_capture_streams_ids_.push(-1);
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_child_allocate_filter());
|
||||
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
|
||||
*child_stream, if_node_child_graph, nullptr, nullptr, 0, capture_mode_));
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
// We need to get the raw_stream here before emplace() to prevent
|
||||
// std::move(child_stream) from potentially executing before
|
||||
// *child_stream.
|
||||
cudaStream_t raw_stream = *child_stream;
|
||||
conditional_node_streams_.emplace(
|
||||
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
|
||||
std::move(child_stream));
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
_currently_capturing_graphs.emplace(
|
||||
conditional_graph_capture_streams_ids_.top(), this);
|
||||
}
|
||||
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::end_capture_to_conditional_node() {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
CaptureId_t capture_id = conditional_graph_capture_streams_ids_.top();
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(capture_id),
|
||||
"capture_end() called before capture_begin().");
|
||||
_currently_capturing_graphs.erase(capture_id);
|
||||
}
|
||||
|
||||
CUDAStream stream = conditional_node_streams_.top().first.current_stream();
|
||||
cudaGraph_t graph{};
|
||||
AT_CUDA_CHECK(cudaStreamEndCapture(stream.stream(), &graph));
|
||||
descendent_graphs_.push_back(graph);
|
||||
conditional_node_streams_.pop();
|
||||
conditional_graph_capture_streams_ids_.pop();
|
||||
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
if (conditional_graph_capture_streams_ids_.empty()) {
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_allocate_filter());
|
||||
} else {
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_child_allocate_filter());
|
||||
}
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::function<bool(cudaStream_t)> CUDAGraph::create_allocate_filter() {
|
||||
return [this](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
|
||||
};
|
||||
}
|
||||
|
||||
std::function<bool(cudaStream_t)> CUDAGraph::create_child_allocate_filter() {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
return [¤t_capture_id = conditional_graph_capture_streams_ids_.top()](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id{};
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == current_capture_id;
|
||||
};
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
return std::function<bool(cudaStream_t)>();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::cuda
|
||||
|
||||
30
aten/src/ATen/cuda/CUDAGraph.cu
Normal file
30
aten/src/ATen/cuda/CUDAGraph.cu
Normal file
@ -0,0 +1,30 @@
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
__global__ void set_conditional_handle_kernel(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const bool* value) {
|
||||
cudaGraphSetConditional(handle, *value);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::set_conditional_handle(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const Tensor& scalar_cuda_pred_tensor) {
|
||||
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
set_conditional_handle_kernel<<<1, 1, 0, getCurrentCUDAStream()>>>(
|
||||
handle, scalar_cuda_pred_tensor.const_data_ptr<bool>());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
AT_ERROR("not allowed");
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
@ -4,9 +4,20 @@
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
#include <limits>
|
||||
#include <stack>
|
||||
|
||||
#if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
// this type is not defined until CUDA 12.4, but we use it as a
|
||||
// parameter type and return type in some below functions, so we give
|
||||
// it the same definition as in CUDA 12.4.
|
||||
typedef unsigned long long cudaGraphConditionalHandle;
|
||||
#endif // defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
|
||||
namespace at {
|
||||
|
||||
struct Generator;
|
||||
@ -15,6 +26,9 @@ struct CUDAGeneratorState;
|
||||
|
||||
namespace cuda {
|
||||
|
||||
using UniquePtrExternalCudaStream =
|
||||
std::unique_ptr<cudaStream_t, void (*)(cudaStream_t*)>;
|
||||
|
||||
// Standalone way to get a unique mempool id usable as a pool=... argument
|
||||
// to CUDAGraph::capture_begin
|
||||
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
|
||||
@ -23,6 +37,26 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
CUDAGraph(bool keep_graph=false);
|
||||
~CUDAGraph();
|
||||
|
||||
// Copy and move constructors and assignments are disabled. These
|
||||
// were disabled because pybind11 believed that CUDAGraph was copy
|
||||
// constructable because
|
||||
// pybind11::is_copy_constructible<CUDAGraph>::value originally
|
||||
// evaluated to true. However, it cannot generate a copy constructor
|
||||
// because CUDAGeneratorState, one of CUDAGraph's members, is an
|
||||
// incomplete type unless CUDAGeneratorImpl.h is included. However,
|
||||
// that would create a circular dependency between
|
||||
// CUDAGeneratorImpl.h and CUDAGraph.h. Disabling the copy and move
|
||||
// constructors is the most straightforward way to prevent pybind11
|
||||
// from trying to generate default implementations of them.
|
||||
//
|
||||
// We needed pybind11 to return a reference to a CUDAGraph as part
|
||||
// of wrapping CUDAGraph::get_currently_capturing_graph, which
|
||||
// unearthed the above problem.
|
||||
CUDAGraph(const CUDAGraph&) = delete;
|
||||
CUDAGraph& operator=(const CUDAGraph&) = delete;
|
||||
CUDAGraph(CUDAGraph&& other) = delete;
|
||||
CUDAGraph& operator=(CUDAGraph&& other) = delete;
|
||||
|
||||
// See Note [Explicit Registration of Generators to the CUDA Graph]
|
||||
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
|
||||
void register_generator_state(const at::Generator& generator);
|
||||
@ -39,6 +73,17 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
cudaGraph_t raw_cuda_graph();
|
||||
cudaGraphExec_t raw_cuda_graph_exec();
|
||||
|
||||
static CUDAGraph* get_currently_capturing_graph();
|
||||
void begin_capture_to_if_node(const Tensor& scalar_cuda_pred_tensor);
|
||||
void end_capture_to_conditional_node();
|
||||
static void set_conditional_handle(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const Tensor& scalar_cuda_pred_tensor);
|
||||
|
||||
private:
|
||||
std::function<bool(cudaStream_t)> create_allocate_filter();
|
||||
std::function<bool(cudaStream_t)> create_child_allocate_filter();
|
||||
|
||||
protected:
|
||||
cudaGraph_t graph_ = nullptr;
|
||||
cudaGraphExec_t graph_exec_ = nullptr;
|
||||
@ -89,6 +134,14 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
|
||||
|
||||
bool keep_graph_;
|
||||
cudaStreamCaptureMode capture_mode_{};
|
||||
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
std::stack<std::pair<at::cuda::CUDAStreamGuard, UniquePtrExternalCudaStream>>
|
||||
conditional_node_streams_;
|
||||
std::stack<CaptureId_t> conditional_graph_capture_streams_ids_;
|
||||
std::vector<cudaGraph_t> descendent_graphs_;
|
||||
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
@ -1325,6 +1325,10 @@ and you suspect its runtime is at least somewhat CPU-limited.
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
|
||||
.. _cudaGraphLaunch:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||
.. _issue 144787:
|
||||
https://github.com/pytorch/pytorch/issues/144787#issuecomment-2606480564
|
||||
.. _conditional nodes:
|
||||
https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
||||
|
||||
PyTorch API
|
||||
^^^^^^^^^^^
|
||||
@ -1413,6 +1417,9 @@ Violating any of these will likely cause a runtime error:
|
||||
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
|
||||
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
|
||||
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
|
||||
* Dynamic control flow (based on CPU or GPU data) is prohibited, unless it is based on GPU data and implemented via higher order operator
|
||||
torch.cond(). See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.
|
||||
|
||||
|
||||
|
||||
Violating any of these will likely cause silent numerical errors or undefined behavior:
|
||||
@ -1421,7 +1428,6 @@ Violating any of these will likely cause silent numerical errors or undefined be
|
||||
* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
|
||||
* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
|
||||
* Every replay reads from and writes to the same (virtual) memory addresses.
|
||||
* Dynamic control flow (based on CPU or GPU data) is prohibited.
|
||||
* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
|
||||
has the same size and layout in every replay.
|
||||
* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
|
||||
@ -1730,3 +1736,45 @@ If, in the live workload, your callables will run in an order that occasionally
|
||||
or if they'll run concurrently, passing them as a tuple to a single invocation of
|
||||
:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
|
||||
:func:`~torch.cuda.make_graphed_callables` separately for each one.
|
||||
|
||||
.. _graph-data-dependent-control-flow:
|
||||
|
||||
Data Dependent Control Flow
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Data-dependent control flow can with cuda graphs if the control flow
|
||||
is implemented using torch.cond(). If your function uses this
|
||||
function, compiling it with the "cudagraphs" backend will enable
|
||||
control flow in the resulting cuda graph via `conditional nodes`_.
|
||||
|
||||
Unfortunately, eager mode execution does not work due to reasons
|
||||
described in `issue 144787`_. Support for inductor backend to
|
||||
torch.compile is not available yet, but there is no fundamental
|
||||
blocker.
|
||||
|
||||
An example of using the cudagraphs backend to torch.compile on code
|
||||
using torch.cond is demonstrated below::
|
||||
|
||||
import torch
|
||||
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4, device="cuda", requires_grad=False)
|
||||
pred = torch.tensor(False, device="cuda", requires_grad=False)
|
||||
def foo(pred, x):
|
||||
with torch.inference_mode():
|
||||
return torch.cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
# First call will run eager for warmup, second call will do graph
|
||||
# capture followed by graph replay, third call and beyond will do
|
||||
# just graph replay.
|
||||
compiled_foo = torch.compile(foo, backend="cudagraphs")
|
||||
for i in range(3):
|
||||
y = compiled_foo(pred, x)
|
||||
|
||||
# will output x.sin()
|
||||
y = compiled_foo(~pred, x)
|
||||
|
||||
@ -10,7 +10,7 @@ For a longer background on CUDAGraphs, read [accelerating pytorch with CUDAGraph
|
||||
|
||||
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
|
||||
|
||||
- Control Flow is not possible
|
||||
- Arbitrary Control Flow is not possible (However, control flow expressed via torch.cond() can be captured in a CUDA Graph. See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.)
|
||||
- Kernels which trigger host to device syncs (such as .item()) errors
|
||||
- All input arguments to kernels are fixed to what they were recorded
|
||||
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change
|
||||
|
||||
@ -2,12 +2,14 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import (
|
||||
_fake_associative_scan,
|
||||
associative_scan,
|
||||
@ -34,12 +36,50 @@ from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def check_cudagraphs_not_skipped(test_case):
|
||||
old_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# reset before the assert, because otherwise the reset is skipped
|
||||
new_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
|
||||
counters["inductor"]["cudagraph_skips"] = old_cudagraph_skips
|
||||
test_case.assertEqual(
|
||||
counters["inductor"]["cudagraph_skips"], new_cudagraph_skips
|
||||
)
|
||||
|
||||
|
||||
def _check_compile_cudagraph(test_case, fn, args):
|
||||
# test cudagraphs backend
|
||||
cudagraphs_compiled_fn = torch.compile(fn, backend="cudagraphs")
|
||||
# We run 3 times.
|
||||
# This is what cuda graph trees does for the first 3 runs:
|
||||
# 1) run in eager mode, for warmup.
|
||||
# 2) do stream capture followed by graph replay.
|
||||
# 3 and beyond) do graph replay
|
||||
# So we need to get to iteration 3 to test all ways of running.
|
||||
outputs = []
|
||||
for i in range(3):
|
||||
with check_cudagraphs_not_skipped(test_case):
|
||||
outputs.append(
|
||||
pytree.tree_map(
|
||||
lambda x: x.clone() if isinstance(x, torch.Tensor) else x,
|
||||
cudagraphs_compiled_fn(*args),
|
||||
)
|
||||
)
|
||||
eager_res = fn(*args)
|
||||
for output in outputs:
|
||||
test_case.assertEqual(eager_res, output)
|
||||
|
||||
|
||||
# TODO: pull these helpers from AOTAutograd later
|
||||
def to_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
@ -5473,6 +5513,24 @@ class GraphModule(torch.nn.Module):
|
||||
torch.randn(2, 3),
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
def test_cond_traced_not_nested_cudagraphs(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def f(x, y):
|
||||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(True).cuda()])
|
||||
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(False).cuda()])
|
||||
|
||||
def test_while_loop_nested_traced(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["nested"]
|
||||
graphs = self._check_tracing(fn, inp)
|
||||
@ -5889,6 +5947,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
f(x, torch.tensor(True), torch.tensor(True)),
|
||||
)
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
_check_compile_cudagraph(
|
||||
self,
|
||||
f,
|
||||
[x.cuda(), torch.tensor(True).cuda(), torch.tensor(True).cuda()],
|
||||
)
|
||||
|
||||
def test_cond_functionalized(self):
|
||||
def true_fn(x):
|
||||
y = x.sin()
|
||||
@ -5902,6 +5967,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
def f_(x, y):
|
||||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
@ -5920,6 +5988,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
pred = torch.tensor(example_inputs[0].shape[0] == 1, device="cuda")
|
||||
_check_compile_cudagraph(self, f_, [torch.ones(4, 5).cuda(), pred])
|
||||
|
||||
def test_cond_accepts_torch_function_as_inputs(self):
|
||||
a = torch.randn(3, 4)
|
||||
b = torch.randn(3, 4)
|
||||
@ -6453,6 +6525,13 @@ def forward(self, arg0_1):
|
||||
return (mul,)""",
|
||||
)
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
_check_compile_cudagraph(
|
||||
self,
|
||||
f,
|
||||
[x.cuda(), torch.tensor(False).cuda(), torch.tensor(False).cuda()],
|
||||
)
|
||||
|
||||
def test_raise_error_on_mismatch_type_size(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
@ -9748,11 +9827,98 @@ class TestHopSchema(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class DynamicCondModel(torch.nn.Module):
|
||||
def __init__(self, input_size=16, hidden_size=64, output_size=10):
|
||||
super().__init__()
|
||||
self.fc1_0 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc1_1 = torch.nn.Linear(input_size, 32)
|
||||
self.fc1_2 = torch.nn.Linear(32, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
def true_fn(x):
|
||||
return self.fc1_0(x)
|
||||
|
||||
def false_fn(x):
|
||||
x = self.fc1_1(x)
|
||||
return self.fc1_2(x)
|
||||
|
||||
pred = torch.tensor(x.sum() > 0, device="cuda")
|
||||
x = cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowNN(TestCase):
|
||||
def test_cond_in_NN(self):
|
||||
model = DynamicCondModel().cuda()
|
||||
|
||||
x = torch.randn(16, device="cuda")
|
||||
_check_compile_cudagraph(self, model, [x])
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowAndRNG(TestCase):
|
||||
@parametrize("rng_func", ["custom_generator", "default_generator"])
|
||||
def test_rng_with_conditional_nodes_warns(self, rng_func):
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.ones(10, dtype=torch.float32, device="cuda")
|
||||
|
||||
if rng_func == "custom_generator":
|
||||
self.skipTest(
|
||||
"randn() currently does not work with a generator argument in dynamo."
|
||||
)
|
||||
generator = torch.Generator("cuda")
|
||||
|
||||
def custom_generator(x):
|
||||
return x + torch.randn(
|
||||
*x.shape, generator=generator, dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
rng_func = custom_generator
|
||||
elif rng_func == "default_generator":
|
||||
|
||||
def default_generator(x):
|
||||
return x + torch.randn(*x.shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
rng_func = default_generator
|
||||
|
||||
def func(pred, x):
|
||||
return torch.cond(pred, rng_func, lambda x: 2 * x, [x])
|
||||
|
||||
compiled_func = torch.compile(func, backend="cudagraphs")
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_objs:
|
||||
for i in range(3):
|
||||
compiled_func(pred, x)
|
||||
|
||||
# Warn first for conditional node warmup, warn second for the
|
||||
# graph capture that we will actually use.
|
||||
self.assertEqual(len(warning_objs), 2)
|
||||
|
||||
warning_message = "You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph." # noqa: B950
|
||||
for warning in warning_objs:
|
||||
self.assertTrue(warning_message in str(warning.message))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestHopSchema)
|
||||
instantiate_parametrized_tests(TestControlFlowTraced)
|
||||
|
||||
instantiate_parametrized_tests(TestControlFlow)
|
||||
instantiate_parametrized_tests(AssociativeScanTests)
|
||||
|
||||
instantiate_parametrized_tests(TestControlFlowAndRNG)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
90
test/functorch/test_control_flow_cuda_initialization.py
Normal file
90
test/functorch/test_control_flow_cuda_initialization.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowInCUDAGraphInitialization(TestCase):
|
||||
# Duplicated from test_cuda_primary_ctx.py
|
||||
CTX_ALREADY_CREATED_ERR_MSG = (
|
||||
"Tests defined in TestControlFlowInCUDAGraphInitialization must be run in a process "
|
||||
"where CUDA contexts are never created. Use either run_test.py or add "
|
||||
"--subprocess to run each test in a different subprocess."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# Ensure context has not been created beforehand
|
||||
self.assertFalse(
|
||||
torch._C._cuda_hasPrimaryContext(0),
|
||||
TestControlFlowInCUDAGraphInitialization.CTX_ALREADY_CREATED_ERR_MSG,
|
||||
)
|
||||
|
||||
def _check_compile_cudagraphs(self, f, pred, *other_args):
|
||||
f = torch.compile(f, backend="cudagraphs")
|
||||
|
||||
outputs = []
|
||||
|
||||
for p in [pred, torch.logical_not(pred)]:
|
||||
for i in range(3):
|
||||
outputs.append(f(pred, *other_args).clone())
|
||||
|
||||
# We compute the eager output only after running cudagraphs
|
||||
# backend compiled function, in order to make sure that
|
||||
# cudagraph trees warms up the conditional part of the code
|
||||
# properly.
|
||||
eager_output = f(pred, *other_args)
|
||||
|
||||
for output in outputs:
|
||||
self.assertEqual(output, eager_output)
|
||||
|
||||
def test_cond_cudnn(self):
|
||||
# Tests that cublasCreate() does not break stream capture
|
||||
def f(pred, x, filters):
|
||||
return torch.cond(
|
||||
pred,
|
||||
lambda y: torch.sum(y),
|
||||
lambda y: torch.sum(torch.nn.functional.conv1d(y, filters)),
|
||||
[x],
|
||||
)
|
||||
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.randn(33, 16, 30, device="cuda")
|
||||
filters = torch.randn(20, 16, 5, device="cuda")
|
||||
|
||||
self._check_compile_cudagraphs(f, pred, x, filters)
|
||||
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
def test_cond_stft(self):
|
||||
# Tests that cufft plan creation does not break stream capture
|
||||
def f(pred, x):
|
||||
return torch.cond(
|
||||
pred,
|
||||
lambda y: torch.sum(y),
|
||||
lambda y: torch.sum(torch.stft(y, 512, return_complex=False)),
|
||||
[x],
|
||||
)
|
||||
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.ones(1024 * 1024, device="cuda")
|
||||
|
||||
self._check_compile_cudagraphs(f, pred, x)
|
||||
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -285,6 +285,7 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_autograd_fallback",
|
||||
"inductor/test_compiler_bisector",
|
||||
"test_privateuseone_python_backend",
|
||||
"functorch/test_control_flow_cuda_initialization",
|
||||
] + FSDP_TEST
|
||||
|
||||
# Test files that should always be run serially with other test files,
|
||||
@ -1277,6 +1278,7 @@ CUSTOM_HANDLERS = {
|
||||
"distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
|
||||
"distributed/rpc/test_share_memory": run_test_with_subprocess,
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
|
||||
"functorch/test_control_flow_cuda_initialization": run_test_with_subprocess,
|
||||
"doctests": run_doctests,
|
||||
"test_ci_sanity_check_fail": run_ci_sanity_check,
|
||||
"test_autoload_enable": test_autoload_enable,
|
||||
|
||||
@ -2362,6 +2362,12 @@ class _CUDAGraph:
|
||||
def debug_dump(self, debug_path: str) -> None: ...
|
||||
def raw_cuda_graph(self) -> _int: ...
|
||||
def raw_cuda_graph_exec(self) -> _int: ...
|
||||
@staticmethod
|
||||
def get_currently_capturing_graph() -> _CUDAGraph: ...
|
||||
def begin_capture_to_if_node(self, scalar_cuda_pred_tensor): ...
|
||||
def end_capture_to_conditional_node(self): ...
|
||||
@staticmethod
|
||||
def set_conditional_handle(handle, scalar_cuda_pred_tensor): ...
|
||||
|
||||
# Defined in torch/csrc/cuda/MemPool.cpp
|
||||
class _MemPool:
|
||||
|
||||
100
torch/_higher_order_ops/cudagraph_conditional_nodes.py
Normal file
100
torch/_higher_order_ops/cudagraph_conditional_nodes.py
Normal file
@ -0,0 +1,100 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
# TODO: Move this into torch/cuda/graphs.py
|
||||
|
||||
|
||||
class CUDAGraphCaptureControlFlowOpDispatchMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
self.supports_higher_order_operators = True
|
||||
super().__init__()
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func,
|
||||
types,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
):
|
||||
if func is torch.ops.higher_order.cond:
|
||||
# Re-enter the mode to support nested conditionals
|
||||
with self:
|
||||
return if_else_node(*args)
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.supports_higher_order_operators = True
|
||||
self.capture_stream = torch.cuda.Stream()
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func,
|
||||
types,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
with torch.cuda.graphs.thread_cuda_stream_capture_mode(
|
||||
torch.cuda.cudart().cudaStreamCaptureMode.Relaxed
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_boolean_scalar_cuda_tensor(pred: Any) -> bool:
|
||||
return (
|
||||
isinstance(pred, torch.Tensor)
|
||||
and pred.size() == torch.Size([])
|
||||
and pred.dtype == torch.bool
|
||||
and pred.is_cuda
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _if_body(pred: torch.Tensor) -> Generator[None, None, None]:
|
||||
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
|
||||
current_cuda_graph.begin_capture_to_if_node(pred)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
current_cuda_graph.end_capture_to_conditional_node()
|
||||
|
||||
|
||||
def if_else_node(pred: torch.Tensor, true_fn, false_fn, operands):
|
||||
if not pred.is_cuda:
|
||||
raise ValueError(
|
||||
"Conditions must be on a cuda device to use conditional node in cuda graphs"
|
||||
)
|
||||
# if-else is not supported yet in CUDA 12.4. Therefore, we use two if conditions, where one evaluates !pred
|
||||
outs = []
|
||||
|
||||
for lazy_pred, fn in [
|
||||
(lambda: pred, true_fn),
|
||||
(lambda: torch.logical_not(pred), false_fn),
|
||||
]:
|
||||
with _if_body(lazy_pred()):
|
||||
outs.append(fn(*operands))
|
||||
# Copy these two outputs into a new output buffer. Well,
|
||||
# actually, what we would like is to be able to merge these two
|
||||
# tensors into the same tensor... Is there an obvious way to do
|
||||
# that?
|
||||
if len(outs) == 2:
|
||||
for if_out, else_out in zip(
|
||||
pytree.tree_iter(outs[0]), pytree.tree_iter(outs[1])
|
||||
):
|
||||
if_out.copy_(else_out)
|
||||
assert len(outs) == 2
|
||||
return outs[0]
|
||||
@ -57,6 +57,10 @@ from torch import Tensor
|
||||
from torch._dynamo.callback import CallbackTrigger
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state
|
||||
from torch._higher_order_ops.cudagraph_conditional_nodes import (
|
||||
ControlFlowOpWarmupDispatchMode,
|
||||
CUDAGraphCaptureControlFlowOpDispatchMode,
|
||||
)
|
||||
from torch._inductor.compile_fx import (
|
||||
align_inputs_from_check_idxs,
|
||||
copy_misaligned_inputs,
|
||||
@ -682,6 +686,7 @@ class CUDAWarmupNode:
|
||||
_use_cuda_memory_pool_manager(
|
||||
self.device_index, self.cuda_graphs_pool, self.stream
|
||||
),
|
||||
ControlFlowOpWarmupDispatchMode(),
|
||||
get_history_recording(),
|
||||
):
|
||||
out = self.wrapped_function.model(new_inputs)
|
||||
@ -1275,6 +1280,7 @@ class CUDAGraphNode:
|
||||
pool=self.cuda_graphs_pool,
|
||||
capture_error_mode="thread_local",
|
||||
),
|
||||
CUDAGraphCaptureControlFlowOpDispatchMode(),
|
||||
get_history_recording(),
|
||||
):
|
||||
static_outputs = model(inputs)
|
||||
|
||||
@ -112,5 +112,25 @@ void THCPGraph_init(PyObject* module) {
|
||||
// compile error.
|
||||
return reinterpret_cast<uintptr_t>(graph_exec);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_static(
|
||||
"get_currently_capturing_graph",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::get_currently_capturing_graph),
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"begin_capture_to_if_node",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::begin_capture_to_if_node),
|
||||
py::arg("scalar_cuda_pred_tensor"))
|
||||
.def(
|
||||
"end_capture_to_conditional_node",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::end_capture_to_conditional_node))
|
||||
.def_static(
|
||||
"set_conditional_handle",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::set_conditional_handle),
|
||||
py::arg("handle"),
|
||||
py::arg("scalar_cuda_pred_tensor"));
|
||||
}
|
||||
|
||||
@ -120,6 +120,22 @@ void initCudartBindings(PyObject* module) {
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
return {device_free, device_total};
|
||||
});
|
||||
|
||||
py::enum_<cudaStreamCaptureMode>(
|
||||
cudart,
|
||||
"cuda"
|
||||
"StreamCaptureMode")
|
||||
.value("Global", cudaStreamCaptureModeGlobal)
|
||||
.value("ThreadLocal", cudaStreamCaptureModeThreadLocal)
|
||||
.value("Relaxed", cudaStreamCaptureModeRelaxed);
|
||||
|
||||
cudart.def(
|
||||
"cuda"
|
||||
"ThreadExchangeStreamCaptureMode",
|
||||
[](cudaStreamCaptureMode mode) -> cudaStreamCaptureMode {
|
||||
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
return mode;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
@ -272,6 +273,30 @@ class graph:
|
||||
_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _graph_no_gc(cuda_graph, pool, stream, capture_error_mode):
|
||||
"""This is an internal function used to do stream capture without
|
||||
calling torch.cuda.synchronize(), gc.collect(), and
|
||||
torch.cuda.empty_cache(). Unfortunately, cudagraph trees runs its
|
||||
eager warmup inside of the context manager
|
||||
_use_cuda_memory_pool_manager(), which makes captures_underway in
|
||||
CUDACachingAllocator.cpp non-empty. We need this in order to
|
||||
warmup conditional higher order operators, like torch.cond() and
|
||||
torch.while_loop(). torch.cuda.empty_cache() will fail if
|
||||
captures_underway is non-empty. Removing torch.cuda.synchronize()
|
||||
and gc.collect() is not strictly speaking required, but they are
|
||||
expensive an unnecessary operations.
|
||||
"""
|
||||
stream_ctx = torch.cuda.stream(stream)
|
||||
pool = () if pool is None else (pool,)
|
||||
with stream_ctx:
|
||||
cuda_graph.capture_begin(*pool, capture_error_mode=capture_error_mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cuda_graph.capture_end()
|
||||
|
||||
|
||||
@overload
|
||||
def make_graphed_callables(
|
||||
callables: _ModuleOrCallable,
|
||||
@ -610,3 +635,46 @@ def make_graphed_callables(
|
||||
return ret[0]
|
||||
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def thread_cuda_stream_capture_mode(new_mode):
|
||||
r"""Changes current thread's stream capture mode to `new_mode` upon __enter__ and resets the mode upon __exit__.
|
||||
|
||||
The only documentation on a thread's stream capture mode is here:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
|
||||
However, it is a little bit inadequate, so here is a more in-depth description.
|
||||
|
||||
Both CPU threads and capturing cuda streams have a capture mode. A
|
||||
cuda stream's capture mode is set at cudaStreamBeginCapture() and
|
||||
can never be changed. Meanwhile all CPU threads start with a capture
|
||||
mode of cudaStreamCaptureModeGlobal, which can be changed at any
|
||||
time.
|
||||
|
||||
Whenever a thread executes an unsafe CUDA action while CUDA
|
||||
streams are capturing, it follows the following logic to determine
|
||||
whether to invalidate those streams:
|
||||
|
||||
if capture_mode_this_thread == cudaStreamCaptureModeRelaxed:
|
||||
never invalidate any capturing cuda streams whatsoever.
|
||||
elif capture_mode_this_thread == cudaStreamCaptureModeThreadLocal:
|
||||
invalidate any cuda streams for which cudaStreamBeginCapture() was called by this
|
||||
thread, except for streams whose capture mode is cudaStreamCaptureModeRelaxed.
|
||||
elif capture_mode_this_thread == cudaStreamCaptureModeGlobal:
|
||||
invalidate all cuda streams that are currently capturing on any thread,
|
||||
except for streams whose capture mode is cudaStreamCaptureModeRelaxed and for
|
||||
streams for which cudaStreamCaptureBegin() was called with
|
||||
cudaStreamCaptureModeThreadLocal on a thread other than this one.
|
||||
|
||||
In practice, changed the current capture mode to
|
||||
cudaStreamCaptureModeRelaxed in particular is helpful for enabling
|
||||
developers to do "unsafe" things that we know are safe in our
|
||||
case.
|
||||
"""
|
||||
cudart = torch.cuda.cudart()
|
||||
old_mode = cudart.cudaThreadExchangeStreamCaptureMode(new_mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cudart.cudaThreadExchangeStreamCaptureMode(old_mode)
|
||||
|
||||
@ -1567,6 +1567,12 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
|
||||
)
|
||||
|
||||
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES = TEST_CUDA_GRAPH and (
|
||||
torch.version.cuda and (
|
||||
(int(torch.version.cuda.split(".")[0]) >= 12 and int(torch.version.cuda.split(".")[1]) >= 4) or
|
||||
(int(torch.version.cuda.split(".")[0]) >= 13)
|
||||
)
|
||||
)
|
||||
|
||||
TEST_CUDA_PYTHON_BINDINGS = _check_module_exists("cuda.bindings") and (
|
||||
torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12
|
||||
|
||||
Reference in New Issue
Block a user