mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)"
This reverts commit c7515da7b00de40942c83dc5856b6daec727e280. Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
This commit is contained in:
@ -347,7 +347,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Please ensure to utilize the CUDAGeneratorImpl::graphsafe_set_state method during capturing.");
|
||||
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
|
@ -11,50 +11,9 @@
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
void external_stream_deleter(cudaStream_t* stream) {
|
||||
if (stream != nullptr) {
|
||||
cudaStreamDestroy(*stream);
|
||||
delete stream;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
UniquePtrExternalCudaStream create_external_stream() {
|
||||
// From:
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g793d7d4e474388ddfda531603dc34aa3
|
||||
// "Capture must be ended on the same stream in which it was initiated, and it
|
||||
// may only be initiated if the stream is not already in capture mode."
|
||||
|
||||
// Since pytorch uses a pool of 32 pre-allocated cuda streams,
|
||||
// should a user nest 32 conditional nodes, there would be an error
|
||||
// for the 32nd node, since that node's stream would already be in
|
||||
// capture mode. The easiest solution is to handle stream creation
|
||||
// and deletion ourselves.
|
||||
|
||||
// we use cudaStreamNonBlocking because every default cuda stream in
|
||||
// pytorch uses that flag for all streams used for stream capture
|
||||
// (see kDefaultFlags in CUDAStream.cpp). This would need to be kept
|
||||
// in sync, should that ever change. Or kDefaultFlags needs to be
|
||||
// exposed in a header file.
|
||||
auto stream_ptr = std::make_unique<cudaStream_t>();
|
||||
AT_CUDA_CHECK(
|
||||
cudaStreamCreateWithFlags(stream_ptr.get(), cudaStreamNonBlocking));
|
||||
return UniquePtrExternalCudaStream(
|
||||
stream_ptr.release(), external_stream_deleter);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
static bool _cuda_graphs_debug = false;
|
||||
constexpr int kSynchronizeBusyWaitMillis = 10;
|
||||
|
||||
// To support stream capture across multiple threads, we use a global
|
||||
// hashmap mapping cuda stream capture IDs to CUDAGraph objects. This
|
||||
// was originally a thread_local std::stack<CUDAGraph*>, but that was
|
||||
// not acceptable since stream capture does span threads in certain
|
||||
// circumstances (in particular, during autograd).
|
||||
static std::mutex _currently_capturing_graphs_mutex;
|
||||
static ska::flat_hash_map<CaptureId_t, CUDAGraph*> _currently_capturing_graphs;
|
||||
|
||||
MempoolId_t graph_pool_handle() {
|
||||
// Sets just the second value, to distinguish it from MempoolId_ts created from
|
||||
// cudaStreamGetCaptureInfo id_s in capture_begin.
|
||||
@ -118,13 +77,11 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) {
|
||||
cuda_gen->register_graph(this);
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_mode) {
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
|
||||
TORCH_CHECK(!has_graph_exec_,
|
||||
"This CUDAGraph instance already owns a captured graph. "
|
||||
"To capture a new graph, create a new instance.");
|
||||
|
||||
capture_mode_ = capture_mode;
|
||||
|
||||
// default generator is always registered
|
||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
std::nullopt, cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -162,7 +119,12 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode
|
||||
// Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an
|
||||
// autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator
|
||||
// due to the capture status being updated _after_ a capture had already started.
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter());
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
|
||||
});
|
||||
|
||||
// At this point, any NCCL watchdogs should be aware that we are in capture mode
|
||||
// and therefore should not enqueue any additional work that could be event-queried.
|
||||
@ -182,10 +144,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_));
|
||||
TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
_currently_capturing_graphs.emplace(capture_id_, this);
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_end() {
|
||||
@ -196,14 +154,6 @@ void CUDAGraph::capture_end() {
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_));
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(capture_id_),
|
||||
"capture_end() called before capture_begin().");
|
||||
_currently_capturing_graphs.erase(capture_id_);
|
||||
}
|
||||
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
|
||||
TORCH_CHECK(graph_ != nullptr, "Invalid capture.");
|
||||
@ -250,19 +200,6 @@ void CUDAGraph::capture_end() {
|
||||
wholegraph_increments = generator_state->capture_epilogue();
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
bool any_wholegraph_increments_nonzero = false;
|
||||
for (auto& [generator_state, wholegraph_increments] : captured_generator_states_) {
|
||||
if (wholegraph_increments != 0) {
|
||||
any_wholegraph_increments_nonzero = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (any_wholegraph_increments_nonzero && !descendent_graphs_.empty()) {
|
||||
TORCH_WARN("You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph.");
|
||||
}
|
||||
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
|
||||
|
||||
size_t numCUDAGraphNodes = 0;
|
||||
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, nullptr, &numCUDAGraphNodes));
|
||||
if (numCUDAGraphNodes == 0) {
|
||||
@ -363,7 +300,7 @@ void CUDAGraph::reset() {
|
||||
|
||||
// Returns an id another graph's capture_begin can use to share the same memory pool as this graph.
|
||||
MempoolId_t CUDAGraph::pool() {
|
||||
TORCH_CHECK(has_graph_exec_,
|
||||
TORCH_CHECK(has_graph_exec_,
|
||||
"Called CUDAGraph::pool() without a preceding successful capture.");
|
||||
return mempool_id_;
|
||||
}
|
||||
@ -388,250 +325,4 @@ CUDAGraph::~CUDAGraph() {
|
||||
#endif
|
||||
}
|
||||
|
||||
CUDAGraph* CUDAGraph::get_currently_capturing_graph() {
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t current_capture_id = -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, ¤t_capture_id));
|
||||
TORCH_CHECK(
|
||||
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive,
|
||||
"The current stream is not currently capturing.");
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(current_capture_id),
|
||||
"get_currently_capturing_graph() can be used only between capture_begin() and capture_end(). Did you use a stream without making it depend upon the original stream used for capture?");
|
||||
return _currently_capturing_graphs.at(current_capture_id);
|
||||
}
|
||||
|
||||
void CUDAGraph::begin_capture_to_if_node(
|
||||
const at::Tensor& scalar_cuda_pred_tensor) {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
TORCH_CHECK(
|
||||
!has_graph_exec_,
|
||||
"begin_capture_to_if_node() must be called before capture_begin()");
|
||||
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t currently_capturing_graph{};
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(), &status, nullptr, ¤tly_capturing_graph));
|
||||
TORCH_CHECK(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"capture_begin() must be called before begin_capture_to_if_node()");
|
||||
cudaGraphConditionalHandle handle{};
|
||||
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
|
||||
&handle, currently_capturing_graph, 0, 0));
|
||||
|
||||
set_conditional_handle(handle, scalar_cuda_pred_tensor);
|
||||
|
||||
const cudaGraphNode_t* dependencies{};
|
||||
size_t num_dependencies = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(),
|
||||
&status,
|
||||
nullptr,
|
||||
¤tly_capturing_graph,
|
||||
&dependencies,
|
||||
&num_dependencies));
|
||||
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
|
||||
|
||||
cudaGraphNodeParams params{};
|
||||
params.type = cudaGraphNodeTypeConditional;
|
||||
params.conditional.handle = handle;
|
||||
params.conditional.type = cudaGraphCondTypeIf;
|
||||
params.conditional.size = 1;
|
||||
|
||||
cudaGraphNode_t cond_node{};
|
||||
AT_CUDA_CHECK(cudaGraphAddNode(
|
||||
&cond_node,
|
||||
currently_capturing_graph,
|
||||
dependencies,
|
||||
num_dependencies,
|
||||
¶ms));
|
||||
|
||||
cudaGraph_t if_node_child_graph = params.conditional.phGraph_out[0];
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
|
||||
|
||||
UniquePtrExternalCudaStream child_stream = create_external_stream();
|
||||
conditional_graph_capture_streams_ids_.push(-1);
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_child_allocate_filter());
|
||||
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
|
||||
*child_stream, if_node_child_graph, nullptr, nullptr, 0, capture_mode_));
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
// We need to get the raw_stream here before emplace() to prevent
|
||||
// std::move(child_stream) from potentially executing before
|
||||
// *child_stream.
|
||||
cudaStream_t raw_stream = *child_stream;
|
||||
conditional_node_streams_.emplace(
|
||||
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
|
||||
std::move(child_stream));
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
_currently_capturing_graphs.emplace(
|
||||
conditional_graph_capture_streams_ids_.top(), this);
|
||||
}
|
||||
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
cudaGraphConditionalHandle CUDAGraph::begin_capture_to_while_loop_node(
|
||||
const at::Tensor& scalar_cuda_pred_tensor) {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
cudaStreamCaptureStatus status{};
|
||||
cudaGraph_t currently_capturing_graph{};
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(), &status, nullptr, ¤tly_capturing_graph));
|
||||
TORCH_CHECK(
|
||||
status == cudaStreamCaptureStatusActive,
|
||||
"capture_begin() must be called before begin_capture_to_while_loop_node()");
|
||||
cudaGraphConditionalHandle handle{};
|
||||
AT_CUDA_CHECK(cudaGraphConditionalHandleCreate(
|
||||
&handle, currently_capturing_graph, 0, 0));
|
||||
|
||||
set_conditional_handle(handle, scalar_cuda_pred_tensor);
|
||||
|
||||
const cudaGraphNode_t* dependencies{};
|
||||
size_t num_dependencies = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
getCurrentCUDAStream(),
|
||||
&status,
|
||||
nullptr,
|
||||
¤tly_capturing_graph,
|
||||
&dependencies,
|
||||
&num_dependencies));
|
||||
TORCH_CHECK(status == cudaStreamCaptureStatusActive);
|
||||
|
||||
cudaGraphNodeParams params{};
|
||||
params.type = cudaGraphNodeTypeConditional;
|
||||
params.conditional.handle = handle;
|
||||
params.conditional.type = cudaGraphCondTypeWhile;
|
||||
params.conditional.size = 1;
|
||||
|
||||
cudaGraphNode_t cond_node{};
|
||||
AT_CUDA_CHECK(cudaGraphAddNode(
|
||||
&cond_node,
|
||||
currently_capturing_graph,
|
||||
dependencies,
|
||||
num_dependencies,
|
||||
¶ms));
|
||||
|
||||
cudaGraph_t while_node_child_graph = params.conditional.phGraph_out[0];
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamUpdateCaptureDependencies(
|
||||
getCurrentCUDAStream(), &cond_node, 1, cudaStreamSetCaptureDependencies));
|
||||
|
||||
UniquePtrExternalCudaStream child_stream = create_external_stream();
|
||||
conditional_graph_capture_streams_ids_.push(-1);
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_child_allocate_filter());
|
||||
AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph(
|
||||
*child_stream,
|
||||
while_node_child_graph,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
capture_mode_));
|
||||
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
|
||||
*child_stream, &status, &conditional_graph_capture_streams_ids_.top()));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive);
|
||||
|
||||
// We need to get the raw_stream here before emplace() to prevent
|
||||
// std::move(child_stream) from potentially executing before
|
||||
// *child_stream.
|
||||
cudaStream_t raw_stream = *child_stream;
|
||||
conditional_node_streams_.emplace(
|
||||
getStreamFromExternal(raw_stream, getCurrentCUDAStream().device_index()),
|
||||
std::move(child_stream));
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
_currently_capturing_graphs.emplace(
|
||||
conditional_graph_capture_streams_ids_.top(), this);
|
||||
}
|
||||
|
||||
return handle;
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
return cudaGraphConditionalHandle{};
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::end_capture_to_conditional_node() {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_currently_capturing_graphs_mutex);
|
||||
CaptureId_t capture_id = conditional_graph_capture_streams_ids_.top();
|
||||
TORCH_CHECK(
|
||||
_currently_capturing_graphs.count(capture_id),
|
||||
"capture_end() called before capture_begin().");
|
||||
_currently_capturing_graphs.erase(capture_id);
|
||||
}
|
||||
|
||||
CUDAStream stream = conditional_node_streams_.top().first.current_stream();
|
||||
cudaGraph_t graph{};
|
||||
AT_CUDA_CHECK(cudaStreamEndCapture(stream.stream(), &graph));
|
||||
descendent_graphs_.push_back(graph);
|
||||
conditional_node_streams_.pop();
|
||||
conditional_graph_capture_streams_ids_.pop();
|
||||
|
||||
c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_);
|
||||
if (conditional_graph_capture_streams_ids_.empty()) {
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_allocate_filter());
|
||||
} else {
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(
|
||||
capture_dev_, mempool_id_, create_child_allocate_filter());
|
||||
}
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::function<bool(cudaStream_t)> CUDAGraph::create_allocate_filter() {
|
||||
return [this](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
|
||||
};
|
||||
}
|
||||
|
||||
std::function<bool(cudaStream_t)> CUDAGraph::create_child_allocate_filter() {
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
return [¤t_capture_id = conditional_graph_capture_streams_ids_.top()](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status{};
|
||||
CaptureId_t stream_capture_id{};
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == current_capture_id;
|
||||
};
|
||||
#else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
AT_ERROR(
|
||||
__func__,
|
||||
" CUDA Graphs conditional nodes are not supported for cuda version < 12.4");
|
||||
return std::function<bool(cudaStream_t)>();
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
} // namespace at::cuda
|
||||
|
@ -1,30 +0,0 @@
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
__global__ void set_conditional_handle_kernel(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const bool* value) {
|
||||
cudaGraphSetConditional(handle, *value);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::set_conditional_handle(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const Tensor& scalar_cuda_pred_tensor) {
|
||||
#if !(defined(USE_ROCM)) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
set_conditional_handle_kernel<<<1, 1, 0, getCurrentCUDAStream()>>>(
|
||||
handle, scalar_cuda_pred_tensor.const_data_ptr<bool>());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
AT_ERROR("not allowed");
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
@ -3,20 +3,9 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
#include <limits>
|
||||
#include <stack>
|
||||
|
||||
#if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
// this type is not defined until CUDA 12.4, but we use it as a
|
||||
// parameter type and return type in some below functions, so we give
|
||||
// it the same definition as in CUDA 12.4.
|
||||
typedef unsigned long long cudaGraphConditionalHandle;
|
||||
#endif // defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
|
||||
namespace at {
|
||||
|
||||
struct Generator;
|
||||
@ -25,9 +14,6 @@ struct CUDAGeneratorState;
|
||||
|
||||
namespace cuda {
|
||||
|
||||
using UniquePtrExternalCudaStream =
|
||||
std::unique_ptr<cudaStream_t, void (*)(cudaStream_t*)>;
|
||||
|
||||
// Standalone way to get a unique mempool id usable as a pool=... argument
|
||||
// to CUDAGraph::capture_begin
|
||||
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
|
||||
@ -36,26 +22,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
CUDAGraph();
|
||||
~CUDAGraph();
|
||||
|
||||
// Copy and move constructors and assignments are disabled. These
|
||||
// were disabled because pybind11 believed that CUDAGraph was copy
|
||||
// constructable because
|
||||
// pybind11::is_copy_constructible<CUDAGraph>::value originally
|
||||
// evaluated to true. However, it cannot generate a copy constructor
|
||||
// because CUDAGeneratorState, one of CUDAGraph's members, is an
|
||||
// incomplete type unless CUDAGeneratorImpl.h is included. However,
|
||||
// that would create a circular dependency between
|
||||
// CUDAGeneratorImpl.h and CUDAGraph.h. Disabling the copy and move
|
||||
// constructors is the most straightforward way to prevent pybind11
|
||||
// from trying to generate default implementations of them.
|
||||
//
|
||||
// We needed pybind11 to return a reference to a CUDAGraph as part
|
||||
// of wrapping CUDAGraph::get_currently_capturing_graph, which
|
||||
// unearthed the above problem.
|
||||
CUDAGraph(const CUDAGraph&) = delete;
|
||||
CUDAGraph& operator=(const CUDAGraph&) = delete;
|
||||
CUDAGraph(CUDAGraph&& other) = delete;
|
||||
CUDAGraph& operator=(CUDAGraph&& other) = delete;
|
||||
|
||||
static void inc_pending_event_queries();
|
||||
static void dec_pending_event_queries();
|
||||
static int num_pending_event_queries();
|
||||
@ -72,19 +38,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
void enable_debug_mode();
|
||||
void debug_dump(const std::string& debug_path);
|
||||
|
||||
static CUDAGraph* get_currently_capturing_graph();
|
||||
void begin_capture_to_if_node(const Tensor& scalar_cuda_pred_tensor);
|
||||
cudaGraphConditionalHandle begin_capture_to_while_loop_node(
|
||||
const Tensor& scalar_cuda_pred_tensor);
|
||||
void end_capture_to_conditional_node();
|
||||
static void set_conditional_handle(
|
||||
cudaGraphConditionalHandle handle,
|
||||
const Tensor& scalar_cuda_pred_tensor);
|
||||
|
||||
private:
|
||||
std::function<bool(cudaStream_t)> create_allocate_filter();
|
||||
std::function<bool(cudaStream_t)> create_child_allocate_filter();
|
||||
|
||||
protected:
|
||||
cudaGraph_t graph_ = nullptr;
|
||||
cudaGraphExec_t graph_exec_ = nullptr;
|
||||
@ -101,7 +54,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
|
||||
// the ID assigned by cuda during graph capture,
|
||||
// used to identify when a stream is participating in capture
|
||||
CaptureId_t capture_id_ = std::numeric_limits<CaptureId_t>::max();
|
||||
CaptureId_t capture_id_ = -1;
|
||||
|
||||
// uuid used to request a particular private mempool from CUDACachingAllocator.
|
||||
// By default, this will be set to {id_, 0}.
|
||||
@ -132,15 +85,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
// init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor
|
||||
static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1;
|
||||
c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
|
||||
|
||||
cudaStreamCaptureMode capture_mode_{};
|
||||
|
||||
#if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040)
|
||||
std::stack<std::pair<at::cuda::CUDAStreamGuard, UniquePtrExternalCudaStream>>
|
||||
conditional_node_streams_;
|
||||
std::stack<CaptureId_t> conditional_graph_capture_streams_ids_;
|
||||
std::vector<cudaGraph_t> descendent_graphs_;
|
||||
#endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
@ -929,10 +929,6 @@ and you suspect its runtime is at least somewhat CPU-limited.
|
||||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture
|
||||
.. _cudaGraphLaunch:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597
|
||||
.. _issue 144787:
|
||||
https://github.com/pytorch/pytorch/issues/144787#issuecomment-2606480564
|
||||
.. _conditional nodes:
|
||||
https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
||||
|
||||
PyTorch API
|
||||
^^^^^^^^^^^
|
||||
@ -1021,9 +1017,6 @@ Violating any of these will likely cause a runtime error:
|
||||
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
|
||||
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
|
||||
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
|
||||
* Dynamic control flow (based on CPU or GPU data) is prohibited, unless it is based on GPU data and implemented via higher order operators
|
||||
torch.cond() and torch.while_loop(). See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.
|
||||
|
||||
|
||||
|
||||
Violating any of these will likely cause silent numerical errors or undefined behavior:
|
||||
@ -1032,6 +1025,7 @@ Violating any of these will likely cause silent numerical errors or undefined be
|
||||
* No non-captured CUDA work may run in this process (on any thread) while capture is underway.
|
||||
* CPU work is not captured. If the captured ops include CPU work, that work will be elided during replay.
|
||||
* Every replay reads from and writes to the same (virtual) memory addresses.
|
||||
* Dynamic control flow (based on CPU or GPU data) is prohibited.
|
||||
* Dynamic shapes are prohibited. The graph assumes every tensor in the captured op sequence
|
||||
has the same size and layout in every replay.
|
||||
* Using multiple streams in a capture is allowed, but there are :ref:`restrictions<multistream-capture>`.
|
||||
@ -1340,45 +1334,3 @@ If, in the live workload, your callables will run in an order that occasionally
|
||||
or if they'll run concurrently, passing them as a tuple to a single invocation of
|
||||
:func:`~torch.cuda.make_graphed_callables` is not allowed. Instead, you must call
|
||||
:func:`~torch.cuda.make_graphed_callables` separately for each one.
|
||||
|
||||
.. _graph-data-dependent-control-flow:
|
||||
|
||||
Data Dependent Control Flow
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Data-dependent control flow can with cuda graphs in limited cases if
|
||||
the control flow is implemented using torch.cond() or
|
||||
torch.while_loop(). If your function uses these functions, compiling
|
||||
it with the "cudagraphs" backend will enable control flow in the
|
||||
resulting cuda graph via `conditional nodes`_.
|
||||
|
||||
Unfortunately, eager mode execution does not work due to reasons
|
||||
described in `issue 144787`_.
|
||||
Support for inductor backend to torch.compile is not available yet, but there is no fundamental blocker.
|
||||
|
||||
An example of using the cudagraphs backend to torch.compile on code
|
||||
using torch.cond is demonstrated below::
|
||||
|
||||
import torch
|
||||
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
x = torch.randn(4, device="cuda", requires_grad=False)
|
||||
pred = torch.tensor(False, device="cuda", requires_grad=False)
|
||||
def foo(pred, x):
|
||||
with torch.inference_mode():
|
||||
return torch.cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
# First call will run eager for warmup, second call will do graph
|
||||
# capture followed by graph replay, third call and beyond will do
|
||||
# just graph replay.
|
||||
compiled_foo = torch.compile(foo, backend="cudagraphs")
|
||||
for i in range(3):
|
||||
y = compiled_foo(pred, x)
|
||||
|
||||
# will output x.sin()
|
||||
y = compiled_foo(~pred, x)
|
||||
|
@ -13,7 +13,7 @@ For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraph
|
||||
|
||||
CUDA Graphs can give large speedups, especially for models with high CPU overhead or small compute. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
|
||||
|
||||
- Arbitrary Control Flow is not possible (However, control flow expressed via torch.cond() and torch.while_loop() can be captured in a CUDA Graph. See :ref:`Data Dependent Control Flow<graph-data-dependent-control-flow>`.)
|
||||
- Control Flow is not possible
|
||||
- Kernels which trigger host to device syncs (such as .item()) errors
|
||||
- All input arguments to kernels are fixed to what they were recorded
|
||||
- CUDA Memory addresses are fixed, however the values of the memory at those addresses can change
|
||||
|
@ -1,16 +1,13 @@
|
||||
# Owner(s): ["module: functorch"]
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from functorch.experimental import control_flow
|
||||
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
|
||||
from torch._dynamo.testing import normalize_gm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._higher_order_ops.associative_scan import (
|
||||
_fake_associative_scan,
|
||||
associative_scan,
|
||||
@ -36,7 +33,6 @@ from torch.testing._internal.common_utils import (
|
||||
skipIfCrossRef,
|
||||
skipIfRocm,
|
||||
skipIfTorchDynamo,
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
@ -44,43 +40,6 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def check_cudagraphs_not_skipped(test_case):
|
||||
old_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# reset before the assert, because otherwise the reset is skipped
|
||||
new_cudagraph_skips = counters["inductor"]["cudagraph_skips"]
|
||||
counters["inductor"]["cudagraph_skips"] = old_cudagraph_skips
|
||||
test_case.assertEqual(
|
||||
counters["inductor"]["cudagraph_skips"], new_cudagraph_skips
|
||||
)
|
||||
|
||||
|
||||
def _check_compile_cudagraph(test_case, fn, args):
|
||||
# test cudagraphs backend
|
||||
cudagraphs_compiled_fn = torch.compile(fn, backend="cudagraphs")
|
||||
# We run 3 times.
|
||||
# This is what cuda graph trees does for the first 3 runs:
|
||||
# 1) run in eager mode, for warmup.
|
||||
# 2) do stream capture followed by graph replay.
|
||||
# 3 and beyond) do graph replay
|
||||
# So we need to get to iteration 3 to test all ways of running.
|
||||
outputs = []
|
||||
for i in range(3):
|
||||
with check_cudagraphs_not_skipped(test_case):
|
||||
outputs.append(
|
||||
pytree.tree_map(
|
||||
lambda x: x.clone() if isinstance(x, torch.Tensor) else x,
|
||||
cudagraphs_compiled_fn(*args),
|
||||
)
|
||||
)
|
||||
eager_res = fn(*args)
|
||||
for output in outputs:
|
||||
test_case.assertEqual(eager_res, output)
|
||||
|
||||
|
||||
# TODO: pull these helpers from AOTAutograd later
|
||||
def to_fun(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
@ -274,9 +233,7 @@ def _while_loop_tests():
|
||||
return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
|
||||
|
||||
i1, j1, x1, y1 = while_loop(
|
||||
cond_fn_nested,
|
||||
body_fn_nested,
|
||||
(i1, j1, x1, y1),
|
||||
cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
|
||||
)
|
||||
return i1 - 1, j1.clone(), x1 * 2, y1 / 2
|
||||
|
||||
@ -440,7 +397,6 @@ def _while_loop_tests():
|
||||
const_and_symint_output,
|
||||
(torch.randn(2, 3, requires_grad=True),),
|
||||
),
|
||||
# I need to add a test here that uses just a dictionary and see what happens.
|
||||
}
|
||||
|
||||
|
||||
@ -4386,24 +4342,6 @@ def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tens
|
||||
torch.randn(2, 3),
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
def test_cond_traced_not_nested_cudagraphs(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def f(x, y):
|
||||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
x = torch.randn(4)
|
||||
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(True).cuda()])
|
||||
_check_compile_cudagraph(self, f, [x.cuda(), torch.tensor(False).cuda()])
|
||||
|
||||
def test_while_loop_nested_traced(self):
|
||||
fn, inp = WHILE_LOOP_TESTS["nested"]
|
||||
graphs = self._check_tracing(fn, inp)
|
||||
@ -4650,43 +4588,6 @@ def forward(self, arg0_1):
|
||||
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
||||
self._check_compile(fn, inp, backend=backend)
|
||||
|
||||
@parametrize(
|
||||
"while_loop_test",
|
||||
set(WHILE_LOOP_TESTS.keys())
|
||||
- {"const_and_symint_output", "int_carry", "pytree_int_carry"},
|
||||
)
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
def test_while_loop_cuda_stream_capture(self, while_loop_test):
|
||||
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
||||
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
fn = copy.deepcopy(fn)
|
||||
fn.cuda()
|
||||
inp = pytree.tree_map(lambda x: x.cuda(), inp)
|
||||
|
||||
_check_compile_cudagraph(self, fn, inp)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@parametrize(
|
||||
"while_loop_test", {"const_and_symint_output", "int_carry", "pytree_int_carry"}
|
||||
)
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
def test_while_loop_cuda_stream_capture_fails(self, while_loop_test):
|
||||
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
||||
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
fn = copy.deepcopy(fn)
|
||||
fn.cuda()
|
||||
inp = pytree.tree_map(lambda x: x.cuda(), inp)
|
||||
|
||||
_check_compile_cudagraph(self, fn, inp)
|
||||
|
||||
@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
|
||||
@skipIfCrossRef # Arg order changes with cross ref
|
||||
def test_while_loop_simple_with_linear_compile_check_graph(self):
|
||||
@ -4896,13 +4797,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
f(x, torch.tensor(True), torch.tensor(True)),
|
||||
)
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
_check_compile_cudagraph(
|
||||
self,
|
||||
f,
|
||||
[x.cuda(), torch.tensor(True).cuda(), torch.tensor(True).cuda()],
|
||||
)
|
||||
|
||||
def test_cond_functionalized(self):
|
||||
def true_fn(x):
|
||||
y = x.sin()
|
||||
@ -4916,9 +4810,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
pred = x.shape[0] == 1
|
||||
return cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
def f_(x, y):
|
||||
return cond(y, true_fn, false_fn, [x])
|
||||
|
||||
example_inputs = (torch.ones(4, 5),)
|
||||
functional_f = torch.func.functionalize(f)
|
||||
self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
|
||||
@ -4937,10 +4828,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1
|
||||
|
||||
self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
pred = torch.tensor(example_inputs[0].shape[0] == 1, device="cuda")
|
||||
_check_compile_cudagraph(self, f_, [torch.ones(4, 5).cuda(), pred])
|
||||
|
||||
def test_cond_accepts_torch_function_as_inputs(self):
|
||||
a = torch.randn(3, 4)
|
||||
b = torch.randn(3, 4)
|
||||
@ -5460,13 +5347,6 @@ def forward(self, arg0_1):
|
||||
return (mul,)""",
|
||||
)
|
||||
|
||||
if TEST_CUDA_GRAPH_CONDITIONAL_NODES:
|
||||
_check_compile_cudagraph(
|
||||
self,
|
||||
f,
|
||||
[x.cuda(), torch.tensor(False).cuda(), torch.tensor(False).cuda()],
|
||||
)
|
||||
|
||||
def test_raise_error_on_mismatch_type_size(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
@ -7783,99 +7663,11 @@ class TestHopSchema(TestCase):
|
||||
self.assertEqual(schema.parse(str(schema)), schema)
|
||||
|
||||
|
||||
class DynamicCondModel(torch.nn.Module):
|
||||
def __init__(self, input_size=16, hidden_size=64, output_size=10):
|
||||
super().__init__()
|
||||
self.fc1_0 = torch.nn.Linear(input_size, hidden_size)
|
||||
self.fc1_1 = torch.nn.Linear(input_size, 32)
|
||||
self.fc1_2 = torch.nn.Linear(32, hidden_size)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
def true_fn(x):
|
||||
return self.fc1_0(x)
|
||||
|
||||
def false_fn(x):
|
||||
x = self.fc1_1(x)
|
||||
return self.fc1_2(x)
|
||||
|
||||
# use PyTorch control flow API
|
||||
pred = torch.tensor(x.sum() > 0, device="cuda")
|
||||
x = cond(pred, true_fn, false_fn, [x])
|
||||
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowNN(TestCase):
|
||||
def test_cond_in_NN(self):
|
||||
model = DynamicCondModel().cuda()
|
||||
|
||||
x = torch.randn(16, device="cuda")
|
||||
_check_compile_cudagraph(self, model, [x])
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowAndRNG(TestCase):
|
||||
@parametrize("rng_func", ["custom_generator", "default_generator"])
|
||||
def test_rng_with_conditional_nodes_warns(self, rng_func):
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.ones(10, dtype=torch.float32, device="cuda")
|
||||
|
||||
if rng_func == "custom_generator":
|
||||
self.skipTest(
|
||||
"randn() currently does not work with a generator argument in dynamo."
|
||||
)
|
||||
generator = torch.Generator("cuda")
|
||||
|
||||
def custom_generator(x):
|
||||
return x + torch.randn(
|
||||
*x.shape, generator=generator, dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
rng_func = custom_generator
|
||||
elif rng_func == "default_generator":
|
||||
|
||||
def default_generator(x):
|
||||
return x + torch.randn(*x.shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
rng_func = default_generator
|
||||
|
||||
def func(pred, x):
|
||||
return torch.cond(pred, rng_func, lambda x: 2 * x, [x])
|
||||
|
||||
compiled_func = torch.compile(func, backend="cudagraphs")
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_objs:
|
||||
for i in range(3):
|
||||
compiled_func(pred, x)
|
||||
|
||||
# Warn first for conditional node warmup, warn second for the
|
||||
# graph capture that we will actually use.
|
||||
self.assertEqual(len(warning_objs), 2)
|
||||
|
||||
warning_message = "You used random numbers in a cuda graph that uses conditional nodes. The previous design assumed that all RNG operations would execute only once, unconditionally, but this is no longer guaranteed with data-dependent control flow. Running with the cuda graph repeatedly may not match running without the cuda graph." # noqa: B950
|
||||
for warning in warning_objs:
|
||||
self.assertTrue(warning_message in str(warning.message))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestHopSchema)
|
||||
instantiate_parametrized_tests(TestControlFlowTraced)
|
||||
|
||||
instantiate_parametrized_tests(TestControlFlow)
|
||||
instantiate_parametrized_tests(AssociativeScanTests)
|
||||
|
||||
instantiate_parametrized_tests(TestControlFlowAndRNG)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,90 +0,0 @@
|
||||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
|
||||
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
|
||||
)
|
||||
class TestControlFlowInCUDAGraphInitialization(TestCase):
|
||||
# Duplicated from test_cuda_primary_ctx.py
|
||||
CTX_ALREADY_CREATED_ERR_MSG = (
|
||||
"Tests defined in TestControlFlowInCUDAGraphInitialization must be run in a process "
|
||||
"where CUDA contexts are never created. Use either run_test.py or add "
|
||||
"--subprocess to run each test in a different subprocess."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# Ensure context has not been created beforehand
|
||||
self.assertFalse(
|
||||
torch._C._cuda_hasPrimaryContext(0),
|
||||
TestControlFlowInCUDAGraphInitialization.CTX_ALREADY_CREATED_ERR_MSG,
|
||||
)
|
||||
|
||||
def _check_compile_cudagraphs(self, f, pred, *other_args):
|
||||
f = torch.compile(f, backend="cudagraphs")
|
||||
|
||||
outputs = []
|
||||
|
||||
for p in [pred, torch.logical_not(pred)]:
|
||||
for i in range(3):
|
||||
outputs.append(f(pred, *other_args).clone())
|
||||
|
||||
# We compute the eager output only after running cudagraphs
|
||||
# backend compiled function, in order to make sure that
|
||||
# cudagraph trees warms up the conditional part of the code
|
||||
# properly.
|
||||
eager_output = f(pred, *other_args)
|
||||
|
||||
for output in outputs:
|
||||
self.assertEqual(output, eager_output)
|
||||
|
||||
def test_cond_cudnn(self):
|
||||
# Tests that cublasCreate() does not break stream capture
|
||||
def f(pred, x, filters):
|
||||
return torch.cond(
|
||||
pred,
|
||||
lambda y: torch.sum(y),
|
||||
lambda y: torch.sum(torch.nn.functional.conv1d(y, filters)),
|
||||
[x],
|
||||
)
|
||||
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.randn(33, 16, 30, device="cuda")
|
||||
filters = torch.randn(20, 16, 5, device="cuda")
|
||||
|
||||
self._check_compile_cudagraphs(f, pred, x, filters)
|
||||
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
def test_cond_stft(self):
|
||||
# Tests that cufft plan creation does not break stream capture
|
||||
def f(pred, x):
|
||||
return torch.cond(
|
||||
pred,
|
||||
lambda y: torch.sum(y),
|
||||
lambda y: torch.sum(torch.stft(y, 512, return_complex=False)),
|
||||
[x],
|
||||
)
|
||||
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
pred = torch.tensor(True, device="cuda")
|
||||
x = torch.ones(1024 * 1024, device="cuda")
|
||||
|
||||
self._check_compile_cudagraphs(f, pred, x)
|
||||
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -551,7 +551,6 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
# temporarily sets a global config
|
||||
"test_autograd_fallback",
|
||||
"inductor/test_compiler_bisector",
|
||||
"functorch/test_control_flow_cuda_initialization",
|
||||
] + FSDP_TEST
|
||||
|
||||
# Test files that should always be run serially with other test files,
|
||||
@ -1479,7 +1478,6 @@ CUSTOM_HANDLERS = {
|
||||
"distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
|
||||
"distributed/rpc/test_share_memory": run_test_with_subprocess,
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
|
||||
"functorch/test_control_flow_cuda_initialization": run_test_with_subprocess,
|
||||
"doctests": run_doctests,
|
||||
"test_ci_sanity_check_fail": run_ci_sanity_check,
|
||||
"test_autoload_enable": test_autoload_enable,
|
||||
|
@ -2125,13 +2125,6 @@ class _CUDAGraph:
|
||||
def pool(self) -> Tuple[_int, _int]: ...
|
||||
def enable_debug_mode(self) -> None: ...
|
||||
def debug_dump(self, debug_path: str) -> None: ...
|
||||
@staticmethod
|
||||
def get_currently_capturing_graph() -> _CUDAGraph: ...
|
||||
def begin_capture_to_if_node(self, scalar_cuda_pred_tensor): ...
|
||||
def begin_capture_to_while_loop_node(self, scalar_cuda_pred_tensor) -> _int: ...
|
||||
def end_capture_to_conditional_node(self): ...
|
||||
@staticmethod
|
||||
def set_conditional_handle(handle, scalar_cuda_pred_tensor): ...
|
||||
|
||||
# Defined in torch/csrc/cuda/MemPool.cpp
|
||||
class _MemPool:
|
||||
|
@ -166,7 +166,7 @@ def cudagraphs(dynamo_model, dynamo_inputs):
|
||||
range(fixed),
|
||||
device_index=boxed_device_index.value,
|
||||
is_backward=False,
|
||||
is_inference=is_inference,
|
||||
is_inference=False,
|
||||
stack_traces=get_stack_traces(aot_model),
|
||||
placeholders=get_placeholder_info(aot_model.graph),
|
||||
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
||||
|
@ -56,7 +56,7 @@ def make_eager_backend_with_torch_function_mode(mode):
|
||||
|
||||
|
||||
def make_eager_backend_with_torch_function_modes(modes):
|
||||
"""Used to trace HOPs (cond and while) for eager execution, the metadata
|
||||
"""Used to trace HOPs (cond and while) for eager exectution, the metadata
|
||||
TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks
|
||||
in the HOP, so we need to externally run this mode and not trace it."""
|
||||
from contextlib import ExitStack
|
||||
|
@ -151,10 +151,10 @@ def populate_builtin_to_tensor_fn_map():
|
||||
most_recent_func = func
|
||||
return func(*args, **kwargs)
|
||||
|
||||
inp0 = torch.ones(1, device="cpu")
|
||||
inp1 = torch.ones(1, device="cpu")
|
||||
inp0_int = torch.ones(1, dtype=torch.int32, device="cpu")
|
||||
inp1_int = torch.ones(1, dtype=torch.int32, device="cpu")
|
||||
inp0 = torch.ones(1)
|
||||
inp1 = torch.ones(1)
|
||||
inp0_int = torch.ones(1, dtype=torch.int32)
|
||||
inp1_int = torch.ones(1, dtype=torch.int32)
|
||||
with GetMethodMode():
|
||||
setups_and_oplists = [
|
||||
(lambda o: o(inp0), un_ops),
|
||||
|
@ -18,11 +18,6 @@ from torch._C._functorch import (
|
||||
from torch._dispatch.python import suspend_functionalization
|
||||
from torch._functorch.utils import exposed_in
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._higher_order_ops.cudagraph_conditional_nodes import (
|
||||
ControlFlowOpWarmupDispatchMode,
|
||||
CUDAGraphCaptureControlFlowOpDispatchMode,
|
||||
if_else_node,
|
||||
)
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
@ -38,7 +33,6 @@ from torch._higher_order_ops.utils import (
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import disable_functional_mode
|
||||
from torch.cuda.graphs import _graph_no_gc
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
@ -377,41 +371,6 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
|
||||
return false_fn(*operands)
|
||||
|
||||
|
||||
# WAR for https://github.com/pytorch/pytorch/issues/140322
|
||||
@cond_op.py_impl(CUDAGraphCaptureControlFlowOpDispatchMode)
|
||||
def cond_op_cudagraph(mode, pred, true_fn, false_fn, operands):
|
||||
assert torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
|
||||
# Re-enter this mode because addition torch.cond() and
|
||||
# torch.while_loop() calls may be nested inside true_fn or
|
||||
# false_fn
|
||||
with mode:
|
||||
return if_else_node(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
# WAR for https://github.com/pytorch/pytorch/issues/140322
|
||||
@cond_op.py_impl(ControlFlowOpWarmupDispatchMode)
|
||||
def cond_op_warmup(mode, pred, true_fn, false_fn, operands):
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
# This is a call to torch.cond() nested within either
|
||||
# torch.while_loop() or another torch.cond() function.
|
||||
with mode:
|
||||
return if_else_node(pred, true_fn, false_fn, operands)
|
||||
else:
|
||||
with _graph_no_gc(
|
||||
torch.cuda.CUDAGraph(),
|
||||
pool=None,
|
||||
stream=mode.capture_stream,
|
||||
capture_error_mode="relaxed",
|
||||
), mode:
|
||||
if_else_node(pred, true_fn, false_fn, operands)
|
||||
# Since ControlFlowOpWarmupDispatchMode has been popped, this call
|
||||
# will fall back to cond_op_dense
|
||||
return cond_op_dense(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
# return torch.cond(pred, true_fn, false_fn, operands)
|
||||
|
||||
|
||||
class CondAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
|
@ -1,132 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
class CUDAGraphCaptureControlFlowOpDispatchMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func,
|
||||
types,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlFlowOpWarmupDispatchMode(TorchDispatchMode):
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.capture_stream = torch.cuda.Stream()
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func,
|
||||
types,
|
||||
args=(),
|
||||
kwargs=None,
|
||||
):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
with torch.cuda.graphs.thread_cuda_stream_capture_mode(
|
||||
torch.cuda.cudart().cudaStreamCaptureMode.Relaxed
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_boolean_scalar_cuda_tensor(pred: Any) -> bool:
|
||||
return (
|
||||
isinstance(pred, torch.Tensor)
|
||||
and pred.size() == torch.Size([])
|
||||
and pred.dtype == torch.bool
|
||||
and pred.is_cuda
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _if_body(pred: torch.Tensor) -> Generator[None, None, None]:
|
||||
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
|
||||
current_cuda_graph.begin_capture_to_if_node(pred)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
current_cuda_graph.end_capture_to_conditional_node()
|
||||
|
||||
|
||||
def if_else_node(pred: torch.Tensor, true_fn, false_fn, operands):
|
||||
if not pred.is_cuda:
|
||||
raise ValueError(
|
||||
"Conditions must be on a cuda device to use conditional node in cuda graphs"
|
||||
)
|
||||
# if-else is not supported yet in CUDA 12.4. Therefore, we use two if conditions, where one evaluates !pred
|
||||
outs = []
|
||||
|
||||
for lazy_pred, fn in [
|
||||
(lambda: pred, true_fn),
|
||||
(lambda: torch.logical_not(pred), false_fn),
|
||||
]:
|
||||
with _if_body(lazy_pred()):
|
||||
outs.append(fn(*operands))
|
||||
# Copy these two outputs into a new output buffer. Well,
|
||||
# actually, what we would like is to be able to merge these two
|
||||
# tensors into the same tensor... Is there an obvious way to do
|
||||
# that?
|
||||
if len(outs) == 2:
|
||||
for if_out, else_out in zip(
|
||||
pytree.tree_iter(outs[0]), pytree.tree_iter(outs[1])
|
||||
):
|
||||
if_out.copy_(else_out)
|
||||
assert len(outs) == 2
|
||||
return outs[0]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _while_loop_body(pred: torch.Tensor) -> Generator[int, None, None]:
|
||||
current_cuda_graph = torch.cuda.CUDAGraph.get_currently_capturing_graph()
|
||||
conditional_handle = current_cuda_graph.begin_capture_to_while_loop_node(pred)
|
||||
try:
|
||||
yield conditional_handle
|
||||
finally:
|
||||
current_cuda_graph.end_capture_to_conditional_node()
|
||||
|
||||
|
||||
def while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
carried_vals = carried_inputs
|
||||
pred = cond_fn(*carried_vals, *additional_inputs)
|
||||
if not _is_boolean_scalar_cuda_tensor(pred):
|
||||
raise RuntimeError(
|
||||
f"cond_fn must return a boolean scalar cuda tensor but got {pred}"
|
||||
)
|
||||
|
||||
with _while_loop_body(pred) as conditional_handle:
|
||||
out = body_fn(*carried_vals, *additional_inputs)
|
||||
assert len(out) == len(
|
||||
carried_inputs
|
||||
), "body_fn should return the same number of elements as carried_inputs"
|
||||
|
||||
# out is not being flattened, for whatever reason.
|
||||
for c, o in zip(carried_vals, out):
|
||||
# TODO: Consider skipping the copy_ if the data_ptr is the
|
||||
# same.
|
||||
c.copy_(o)
|
||||
|
||||
# call the cond_fn again to update the pred.
|
||||
pred = cond_fn(*carried_vals, *additional_inputs)
|
||||
if not _is_boolean_scalar_cuda_tensor(pred):
|
||||
raise RuntimeError(
|
||||
f"cond_fn must return a boolean scalar tensor but got {pred}"
|
||||
)
|
||||
torch.cuda.CUDAGraph.set_conditional_handle(conditional_handle, pred)
|
||||
|
||||
return carried_vals
|
@ -5,11 +5,6 @@ from typing import Callable, Union
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.cudagraph_conditional_nodes import (
|
||||
ControlFlowOpWarmupDispatchMode,
|
||||
CUDAGraphCaptureControlFlowOpDispatchMode,
|
||||
while_loop_node,
|
||||
)
|
||||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
@ -23,7 +18,6 @@ from torch._higher_order_ops.utils import (
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.cuda.graphs import _graph_no_gc
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_metadata_torch_function_mode,
|
||||
ProxyTorchDispatchMode,
|
||||
@ -222,37 +216,6 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
return carried_vals
|
||||
|
||||
|
||||
# WAR for https://github.com/pytorch/pytorch/issues/140322
|
||||
@while_loop_op.py_impl(CUDAGraphCaptureControlFlowOpDispatchMode)
|
||||
def while_loop_cudagraph(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
assert torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()
|
||||
# Re-enter this mode because addition torch.cond() and
|
||||
# torch.while_loop() calls may be nested inside cond_fn or body_fn
|
||||
with mode:
|
||||
return while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
# WAR for https://github.com/pytorch/pytorch/issues/140322
|
||||
@while_loop_op.py_impl(ControlFlowOpWarmupDispatchMode)
|
||||
def while_loop_warmup(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
# This is a call to torch.while_loop() nested within either
|
||||
# torch.while_loop() or another torch.cond() function.
|
||||
with mode:
|
||||
return while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
else:
|
||||
with _graph_no_gc(
|
||||
torch.cuda.CUDAGraph(),
|
||||
pool=None,
|
||||
stream=mode.capture_stream,
|
||||
capture_error_mode="relaxed",
|
||||
), mode:
|
||||
while_loop_node(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
# Since ControlFlowOpWarmupDispatchMode has been popped, this call
|
||||
# will fall back to while_loop_dense
|
||||
return while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
|
||||
while_loop_op.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(while_loop_op, deferred_error=True)
|
||||
)
|
||||
|
@ -64,10 +64,6 @@ import torch.fx
|
||||
from torch import Tensor
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state
|
||||
from torch._higher_order_ops.cudagraph_conditional_nodes import (
|
||||
ControlFlowOpWarmupDispatchMode,
|
||||
CUDAGraphCaptureControlFlowOpDispatchMode,
|
||||
)
|
||||
from torch._inductor.compile_fx import (
|
||||
align_inputs_from_check_idxs,
|
||||
copy_misaligned_inputs,
|
||||
@ -638,7 +634,7 @@ class CUDAWarmupNode:
|
||||
self.device_index
|
||||
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
|
||||
self.device_index, self.cuda_graphs_pool, self.stream
|
||||
), ControlFlowOpWarmupDispatchMode(), get_history_recording():
|
||||
), get_history_recording():
|
||||
out = self.wrapped_function.model(new_inputs)
|
||||
|
||||
# We need to know which outputs are allocated within the cudagraph pool
|
||||
@ -1213,7 +1209,7 @@ class CUDAGraphNode:
|
||||
stream=self.stream,
|
||||
pool=self.cuda_graphs_pool,
|
||||
capture_error_mode="thread_local",
|
||||
), CUDAGraphCaptureControlFlowOpDispatchMode(), get_history_recording():
|
||||
), get_history_recording():
|
||||
static_outputs = model(inputs)
|
||||
|
||||
# running model should reclaim memory
|
||||
|
@ -32,7 +32,6 @@ from torch._subclasses.meta_utils import (
|
||||
MetaConverter,
|
||||
)
|
||||
from torch._utils import render_call
|
||||
from torch.cuda.graphs import thread_cuda_stream_capture_mode
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
@ -2653,31 +2652,10 @@ def run_fallback_kernel(
|
||||
return out
|
||||
return e
|
||||
|
||||
has_cuda_tensor = any(
|
||||
isinstance(a, FakeTensor) and a.fake_device.type == "cuda"
|
||||
for a in flat_args
|
||||
)
|
||||
|
||||
flat_args = [to_real_tensor(a) for a in flat_args]
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
|
||||
|
||||
# If one of the inputs is a CUDA tensor, it is possible that
|
||||
# running the fallback kernel will do an unsafe
|
||||
# action. Unfortunately, there are scenarios where pytorch can
|
||||
# have a stream currently capturing on the current stream that
|
||||
# is using fake tensors (in particular, for shape inference in
|
||||
# higher order operators). We need to prevent stream capture
|
||||
# from breaking in this case. This is basically always safe
|
||||
# because the unsafe actions tend to be lazy initialization of
|
||||
# things like CUFFT plans, which won't be destroyed.
|
||||
maybe_relaxed: typing.ContextManager = contextlib.nullcontext()
|
||||
if has_cuda_tensor:
|
||||
cudart = torch.cuda.cudart()
|
||||
maybe_relaxed = thread_cuda_stream_capture_mode(
|
||||
cudart.cudaStreamCaptureMode.Relaxed
|
||||
)
|
||||
with maybe_relaxed:
|
||||
r = func(*args, **kwargs)
|
||||
r = func(*args, **kwargs)
|
||||
|
||||
storages: set[_StoragePointer] = set()
|
||||
|
||||
|
@ -87,30 +87,5 @@ void THCPGraph_init(PyObject* module) {
|
||||
"debug_dump",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::debug_dump),
|
||||
py::arg("debug_path"))
|
||||
.def_static(
|
||||
"get_currently_capturing_graph",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::get_currently_capturing_graph),
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"begin_capture_to_if_node",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::begin_capture_to_if_node),
|
||||
py::arg("scalar_cuda_pred_tensor"))
|
||||
.def(
|
||||
"begin_capture_to_while_loop_node",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::begin_capture_to_while_loop_node),
|
||||
py::arg("scalar_cuda_pred_tensor"))
|
||||
.def(
|
||||
"end_capture_to_conditional_node",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::end_capture_to_conditional_node))
|
||||
.def_static(
|
||||
"set_conditional_handle",
|
||||
torch::wrap_pybind_function_no_gil(
|
||||
&::at::cuda::CUDAGraph::set_conditional_handle),
|
||||
py::arg("handle"),
|
||||
py::arg("scalar_cuda_pred_tensor"));
|
||||
py::arg("debug_path"));
|
||||
}
|
||||
|
@ -92,20 +92,6 @@ void initCudartBindings(PyObject* module) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr));
|
||||
});
|
||||
cudart.attr(
|
||||
"cuda"
|
||||
"StreamDefault") = cudaStreamDefault;
|
||||
cudart.attr(
|
||||
"cuda"
|
||||
"StreamNonBlocking") = cudaStreamNonBlocking;
|
||||
cudart.def(
|
||||
"cuda"
|
||||
"StreamCreateWithFlags",
|
||||
[](uintptr_t ptr, unsigned int flags) -> cudaError_t {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return C10_CUDA_ERROR_HANDLED(
|
||||
cudaStreamCreateWithFlags((cudaStream_t*)ptr, flags));
|
||||
});
|
||||
cudart.def(
|
||||
"cuda"
|
||||
"StreamDestroy",
|
||||
@ -134,22 +120,6 @@ void initCudartBindings(PyObject* module) {
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
return {device_free, device_total};
|
||||
});
|
||||
|
||||
py::enum_<cudaStreamCaptureMode>(
|
||||
cudart,
|
||||
"cuda"
|
||||
"StreamCaptureMode")
|
||||
.value("Global", cudaStreamCaptureModeGlobal)
|
||||
.value("ThreadLocal", cudaStreamCaptureModeThreadLocal)
|
||||
.value("Relaxed", cudaStreamCaptureModeRelaxed);
|
||||
|
||||
cudart.def(
|
||||
"cuda"
|
||||
"ThreadExchangeStreamCaptureMode",
|
||||
[](cudaStreamCaptureMode mode) -> cudaStreamCaptureMode {
|
||||
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
return mode;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::cuda::shared
|
||||
|
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import gc
|
||||
import typing
|
||||
|
||||
@ -133,11 +132,6 @@ class graph:
|
||||
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
|
||||
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
|
||||
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
|
||||
collect_garbage (bool, optional): If True, call torch.cuda.synchronize() followed by gc.collect() to free
|
||||
memory before starting graph capture. Users almost always this to be True, but since the introduction of
|
||||
conditional nodes in cuda graphs, it is possible that more than one stream may be capturing at once.
|
||||
Since cudaDeviceSynchronize() synchronizes all streams, including capturing streams, previously started
|
||||
stream captures will be invalidated. This is not desirable.
|
||||
|
||||
.. note::
|
||||
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
|
||||
@ -153,7 +147,11 @@ class graph:
|
||||
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
|
||||
|
||||
def __init__(
|
||||
self, cuda_graph, pool=None, stream=None, capture_error_mode: str = "global"
|
||||
self,
|
||||
cuda_graph,
|
||||
pool=None,
|
||||
stream=None,
|
||||
capture_error_mode: str = "global",
|
||||
):
|
||||
# Lazy-init of default_capture_stream helps avoid circular-import errors.
|
||||
# Not thread safe, but graphs already have the general (explicitly documented)
|
||||
@ -171,7 +169,7 @@ class graph:
|
||||
self.capture_error_mode = capture_error_mode
|
||||
|
||||
def __enter__(self):
|
||||
# Free as much memory as we can for the graph.
|
||||
# Free as much memory as we can for the graph
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@ -190,30 +188,6 @@ class graph:
|
||||
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _graph_no_gc(cuda_graph, pool, stream, capture_error_mode):
|
||||
"""This is an internal function used to do stream capture without
|
||||
calling torch.cuda.synchronize(), gc.collect(), and
|
||||
torch.cuda.empty_cache(). Unfortunately, cudagraph trees runs its
|
||||
eager warmup inside of the context manager
|
||||
_use_cuda_memory_pool_manager(), which makes captures_underway in
|
||||
CUDACachingAllocator.cpp non-empty. We need this in order to
|
||||
warmup conditional higher order operators, like torch.cond() and
|
||||
torch.while_loop(). torch.cuda.empty_cache() will fail if
|
||||
captures_underway is non-empty. Removing torch.cuda.synchronize()
|
||||
and gc.collect() is not strictly speaking required, but they are
|
||||
expensive an unnecessary operations.
|
||||
"""
|
||||
stream_ctx = torch.cuda.stream(stream)
|
||||
pool = () if pool is None else (pool,)
|
||||
with stream_ctx:
|
||||
cuda_graph.capture_begin(*pool, capture_error_mode=capture_error_mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cuda_graph.capture_end()
|
||||
|
||||
|
||||
def make_graphed_callables(
|
||||
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
|
||||
):
|
||||
@ -514,46 +488,3 @@ def make_graphed_callables(
|
||||
return ret[0]
|
||||
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def thread_cuda_stream_capture_mode(new_mode):
|
||||
r"""Changes current thread's stream capture mode to `new_mode` upon __enter__ and resets the mode upon __exit__.
|
||||
|
||||
The only documentation on a thread's stream capture mode is here:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
|
||||
However, it is a little bit inadequate, so here is a more in-depth description.
|
||||
|
||||
Both CPU threads and capturing cuda streams have a capture mode. A
|
||||
cuda stream's capture mode is set at cudaStreamBeginCapture() and
|
||||
can never be changed. Meanwhile all CPU threads start with a capture
|
||||
mode of cudaStreamCaptureModeGlobal, which can be changed at any
|
||||
time.
|
||||
|
||||
Whenever a thread executes an unsafe CUDA action while CUDA
|
||||
streams are capturing, it follows the following logic to determine
|
||||
whether to invalidate those streams:
|
||||
|
||||
if capture_mode_this_thread == cudaStreamCaptureModeRelaxed:
|
||||
never invalidate any capturing cuda streams whatsoever.
|
||||
elif capture_mode_this_thread == cudaStreamCaptureModeThreadLocal:
|
||||
invalidate any cuda streams for which cudaStreamBeginCapture() was called by this
|
||||
thread, except for streams whose capture mode is cudaStreamCaptureModeRelaxed.
|
||||
elif capture_mode_this_thread == cudaStreamCaptureModeGlobal:
|
||||
invalidate all cuda streams that are currently capturing on any thread,
|
||||
except for streams whose capture mode is cudaStreamCaptureModeRelaxed and for
|
||||
streams for which cudaStreamCaptureBegin() was called with
|
||||
cudaStreamCaptureModeThreadLocal on a thread other than this one.
|
||||
|
||||
In practice, changed the current capture mode to
|
||||
cudaStreamCaptureModeRelaxed in particular is helpful for enabling
|
||||
developers to do "unsafe" things that we know are safe in our
|
||||
case.
|
||||
"""
|
||||
cudart = torch.cuda.cudart()
|
||||
old_mode = cudart.cudaThreadExchangeStreamCaptureMode(new_mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cudart.cudaThreadExchangeStreamCaptureMode(old_mode)
|
||||
|
@ -1507,12 +1507,6 @@ TEST_CUDA_GRAPH = TEST_CUDA and (not TEST_SKIP_CUDAGRAPH) and (
|
||||
)
|
||||
|
||||
TEST_CUDA_CUDSS = TEST_CUDA and (torch.version.cuda and int(torch.version.cuda.split(".")[0]) >= 12)
|
||||
TEST_CUDA_GRAPH_CONDITIONAL_NODES = TEST_CUDA_GRAPH and (
|
||||
torch.version.cuda and (
|
||||
(int(torch.version.cuda.split(".")[0]) >= 12 and int(torch.version.cuda.split(".")[1]) >= 4) or
|
||||
(int(torch.version.cuda.split(".")[0]) >= 13)
|
||||
)
|
||||
)
|
||||
|
||||
def allocator_option_enabled_fn(allocator_config, _, option):
|
||||
if allocator_config is None:
|
||||
|
Reference in New Issue
Block a user