Graph-Safe RNG State Exchange for Tensor Parallelism (#114068)

See #113541

The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality.

cc  @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114068
Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
This commit is contained in:
Frank Lin
2024-03-27 01:14:38 +00:00
committed by PyTorch MergeBot
parent fe41ba4765
commit 249e65b92d
15 changed files with 644 additions and 139 deletions

View File

@ -13,4 +13,12 @@ at::Tensor Generator::get_state() const {
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
}
void Generator::graphsafe_set_state(const Generator& new_state) {
this->impl_->graphsafe_set_state(new_state.getIntrusivePtr());
}
Generator Generator::graphsafe_get_state() const {
return Generator(this->impl_->graphsafe_get_state());
}
} // namespace at

View File

@ -107,6 +107,10 @@ struct TORCH_API Generator {
at::Tensor get_state() const;
void graphsafe_set_state(const Generator& new_state);
Generator graphsafe_get_state() const;
std::mutex& mutex() {
return impl_->mutex_;
}

View File

@ -1,5 +1,8 @@
#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDAFunctions.h>
@ -24,10 +27,10 @@ static std::deque<c10::once_flag> cuda_gens_init_flag;
static std::vector<Generator> default_gens_cuda;
/*
* Populates the global variables related to CUDA generators
* Warning: this function must only be called once!
*/
static void initCUDAGenVector(){
* Populates the global variables related to CUDA generators
* Warning: this function must only be called once!
*/
static void initCUDAGenVector() {
num_gpus = c10::cuda::device_count();
cuda_gens_init_flag.resize(num_gpus);
default_gens_cuda.resize(num_gpus);
@ -77,6 +80,150 @@ Generator createCUDAGenerator(DeviceIndex device_index) {
} // namespace cuda::detail
/**
* Creates a clone of this CUDA Generator State.
*/
c10::intrusive_ptr<CUDAGeneratorState> CUDAGeneratorState::clone() {
return make_intrusive<CUDAGeneratorState>(
seed_, philox_offset_per_thread_, offset_intragraph_);
}
/**
* Function to increase the internal offset based on the specified increment.
*/
void CUDAGeneratorState::increase(uint64_t increment) {
// Rounds increment up to the nearest multiple of 4 to meet alignment
// requirements.
// see Note [Why enforce RNG offset % 4 == 0?]
increment = ((increment + 3) / 4) * 4;
// Handling different behaviors based on whether capturing is active.
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
// Ensures that the state is actually capturing.
TORCH_CHECK(
capturing_,
"Attempt to increase offset for a CUDA generator not in capture mode.");
// Ensures the offset is a multiple of 4
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
// Ensures the increment does not cause overflow.
TORCH_INTERNAL_ASSERT(
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
"Increment causes overflow in the offset value.");
offset_intragraph_ += increment;
} else {
// Checks that the increment is expected outside graph capturing.
TORCH_CHECK(
!capturing_,
"Offset increment outside graph capture encountered unexpectedly.");
// Ensures the offset is a multiple of 4
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(
philox_offset_per_thread_ % 4 == 0,
"RNG offset must be a multiple of 4.");
philox_offset_per_thread_ += increment;
}
}
/**
* Registers this state to a CUDA graph to manage within the graph.
*/
void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
// Ensures that the RNG state is not currently being captured.
at::cuda::assertNotCapturing(
"Cannot register the state during capturing stage.");
// If this is the first graph to be registered, allocate memory for the seed
// and offset on the GPU.
if (registered_graphs_.empty()) {
auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong);
seed_extragraph_ = at::empty({1}, options);
offset_extragraph_ = at::empty({1}, options);
}
// Insert the graph into the set of registered graphs if it's not already
// registered.
if (registered_graphs_.find(graph) == registered_graphs_.end()) {
registered_graphs_.insert(graph);
}
}
/**
* Unregisters a CUDA graph from the RNG state.
*/
void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
// Ensures that the RNG state is not currently being captured.
at::cuda::assertNotCapturing(
"Cannot unregister the state during capturing stage.");
// Verify the graph was previously registered.
TORCH_CHECK(
registered_graphs_.find(graph) != registered_graphs_.end(),
"The graph should be registered to the state");
// Remove the graph from the set of registered graphs.
registered_graphs_.erase(graph);
// If no more graphs are registered, deallocate the GPU memory for the seed
// and offset.
if (registered_graphs_.empty()) {
seed_extragraph_.reset();
offset_extragraph_.reset();
}
}
/**
* Note [Explicit Registration of Generators to the CUDA Graph]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*
* Ideally, it would be more user-friendly if the state could be exchanged and generators
* could be registered with the CUDA graph implicitly. However, resetting GPU tensors during
* the capture stage causes these reset operations to be recorded within the CUDA graph.
* This behavior is undesirable because we do not want these tensors to be reset during
* the replay stage of the graph.
*
* As of now, there is no available method to perform a CUDA operation during the graph's
* recording phase without having that operation be included in the CUDA graph.
* This limitation necessitates explicit user action to register generators with the graph.
* By requiring users to manually register their generators, we can ensure that state resets
* (capture_prologue) only occur before the graph capture begins, thus avoiding unintended
* resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068.
*/
/**
* Performs the prologue steps for capturing a CUDA graph state.
* This method is intended to reset graph-related state variables before capturing begins.
*/
void CUDAGeneratorState::capture_prologue() {
capturing_ = true;
offset_intragraph_ = 0;
seed_extragraph_.fill_(int64_t(seed_));
offset_extragraph_.fill_(int64_t(0));
}
/**
* Ends the capturing phase and resets related variables, returning the whole
* graph increment.
*/
uint64_t CUDAGeneratorState::capture_epilogue() {
capturing_ = false;
return offset_intragraph_;
}
/**
* Prepares the state for replay by setting initial state tensors and applying
* total increment.
*/
void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
// Ensures the generator is not in capturing mode.
at::cuda::assertNotCapturing(
"Cannot prepare for replay during capturing stage.");
seed_extragraph_.fill_(int64_t(seed_));
offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
// Applies the total increment achieved during previous captures to update the
// offset.
increase(wholegraph_increment);
}
/**
* Note [Why enforce RNG offset % 4 == 0?]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -97,8 +244,18 @@ Generator createCUDAGenerator(DeviceIndex device_index) {
*/
CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
: c10::GeneratorImpl{Device(DeviceType::CUDA, device_index),
DispatchKeySet(c10::DispatchKey::CUDA)} {
DispatchKeySet(c10::DispatchKey::CUDA)} {
at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl");
state_ = make_intrusive<CUDAGeneratorState>();
no_reset_rnn_state_.clear();
}
CUDAGeneratorImpl::CUDAGeneratorImpl(
DeviceIndex device_index,
c10::intrusive_ptr<CUDAGeneratorState> state)
: c10::
GeneratorImpl{Device(DeviceType::CUDA, device_index), DispatchKeySet(c10::DispatchKey::CUDA)},
state_(std::move(state)) {
no_reset_rnn_state_.clear();
}
@ -109,9 +266,10 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
* See Note [Acquire lock when using random generators]
*/
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_current_seed");
seed_ = seed;
philox_offset_per_thread_ = 0;
at::cuda::assertNotCapturing(
"Cannot call CUDAGeneratorImpl::set_current_seed");
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
no_reset_rnn_state_.clear();
}
@ -134,15 +292,9 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
// Debatable if get_offset() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_offset");
return philox_offset_per_thread_;
return state_->philox_offset_per_thread_;
}
#define CAPTURE_DEFAULT_GENS_MSG \
"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \
"generator on the device that's current when capture begins. " \
"If you need a non-default (user-supplied) generator, or a generator on another " \
"device, please file an issue."
/**
* Gets the current seed of CUDAGeneratorImpl.
*/
@ -150,7 +302,7 @@ uint64_t CUDAGeneratorImpl::current_seed() const {
// Debatable if current_seed() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
return seed_;
return state_->seed_;
}
/**
@ -194,6 +346,8 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::cuda::assertNotCapturing(
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
@ -208,7 +362,7 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
}
uint64_t input_seed;
uint64_t input_seed = 0;
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
memcpy(&input_seed, new_rng_state, seed_size);
this->set_current_seed(input_seed);
@ -219,44 +373,59 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset));
}
/**
* Sets the generator's current state to
* This function allows switching between different registered states of
* the generator.
*/
void CUDAGeneratorImpl::graphsafe_set_state(
const c10::intrusive_ptr<GeneratorImpl>& gen) {
c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(gen);
TORCH_CHECK(cuda_gen, "Expected a CUDA Generator");
state_ = cuda_gen->state_;
}
/**
* Get the GeneratorImpl that point to current state_
*/
c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
const {
auto gen = make_intrusive<CUDAGeneratorImpl>(device().index(), state_);
return gen;
}
/**
* Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
*
* See Note [Acquire lock when using random generators]
*/
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_philox_offset_per_thread");
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
philox_offset_per_thread_ = offset;
state_->philox_offset_per_thread_ = offset;
}
/**
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread");
return philox_offset_per_thread_;
return state_->philox_offset_per_thread_;
}
/**
* Called by CUDAGraph to prepare this instance for a graph capture region.
* offset_extragraph is the initial offset at the start of the graphed region.
* offset_intragraph tracks the offset in the graphed region.
* Registers this state to a CUDA graph to manage within the graph.
*/
void CUDAGeneratorImpl::capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph) {
seed_extragraph_ = seed_extragraph;
offset_extragraph_ = offset_extragraph;
offset_intragraph_ = 0;
graph_expects_this_gen_ = true;
void CUDAGeneratorImpl::register_graph(cuda::CUDAGraph* graph) {
graph->register_generator_state(state_);
state_->register_graph(graph);
}
/**
* Called by CUDAGraph to finalize a graph capture region for this instance.
* Unregisters a CUDA graph from the RNG state.
*/
uint64_t CUDAGeneratorImpl::capture_epilogue() {
graph_expects_this_gen_ = false;
return offset_intragraph_;
void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
state_->unregister_graph(graph);
}
/**
@ -281,30 +450,17 @@ uint64_t CUDAGeneratorImpl::capture_epilogue() {
* See Note [Acquire lock when using random generators]
*/
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
// rounds increment up to the nearest multiple of 4
increment = ((increment + 3) / 4) * 4;
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
TORCH_CHECK(graph_expects_this_gen_,
"philox_cuda_state for an unexpected CUDA generator used during capture. "
CAPTURE_DEFAULT_GENS_MSG);
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(this->offset_intragraph_ % 4 == 0);
uint32_t offset = this->offset_intragraph_;
TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <=
std::numeric_limits<uint32_t>::max() - increment);
this->offset_intragraph_ += increment;
return PhiloxCudaState(this->seed_extragraph_,
this->offset_extragraph_,
offset);
uint32_t offset = state_->offset_intragraph_;
state_->increase(increment);
return PhiloxCudaState(
state_->seed_extragraph_.data_ptr<int64_t>(),
state_->offset_extragraph_.data_ptr<int64_t>(),
offset);
} else {
TORCH_CHECK(!graph_expects_this_gen_,
"CUDA generator expects graph capture to be underway, "
"but the current stream is not capturing.");
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
uint64_t offset = this->philox_offset_per_thread_;
this->philox_offset_per_thread_ += increment;
return PhiloxCudaState(this->seed_, offset);
uint64_t offset = state_->philox_offset_per_thread_;
state_->increase(increment);
return PhiloxCudaState(state_->seed_, offset);
}
}
@ -312,16 +468,13 @@ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
* Temporarily accommodates call sites that use philox_engine_inputs.
* Allows incremental refactor of call sites to use philox_cuda_state.
*/
std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(uint64_t increment) {
at::cuda::assertNotCapturing("Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. "
"Cannot call CUDAGeneratorImpl::philox_engine_inputs");
// rounds increment up to the nearest multiple of 4
increment = ((increment + 3) / 4) * 4;
// see Note [Why enforce RNG offset % 4 == 0?]
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
uint64_t offset = this->philox_offset_per_thread_;
this->philox_offset_per_thread_ += increment;
return std::make_pair(this->seed_, offset);
std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(
uint64_t increment) {
at::cuda::assertNotCapturing(
"Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. Cannot call CUDAGeneratorImpl::philox_engine_inputs");
uint64_t offset = state_->philox_offset_per_thread_;
state_->increase(increment);
return std::make_pair(state_->seed_, offset);
}
/*
@ -348,9 +501,7 @@ std::shared_ptr<CUDAGeneratorImpl> CUDAGeneratorImpl::clone() const {
*/
CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl");
auto gen = new CUDAGeneratorImpl(this->device().index());
gen->set_current_seed(this->seed_);
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
auto gen = new CUDAGeneratorImpl(this->device().index(), state_->clone());
return gen;
}

View File

@ -1,12 +1,19 @@
#pragma once
#include <ATen/core/Generator.h>
#include <ATen/cuda/PhiloxCudaState.h>
#include <ATen/Context.h>
#include <limits>
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/PhiloxCudaState.h>
#include <atomic>
#include <limits>
#include <memory>
#include <unordered_set>
namespace at {
namespace cuda {
struct CUDAGraph;
}
/**
* Note [CUDA Graph-safe RNG states]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -87,9 +94,41 @@ namespace at {
*
*/
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
uint64_t seed_;
uint64_t philox_offset_per_thread_;
uint32_t offset_intragraph_;
bool capturing_{};
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
at::TensorBase seed_extragraph_{};
at::TensorBase offset_extragraph_{};
CUDAGeneratorState(
uint64_t seed = default_rng_seed_val,
uint64_t philox_offset_per_thread = 0,
uint32_t offset_intragraph = 0)
: seed_(seed),
philox_offset_per_thread_(philox_offset_per_thread),
offset_intragraph_(offset_intragraph) {}
void increase(uint64_t increment);
void register_graph(cuda::CUDAGraph* graph);
void unregister_graph(cuda::CUDAGraph* graph);
void capture_prologue();
// capture_epilogue returns the wholegraph_increment
uint64_t capture_epilogue();
void replay_prologue(uint64_t wholegraph_increment);
c10::intrusive_ptr<CUDAGeneratorState> clone();
};
struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
// Constructors
CUDAGeneratorImpl(DeviceIndex device_index = -1);
CUDAGeneratorImpl(
DeviceIndex device_index,
c10::intrusive_ptr<CUDAGeneratorState> state_);
~CUDAGeneratorImpl() override = default;
// CUDAGeneratorImpl methods
@ -101,10 +140,18 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void graphsafe_set_state(
const c10::intrusive_ptr<GeneratorImpl>& state) override;
c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread() const;
void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
uint64_t capture_epilogue();
void register_graph(cuda::CUDAGraph* graph);
void unregister_graph(cuda::CUDAGraph* graph);
// Generates a PhiloxCudaState with a specified increment, and increment
// current state
PhiloxCudaState philox_cuda_state(uint64_t increment);
bool reset_rnn_state() {
@ -117,14 +164,10 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
static c10::DeviceType device_type();
private:
private:
CUDAGeneratorImpl* clone_impl() const override;
uint64_t seed_ = default_rng_seed_val;
uint64_t philox_offset_per_thread_ = 0;
int64_t* seed_extragraph_{};
int64_t* offset_extragraph_{};
uint32_t offset_intragraph_ = 0;
bool graph_expects_this_gen_ = false;
c10::intrusive_ptr<CUDAGeneratorState> state_;
std::atomic_flag no_reset_rnn_state_;
};

View File

@ -6,7 +6,10 @@
#include <c10/cuda/CUDAFunctions.h>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <thread>
#include <vector>
namespace at::cuda {
@ -86,26 +89,33 @@ CUDAGraph::CUDAGraph()
#endif
}
void CUDAGraph::register_generator_state(
c10::intrusive_ptr<at::CUDAGeneratorState> state) {
captured_generator_states_[std::move(state)] = 0;
}
void CUDAGraph::register_generator_state(const at::Generator& generator) {
c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(
generator.getIntrusivePtr());
cuda_gen->register_graph(this);
}
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
TORCH_CHECK(!has_graph_exec_,
"This CUDAGraph instance already owns a captured graph. "
"To capture a new graph, create a new instance.");
// For now, a CUDAGraph instance only accommodates the default generator on the device that's
// current when capture begins. If any op in the captured region uses a non-default generator,
// or a generator on another device, the offending generator will throw an error.
// These restrictions simplify CUDAGraph, but could be relaxed in the future:
// in principle, the underlying Cuda calls do permit cross-device ops to be captured.
// default generator is always registered
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
gen->register_graph(this);
auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong);
seed_extragraph_ = at::empty({1}, options);
offset_extragraph_ = at::empty({1}, options);
seed_extragraph_.fill_(int64_t(gen->current_seed()));
gen->capture_prologue(seed_extragraph_.data_ptr<int64_t>(), offset_extragraph_.mutable_data_ptr<int64_t>());
for (auto& [generator_state, wholegraph_increments] :
captured_generator_states_) {
generator_state->capture_prologue();
}
auto stream = at::cuda::getCurrentCUDAStream();
@ -115,7 +125,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
"default stream.)");
capture_stream_ = stream;
capture_gen_ = gen;
capture_dev_ = c10::cuda::current_device();
id_ = capture_sequence_id();
@ -215,13 +224,10 @@ void CUDAGraph::capture_end() {
has_graph_exec_ = true;
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
TORCH_CHECK(gen == capture_gen_,
"Default CUDA RNG generator on current device at capture end "
"is different from default generator on current device "
"when capture began");
wholegraph_increment_ = gen->capture_epilogue();
for (auto& [generator_state, wholegraph_increments] :
captured_generator_states_) {
wholegraph_increments = generator_state->capture_epilogue();
}
size_t numCUDAGraphNodes = 0;
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, NULL, &numCUDAGraphNodes));
@ -251,17 +257,10 @@ void CUDAGraph::replay() {
c10::OptionalDeviceGuard device_guard{capture_stream_.device()};
// Just like any RNG consumer kernel!
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
PhiloxCudaState rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_);
for (auto& [generator_state, wholegraph_increments] :
captured_generator_states_) {
generator_state->replay_prologue(wholegraph_increments);
}
seed_extragraph_.fill_(int64_t(gen->current_seed()));
offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val));
// graph_exec_ may be replayed in any stream.
AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream()));
@ -355,6 +354,10 @@ TORCH_CHECK(has_graph_exec_,
}
CUDAGraph::~CUDAGraph() {
for (auto& [generator_state, wholegraph_increments] :
captured_generator_states_) {
generator_state->unregister_graph(this);
}
reset();
}

View File

@ -4,12 +4,13 @@
#include <c10/core/Device.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <mutex>
#include <c10/util/flat_hash_map.h>
namespace at {
struct Generator;
struct CUDAGeneratorImpl;
struct CUDAGeneratorState;
namespace cuda {
@ -24,7 +25,12 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
static void inc_pending_event_queries();
static void dec_pending_event_queries();
static int num_pending_event_queries();
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
// See Note [Explicit Registration of Generators to the CUDA Graph]
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
void register_generator_state(const at::Generator& generator);
void capture_begin(
MempoolId_t pool = {0, 0},
cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
void reset();
@ -32,7 +38,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
void enable_debug_mode();
void debug_dump(const std::string& debug_path);
protected:
protected:
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
@ -73,19 +79,16 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
// Stream on which capture began
at::cuda::CUDAStream capture_stream_;
// Default generator on device where capture began
at::CUDAGeneratorImpl* capture_gen_;
// multiple generator states and their wholegraph_increments in this graph
// that are managed by the CUDA Graph
ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t>
captured_generator_states_;
// Device where capture occurred. Right now, for simplicity, we require all ops
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
// captures if needed.
int capture_dev_;
// RNG state trackers
at::Tensor seed_extragraph_;
at::Tensor offset_extragraph_;
uint64_t wholegraph_increment_;
};
} // namespace cuda

View File

@ -31,6 +31,18 @@ c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const {
return c10::intrusive_ptr<GeneratorImpl>::reclaim(res);
}
void GeneratorImpl::graphsafe_set_state(
const c10::intrusive_ptr<c10::GeneratorImpl>& state) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "graphsafe_set_state is not supported in this Generator");
}
c10::intrusive_ptr<c10::GeneratorImpl> GeneratorImpl::graphsafe_get_state()
const {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "graphsafe_get_state is not supported in this Generator");
}
/**
* Gets the device of a generator.
*/

View File

@ -73,6 +73,9 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
virtual uint64_t seed() = 0;
virtual void set_state(const c10::TensorImpl& new_state) = 0;
virtual c10::intrusive_ptr<c10::TensorImpl> get_state() const = 0;
virtual void graphsafe_set_state(
const c10::intrusive_ptr<c10::GeneratorImpl>& new_state);
virtual c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const;
Device device() const;
// See Note [Acquire lock when using random generators]

View File

@ -976,9 +976,12 @@ Violating any of these will likely cause a runtime error:
:class:`~torch.cuda.graph` and
:func:`~torch.cuda.make_graphed_callables` set a side stream for you.)
* Ops that synchronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited.
* CUDA RNG ops are allowed, but must use default generators. For example, explicitly constructing a
new :class:`torch.Generator` instance and passing it as the ``generator`` argument to an RNG function
is prohibited.
* CUDA RNG operations are permitted, and when using multiple :class:`torch.Generator` instances within a graph,
they must be registered using :meth:`CUDAGraph.register_generator_state<torch.cuda.CUDAGraph.register_generator_state>` before graph capture.
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
Violating any of these will likely cause silent numerical errors or undefined behavior:

View File

@ -171,7 +171,6 @@ class TestCuda(TestCase):
tensor.fill_(1)
self.assertTrue((tensor == 1).all())
@unittest.skipIf(TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)")
def test_out_of_memory_retry(self):
torch.cuda.empty_cache()
@ -256,7 +255,6 @@ class TestCuda(TestCase):
c.copy_(b, non_blocking=True)
self.assertEqual(a, c, exact_dtype=False)
def test_to_non_blocking(self):
stream = torch.cuda.current_stream()
@ -384,7 +382,6 @@ class TestCuda(TestCase):
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
def test_cudnn_allow_tf32_get_set(self):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
self.assertFalse(torch.backends.cudnn.allow_tf32)
@ -1463,8 +1460,6 @@ torch.cuda.synchronize()
for op, args in self.autocast_lists.nn_fp16:
self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._nn)
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_nn_bf16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
@ -1783,6 +1778,163 @@ torch.cuda.synchronize()
self.assertTrue(b.sum().item() == 11000.)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graphsafe_set_get_rng_state(self):
# Define a function to create generator states, with optional graph registration
def create_states(generator):
"""Initializes generator states and registers them with a CUDA graph if provided."""
# Ensure the CUDA generator is initialized
torch.rand(1, device="cuda")
generator.manual_seed(0)
# Save the current state of the generator
old_state = generator.graphsafe_get_state()
# Create and save a cloned state of the generator
new_state = generator.clone_state()
# Return the original generator and its two states
return generator, old_state, new_state
def register_states_to_graph(generator_state, graph):
generator, old_state, new_state = generator_state
graph.register_generator_state(old_state)
graph.register_generator_state(new_state)
# Define a function to perform specific RNG actions using the generator's states
def perform_random_generation_steps(generator_state):
generator, old_state, new_state = generator_state
random_values = []
# Generate random numbers with the new generator state
generator.graphsafe_set_state(new_state)
random_values.append(torch.rand(5, device="cuda", generator=generator))
# Generate random numbers twice with the old generator state
generator.graphsafe_set_state(old_state)
random_values.extend(
[torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
)
return random_values
# Define a function to retrieve the final offsets of the original and new generator states
def get_final_offsets_of_states(generator_state):
generator, old_state, new_state = generator_state
old_state_offset = old_state.get_offset()
new_state_offset = new_state.get_offset()
return old_state_offset, new_state_offset
# Set up and test a new CUDA generator
generator = torch.Generator(device="cuda")
generator_state = create_states(generator)
# Set up and test the default CUDA generator with a CUDA Graph
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()
default_generator = torch.cuda.default_generators[0]
default_generator_state = create_states(default_generator)
register_states_to_graph(default_generator_state, g)
# Perform random number generation within a CUDA graph
with torch.cuda.stream(s):
g.capture_begin()
graphed_random_values = perform_random_generation_steps(
default_generator_state
)
g.capture_end()
# Synchronize the streams and replay the graph
torch.cuda.current_stream().wait_stream(s)
for _ in range(3):
random_values = perform_random_generation_steps(generator_state)
g.replay()
offset = get_final_offsets_of_states(generator_state)
graph_offset = get_final_offsets_of_states(default_generator_state)
# Compare the final offsets of states for both generators to ensure consistency
self.assertTrue(offset == graph_offset)
# Compare the states generated outside and inside the graph
self.assertEqual(random_values, graphed_random_values)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_memory_stats_of_multiple_generators_and_graphs(self):
# Function to clear CUDA cache and collect garbage
def clear_cuda_cache():
gc.collect()
torch.cuda.empty_cache()
# Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph.
def simple_graph_task(graph):
s = torch.cuda.Stream()
with torch.cuda.stream(s):
graph.capture_begin()
torch.rand(1, device="cuda")
graph.capture_end()
torch.cuda.current_stream().wait_stream(s)
graph.replay() # Replays the captured operations
def get_memory_stats():
stats = torch.cuda.memory_stats()
num_blocks = stats["active.all.current"]
total_size = stats["active_bytes.all.current"]
return num_blocks, total_size
def test(num_graphs, num_generators):
baseline = get_memory_stats()
baseline_num_blocks, baseline_total_size = baseline
# Allocate CUDA graphs
graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
# Allocate and manage generator states
default_generator = torch.cuda.default_generators[0]
generators = [default_generator.graphsafe_get_state()]
# Starts from 1 as one state is already added
for _ in range(1, num_generators):
generators.append(default_generator.clone_state())
for graph in graphs:
for generator_state in generators:
graph.register_generator_state(generator_state)
simple_graph_task(graph)
# Assert conditions after graph tasks
num_blocks, total_size = get_memory_stats()
# The allocated blocks should only be proportional to the number of generators
expected_blocks_diff = 2 * num_generators
expected_size_diff = 2 * 512 * num_generators # Each block's size is 512
self.assertTrue(
(num_blocks - baseline_num_blocks) == expected_blocks_diff,
"Unexpected number of active blocks.",
)
self.assertTrue(
(total_size - baseline_total_size) == expected_size_diff,
"Unexpected total memory size.",
)
# Cleanup graphs and clear CUDA cache
while graphs:
graph = graphs.pop()
del graph
clear_cuda_cache()
# Assert that memory stats return to baseline after cleanup
self.assertTrue(
get_memory_stats() == baseline,
"Memory stats do not match baseline after cleanup.",
)
# Running the test function with different parameters
test(1, 1)
test(3, 2)
test(10, 20)
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_graph_capture_reset_recapture(self):
s = torch.cuda.Stream()
@ -2332,14 +2484,18 @@ exit(2)
s = torch.cuda.Stream()
for (numel,
delta_cudaMallocs,
delta_cudaMalloc_bytes,
delta_cudaMalloc_bytes_post_del_g,
pool_string) in cases:
for (
numel,
delta_cudaMallocs,
delta_cudaMalloc_bytes,
delta_cudaMalloc_bytes_post_del_g,
pool_string,
) in cases:
if pool_string == "small_pool":
delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders
delta_active_bytes = numel * elem + 1024 # + 1024 for CUDAGraph's rng seed and offset holders each
delta_active_bytes = (
numel * elem + 1024
) # + 1024 for CUDAGraph's rng seed and offset holders each
else:
delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder
delta_active_bytes = numel * elem
@ -3085,8 +3241,6 @@ exit(2)
self.assertEqual(rc, "3")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaMallocAsync(TestCase):
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
@ -3165,7 +3319,6 @@ class TestCudaMallocAsync(TestCase):
finally:
torch.cuda.memory._record_memory_history(None)
@skipIfRocm
def test_memory_profiler_viz(self):
with torch.profiler.profile(
@ -3448,7 +3601,6 @@ class TestCudaMallocAsync(TestCase):
with self.assertRaises(RuntimeError):
torch.cuda.memory._set_allocator_settings("pinned_num_register_threads:1024")
@parametrize(
"max_split_size_mb_setting", [False, True]
)

View File

@ -1385,6 +1385,9 @@ class Generator:
def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
def get_state(self) -> Tensor: ...
def set_state(self, _new_state: Tensor) -> Generator: ...
def clone_state(self) -> Generator: ...
def graphsafe_get_state(self) -> Generator: ...
def graphsafe_set_state(self, _new_state: Generator) -> Generator: ...
def set_offset(self, offset: _int) -> Generator: ...
def get_offset(self) -> _int: ...
def manual_seed(self, seed: _int) -> Generator: ...
@ -1883,6 +1886,7 @@ class _CudaEventBase:
class _CUDAGraph:
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
def capture_end(self) -> None: ...
def register_generator_state(self, Generator) -> None: ...
def replay(self) -> None: ...
def reset(self) -> None: ...
def pool(self) -> Tuple[_int, _int]: ...

View File

@ -13722,6 +13722,58 @@ Example::
""",
)
add_docstr(
torch.Generator.graphsafe_set_state,
r"""
Generator.graphsafe_set_state(state) -> None
Sets the state of the generator to the specified state in a manner that is safe for use in graph capture.
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
Arguments:
state (torch.Generator): A Generator point to the new state for the generator, typically obtained from `graphsafe_get_state`.
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> g_cuda_other = torch.Generator(device='cuda')
>>> current_state = g_cuda_other.graphsafe_get_state()
>>> g_cuda.graphsafe_set_state(current_state)
""",
)
add_docstr(
torch.Generator.graphsafe_get_state,
r"""
Generator.graphsafe_get_state() -> torch.Generator
Retrieves the current state of the generator in a manner that is safe for graph capture.
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
Returns:
torch.Generator: A Generator point to the current state of the generator
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> current_state = g_cuda.graphsafe_get_state()
""",
)
add_docstr(
torch.Generator.clone_state,
r"""
Generator.clone_state() -> torch.Generator
Clones the current state of the generator and returns a new generator pointing to this cloned state.
This method is beneficial for preserving a particular state of a generator to restore at a later point.
Returns:
torch.Generator: A Generator pointing to the newly cloned state.
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> cloned_state = g_cuda.clone_state()
""",
)
add_docstr(
torch.Generator.manual_seed,

View File

@ -143,6 +143,47 @@ uint64_t unpack_uint64(PyObject* pyobj) {
return unsigned_obj;
}
static PyObject* THPGenerator_graphSafeGetState(
PyObject* _self,
PyObject* noargs) {
HANDLE_TH_ERRORS
auto& gen = ((THPGenerator*)_self)->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
return THPGenerator_Wrap(gen.graphsafe_get_state());
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_graphSafeSetState(
PyObject* _self,
PyObject* _state) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
auto& gen = self->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
Py_INCREF(self);
return (PyObject*)self;
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto& gen = ((THPGenerator*)_self)->cdata;
// See Note [Acquire lock when using random generators]
std::scoped_lock<std::mutex> lock(gen.mutex());
auto new_generator = gen.clone();
return THPGenerator_Wrap(new_generator);
END_HANDLE_TH_ERRORS
}
static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
HANDLE_TH_ERRORS
auto self = (THPGenerator*)_self;
@ -218,6 +259,12 @@ static struct PyGetSetDef THPGenerator_properties[] = {
static PyMethodDef THPGenerator_methods[] = {
{"get_state", THPGenerator_getState, METH_NOARGS, nullptr},
{"set_state", THPGenerator_setState, METH_O, nullptr},
{"clone_state", THPGenerator_cloneState, METH_NOARGS, nullptr},
{"graphsafe_get_state",
THPGenerator_graphSafeGetState,
METH_NOARGS,
nullptr},
{"graphsafe_set_state", THPGenerator_graphSafeSetState, METH_O, nullptr},
{"set_offset", THPGenerator_setOffset, METH_O, nullptr},
{"manual_seed", THPGenerator_manualSeed, METH_O, nullptr},
{"seed", THPGenerator_seed, METH_NOARGS, nullptr},
@ -304,6 +351,14 @@ PyObject* THPGenerator_Wrap(Generator gen) {
(PyTypeObject*)THPGeneratorClass, std::move(gen));
}
at::Generator THPGenerator_Unwrap(PyObject* state) {
if (!Py_IS_TYPE(state, &THPGeneratorType)) {
throw torch::TypeError(
"expected a Generator, but got %s", Py_TYPE(state)->tp_name);
}
return reinterpret_cast<THPGenerator*>(state)->cdata;
}
// Creates a new Python object for a Generator. The Generator must not already
// have a PyObject* associated with it.
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {

View File

@ -23,6 +23,8 @@ bool THPGenerator_init(PyObject* module);
TORCH_PYTHON_API PyObject* THPGenerator_Wrap(at::Generator gen);
TORCH_PYTHON_API at::Generator THPGenerator_Unwrap(PyObject* state);
// Creates a new Python object for a Generator. The Generator must not already
// have a PyObject* associated with it.
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, at::Generator gen);

View File

@ -56,6 +56,16 @@ void THCPGraph_init(PyObject* module) {
.def(
"capture_end",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
.def(
"register_generator_state",
[](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
// We've unwrapped Python object to C++ object,
// so we could release GIL before calling into C++
py::gil_scoped_release release;
return self.register_generator_state(generator);
},
py::arg("generator"))
.def(
"replay",
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))