mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Graph-Safe RNG State Exchange for Tensor Parallelism (#114068)
See #113541 The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality. cc @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli Pull Request resolved: https://github.com/pytorch/pytorch/pull/114068 Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe41ba4765
commit
249e65b92d
@ -13,4 +13,12 @@ at::Tensor Generator::get_state() const {
|
||||
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
|
||||
}
|
||||
|
||||
void Generator::graphsafe_set_state(const Generator& new_state) {
|
||||
this->impl_->graphsafe_set_state(new_state.getIntrusivePtr());
|
||||
}
|
||||
|
||||
Generator Generator::graphsafe_get_state() const {
|
||||
return Generator(this->impl_->graphsafe_get_state());
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
@ -107,6 +107,10 @@ struct TORCH_API Generator {
|
||||
|
||||
at::Tensor get_state() const;
|
||||
|
||||
void graphsafe_set_state(const Generator& new_state);
|
||||
|
||||
Generator graphsafe_get_state() const;
|
||||
|
||||
std::mutex& mutex() {
|
||||
return impl_->mutex_;
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#include <c10/core/StreamGuard.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
@ -24,10 +27,10 @@ static std::deque<c10::once_flag> cuda_gens_init_flag;
|
||||
static std::vector<Generator> default_gens_cuda;
|
||||
|
||||
/*
|
||||
* Populates the global variables related to CUDA generators
|
||||
* Warning: this function must only be called once!
|
||||
*/
|
||||
static void initCUDAGenVector(){
|
||||
* Populates the global variables related to CUDA generators
|
||||
* Warning: this function must only be called once!
|
||||
*/
|
||||
static void initCUDAGenVector() {
|
||||
num_gpus = c10::cuda::device_count();
|
||||
cuda_gens_init_flag.resize(num_gpus);
|
||||
default_gens_cuda.resize(num_gpus);
|
||||
@ -77,6 +80,150 @@ Generator createCUDAGenerator(DeviceIndex device_index) {
|
||||
|
||||
} // namespace cuda::detail
|
||||
|
||||
/**
|
||||
* Creates a clone of this CUDA Generator State.
|
||||
*/
|
||||
c10::intrusive_ptr<CUDAGeneratorState> CUDAGeneratorState::clone() {
|
||||
return make_intrusive<CUDAGeneratorState>(
|
||||
seed_, philox_offset_per_thread_, offset_intragraph_);
|
||||
}
|
||||
|
||||
/**
|
||||
* Function to increase the internal offset based on the specified increment.
|
||||
*/
|
||||
void CUDAGeneratorState::increase(uint64_t increment) {
|
||||
// Rounds increment up to the nearest multiple of 4 to meet alignment
|
||||
// requirements.
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
increment = ((increment + 3) / 4) * 4;
|
||||
// Handling different behaviors based on whether capturing is active.
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
// Ensures that the state is actually capturing.
|
||||
TORCH_CHECK(
|
||||
capturing_,
|
||||
"Attempt to increase offset for a CUDA generator not in capture mode.");
|
||||
// Ensures the offset is a multiple of 4
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
|
||||
// Ensures the increment does not cause overflow.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
|
||||
"Increment causes overflow in the offset value.");
|
||||
offset_intragraph_ += increment;
|
||||
} else {
|
||||
// Checks that the increment is expected outside graph capturing.
|
||||
TORCH_CHECK(
|
||||
!capturing_,
|
||||
"Offset increment outside graph capture encountered unexpectedly.");
|
||||
// Ensures the offset is a multiple of 4
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
philox_offset_per_thread_ % 4 == 0,
|
||||
"RNG offset must be a multiple of 4.");
|
||||
philox_offset_per_thread_ += increment;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers this state to a CUDA graph to manage within the graph.
|
||||
*/
|
||||
void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
|
||||
// Ensures that the RNG state is not currently being captured.
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot register the state during capturing stage.");
|
||||
|
||||
// If this is the first graph to be registered, allocate memory for the seed
|
||||
// and offset on the GPU.
|
||||
if (registered_graphs_.empty()) {
|
||||
auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong);
|
||||
seed_extragraph_ = at::empty({1}, options);
|
||||
offset_extragraph_ = at::empty({1}, options);
|
||||
}
|
||||
|
||||
// Insert the graph into the set of registered graphs if it's not already
|
||||
// registered.
|
||||
if (registered_graphs_.find(graph) == registered_graphs_.end()) {
|
||||
registered_graphs_.insert(graph);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregisters a CUDA graph from the RNG state.
|
||||
*/
|
||||
void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
|
||||
// Ensures that the RNG state is not currently being captured.
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot unregister the state during capturing stage.");
|
||||
// Verify the graph was previously registered.
|
||||
TORCH_CHECK(
|
||||
registered_graphs_.find(graph) != registered_graphs_.end(),
|
||||
"The graph should be registered to the state");
|
||||
|
||||
// Remove the graph from the set of registered graphs.
|
||||
registered_graphs_.erase(graph);
|
||||
|
||||
// If no more graphs are registered, deallocate the GPU memory for the seed
|
||||
// and offset.
|
||||
if (registered_graphs_.empty()) {
|
||||
seed_extragraph_.reset();
|
||||
offset_extragraph_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Note [Explicit Registration of Generators to the CUDA Graph]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
*
|
||||
* Ideally, it would be more user-friendly if the state could be exchanged and generators
|
||||
* could be registered with the CUDA graph implicitly. However, resetting GPU tensors during
|
||||
* the capture stage causes these reset operations to be recorded within the CUDA graph.
|
||||
* This behavior is undesirable because we do not want these tensors to be reset during
|
||||
* the replay stage of the graph.
|
||||
*
|
||||
* As of now, there is no available method to perform a CUDA operation during the graph's
|
||||
* recording phase without having that operation be included in the CUDA graph.
|
||||
* This limitation necessitates explicit user action to register generators with the graph.
|
||||
* By requiring users to manually register their generators, we can ensure that state resets
|
||||
* (capture_prologue) only occur before the graph capture begins, thus avoiding unintended
|
||||
* resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Performs the prologue steps for capturing a CUDA graph state.
|
||||
* This method is intended to reset graph-related state variables before capturing begins.
|
||||
*/
|
||||
void CUDAGeneratorState::capture_prologue() {
|
||||
capturing_ = true;
|
||||
offset_intragraph_ = 0;
|
||||
seed_extragraph_.fill_(int64_t(seed_));
|
||||
offset_extragraph_.fill_(int64_t(0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Ends the capturing phase and resets related variables, returning the whole
|
||||
* graph increment.
|
||||
*/
|
||||
uint64_t CUDAGeneratorState::capture_epilogue() {
|
||||
capturing_ = false;
|
||||
return offset_intragraph_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepares the state for replay by setting initial state tensors and applying
|
||||
* total increment.
|
||||
*/
|
||||
void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
|
||||
// Ensures the generator is not in capturing mode.
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot prepare for replay during capturing stage.");
|
||||
seed_extragraph_.fill_(int64_t(seed_));
|
||||
offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
|
||||
// Applies the total increment achieved during previous captures to update the
|
||||
// offset.
|
||||
increase(wholegraph_increment);
|
||||
}
|
||||
|
||||
/**
|
||||
* Note [Why enforce RNG offset % 4 == 0?]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -97,8 +244,18 @@ Generator createCUDAGenerator(DeviceIndex device_index) {
|
||||
*/
|
||||
CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
|
||||
: c10::GeneratorImpl{Device(DeviceType::CUDA, device_index),
|
||||
DispatchKeySet(c10::DispatchKey::CUDA)} {
|
||||
DispatchKeySet(c10::DispatchKey::CUDA)} {
|
||||
at::cuda::assertNotCapturing("Cannot construct a new CUDAGeneratorImpl");
|
||||
state_ = make_intrusive<CUDAGeneratorState>();
|
||||
no_reset_rnn_state_.clear();
|
||||
}
|
||||
|
||||
CUDAGeneratorImpl::CUDAGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
c10::intrusive_ptr<CUDAGeneratorState> state)
|
||||
: c10::
|
||||
GeneratorImpl{Device(DeviceType::CUDA, device_index), DispatchKeySet(c10::DispatchKey::CUDA)},
|
||||
state_(std::move(state)) {
|
||||
no_reset_rnn_state_.clear();
|
||||
}
|
||||
|
||||
@ -109,9 +266,10 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(DeviceIndex device_index)
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_current_seed");
|
||||
seed_ = seed;
|
||||
philox_offset_per_thread_ = 0;
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot call CUDAGeneratorImpl::set_current_seed");
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
no_reset_rnn_state_.clear();
|
||||
}
|
||||
|
||||
@ -134,15 +292,9 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
|
||||
// Debatable if get_offset() should be allowed in captured regions.
|
||||
// Conservatively disallow it for now.
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_offset");
|
||||
return philox_offset_per_thread_;
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
#define CAPTURE_DEFAULT_GENS_MSG \
|
||||
"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \
|
||||
"generator on the device that's current when capture begins. " \
|
||||
"If you need a non-default (user-supplied) generator, or a generator on another " \
|
||||
"device, please file an issue."
|
||||
|
||||
/**
|
||||
* Gets the current seed of CUDAGeneratorImpl.
|
||||
*/
|
||||
@ -150,7 +302,7 @@ uint64_t CUDAGeneratorImpl::current_seed() const {
|
||||
// Debatable if current_seed() should be allowed in captured regions.
|
||||
// Conservatively disallow it for now.
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
|
||||
return seed_;
|
||||
return state_->seed_;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -194,6 +346,8 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
* and size of the internal state.
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
@ -208,7 +362,7 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
|
||||
}
|
||||
|
||||
uint64_t input_seed;
|
||||
uint64_t input_seed = 0;
|
||||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
|
||||
memcpy(&input_seed, new_rng_state, seed_size);
|
||||
this->set_current_seed(input_seed);
|
||||
@ -219,44 +373,59 @@ void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the generator's current state to
|
||||
* This function allows switching between different registered states of
|
||||
* the generator.
|
||||
*/
|
||||
void CUDAGeneratorImpl::graphsafe_set_state(
|
||||
const c10::intrusive_ptr<GeneratorImpl>& gen) {
|
||||
c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
|
||||
dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(gen);
|
||||
TORCH_CHECK(cuda_gen, "Expected a CUDA Generator");
|
||||
state_ = cuda_gen->state_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the GeneratorImpl that point to current state_
|
||||
*/
|
||||
c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
|
||||
const {
|
||||
auto gen = make_intrusive<CUDAGeneratorImpl>(device().index(), state_);
|
||||
return gen;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
|
||||
*
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_philox_offset_per_thread");
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
|
||||
philox_offset_per_thread_ = offset;
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread");
|
||||
return philox_offset_per_thread_;
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by CUDAGraph to prepare this instance for a graph capture region.
|
||||
* offset_extragraph is the initial offset at the start of the graphed region.
|
||||
* offset_intragraph tracks the offset in the graphed region.
|
||||
* Registers this state to a CUDA graph to manage within the graph.
|
||||
*/
|
||||
void CUDAGeneratorImpl::capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph) {
|
||||
seed_extragraph_ = seed_extragraph;
|
||||
offset_extragraph_ = offset_extragraph;
|
||||
offset_intragraph_ = 0;
|
||||
graph_expects_this_gen_ = true;
|
||||
void CUDAGeneratorImpl::register_graph(cuda::CUDAGraph* graph) {
|
||||
graph->register_generator_state(state_);
|
||||
state_->register_graph(graph);
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by CUDAGraph to finalize a graph capture region for this instance.
|
||||
* Unregisters a CUDA graph from the RNG state.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::capture_epilogue() {
|
||||
graph_expects_this_gen_ = false;
|
||||
return offset_intragraph_;
|
||||
void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
|
||||
state_->unregister_graph(graph);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -281,30 +450,17 @@ uint64_t CUDAGeneratorImpl::capture_epilogue() {
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
|
||||
// rounds increment up to the nearest multiple of 4
|
||||
increment = ((increment + 3) / 4) * 4;
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
TORCH_CHECK(graph_expects_this_gen_,
|
||||
"philox_cuda_state for an unexpected CUDA generator used during capture. "
|
||||
CAPTURE_DEFAULT_GENS_MSG);
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_INTERNAL_ASSERT(this->offset_intragraph_ % 4 == 0);
|
||||
uint32_t offset = this->offset_intragraph_;
|
||||
TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <=
|
||||
std::numeric_limits<uint32_t>::max() - increment);
|
||||
this->offset_intragraph_ += increment;
|
||||
return PhiloxCudaState(this->seed_extragraph_,
|
||||
this->offset_extragraph_,
|
||||
offset);
|
||||
uint32_t offset = state_->offset_intragraph_;
|
||||
state_->increase(increment);
|
||||
return PhiloxCudaState(
|
||||
state_->seed_extragraph_.data_ptr<int64_t>(),
|
||||
state_->offset_extragraph_.data_ptr<int64_t>(),
|
||||
offset);
|
||||
} else {
|
||||
TORCH_CHECK(!graph_expects_this_gen_,
|
||||
"CUDA generator expects graph capture to be underway, "
|
||||
"but the current stream is not capturing.");
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
|
||||
uint64_t offset = this->philox_offset_per_thread_;
|
||||
this->philox_offset_per_thread_ += increment;
|
||||
return PhiloxCudaState(this->seed_, offset);
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return PhiloxCudaState(state_->seed_, offset);
|
||||
}
|
||||
}
|
||||
|
||||
@ -312,16 +468,13 @@ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
|
||||
* Temporarily accommodates call sites that use philox_engine_inputs.
|
||||
* Allows incremental refactor of call sites to use philox_cuda_state.
|
||||
*/
|
||||
std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(uint64_t increment) {
|
||||
at::cuda::assertNotCapturing("Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. "
|
||||
"Cannot call CUDAGeneratorImpl::philox_engine_inputs");
|
||||
// rounds increment up to the nearest multiple of 4
|
||||
increment = ((increment + 3) / 4) * 4;
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0);
|
||||
uint64_t offset = this->philox_offset_per_thread_;
|
||||
this->philox_offset_per_thread_ += increment;
|
||||
return std::make_pair(this->seed_, offset);
|
||||
std::pair<uint64_t, uint64_t> CUDAGeneratorImpl::philox_engine_inputs(
|
||||
uint64_t increment) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Refactor this op to use CUDAGeneratorImpl::philox_cuda_state. Cannot call CUDAGeneratorImpl::philox_engine_inputs");
|
||||
uint64_t offset = state_->philox_offset_per_thread_;
|
||||
state_->increase(increment);
|
||||
return std::make_pair(state_->seed_, offset);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -348,9 +501,7 @@ std::shared_ptr<CUDAGeneratorImpl> CUDAGeneratorImpl::clone() const {
|
||||
*/
|
||||
CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const {
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::clone_impl");
|
||||
auto gen = new CUDAGeneratorImpl(this->device().index());
|
||||
gen->set_current_seed(this->seed_);
|
||||
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_);
|
||||
auto gen = new CUDAGeneratorImpl(this->device().index(), state_->clone());
|
||||
return gen;
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/cuda/PhiloxCudaState.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <limits>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/cuda/PhiloxCudaState.h>
|
||||
#include <atomic>
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
namespace at {
|
||||
|
||||
namespace cuda {
|
||||
struct CUDAGraph;
|
||||
}
|
||||
|
||||
/**
|
||||
* Note [CUDA Graph-safe RNG states]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -87,9 +94,41 @@ namespace at {
|
||||
*
|
||||
*/
|
||||
|
||||
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
|
||||
uint64_t seed_;
|
||||
uint64_t philox_offset_per_thread_;
|
||||
uint32_t offset_intragraph_;
|
||||
bool capturing_{};
|
||||
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
|
||||
at::TensorBase seed_extragraph_{};
|
||||
at::TensorBase offset_extragraph_{};
|
||||
|
||||
CUDAGeneratorState(
|
||||
uint64_t seed = default_rng_seed_val,
|
||||
uint64_t philox_offset_per_thread = 0,
|
||||
uint32_t offset_intragraph = 0)
|
||||
: seed_(seed),
|
||||
philox_offset_per_thread_(philox_offset_per_thread),
|
||||
offset_intragraph_(offset_intragraph) {}
|
||||
|
||||
void increase(uint64_t increment);
|
||||
|
||||
void register_graph(cuda::CUDAGraph* graph);
|
||||
void unregister_graph(cuda::CUDAGraph* graph);
|
||||
|
||||
void capture_prologue();
|
||||
// capture_epilogue returns the wholegraph_increment
|
||||
uint64_t capture_epilogue();
|
||||
void replay_prologue(uint64_t wholegraph_increment);
|
||||
c10::intrusive_ptr<CUDAGeneratorState> clone();
|
||||
};
|
||||
|
||||
struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
|
||||
// Constructors
|
||||
CUDAGeneratorImpl(DeviceIndex device_index = -1);
|
||||
CUDAGeneratorImpl(
|
||||
DeviceIndex device_index,
|
||||
c10::intrusive_ptr<CUDAGeneratorState> state_);
|
||||
~CUDAGeneratorImpl() override = default;
|
||||
|
||||
// CUDAGeneratorImpl methods
|
||||
@ -101,10 +140,18 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
|
||||
uint64_t seed() override;
|
||||
void set_state(const c10::TensorImpl& new_state) override;
|
||||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
void graphsafe_set_state(
|
||||
const c10::intrusive_ptr<GeneratorImpl>& state) override;
|
||||
c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override;
|
||||
|
||||
void set_philox_offset_per_thread(uint64_t offset);
|
||||
uint64_t philox_offset_per_thread() const;
|
||||
void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph);
|
||||
uint64_t capture_epilogue();
|
||||
|
||||
void register_graph(cuda::CUDAGraph* graph);
|
||||
void unregister_graph(cuda::CUDAGraph* graph);
|
||||
|
||||
// Generates a PhiloxCudaState with a specified increment, and increment
|
||||
// current state
|
||||
PhiloxCudaState philox_cuda_state(uint64_t increment);
|
||||
|
||||
bool reset_rnn_state() {
|
||||
@ -117,14 +164,10 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
|
||||
|
||||
static c10::DeviceType device_type();
|
||||
|
||||
private:
|
||||
private:
|
||||
CUDAGeneratorImpl* clone_impl() const override;
|
||||
uint64_t seed_ = default_rng_seed_val;
|
||||
uint64_t philox_offset_per_thread_ = 0;
|
||||
int64_t* seed_extragraph_{};
|
||||
int64_t* offset_extragraph_{};
|
||||
uint32_t offset_intragraph_ = 0;
|
||||
bool graph_expects_this_gen_ = false;
|
||||
|
||||
c10::intrusive_ptr<CUDAGeneratorState> state_;
|
||||
std::atomic_flag no_reset_rnn_state_;
|
||||
};
|
||||
|
||||
|
@ -6,7 +6,10 @@
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
@ -86,26 +89,33 @@ CUDAGraph::CUDAGraph()
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDAGraph::register_generator_state(
|
||||
c10::intrusive_ptr<at::CUDAGeneratorState> state) {
|
||||
captured_generator_states_[std::move(state)] = 0;
|
||||
}
|
||||
|
||||
void CUDAGraph::register_generator_state(const at::Generator& generator) {
|
||||
c10::intrusive_ptr<CUDAGeneratorImpl> cuda_gen =
|
||||
dynamic_intrusive_pointer_cast<CUDAGeneratorImpl>(
|
||||
generator.getIntrusivePtr());
|
||||
cuda_gen->register_graph(this);
|
||||
}
|
||||
|
||||
void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capture_mode) {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
TORCH_CHECK(!has_graph_exec_,
|
||||
"This CUDAGraph instance already owns a captured graph. "
|
||||
"To capture a new graph, create a new instance.");
|
||||
|
||||
// For now, a CUDAGraph instance only accommodates the default generator on the device that's
|
||||
// current when capture begins. If any op in the captured region uses a non-default generator,
|
||||
// or a generator on another device, the offending generator will throw an error.
|
||||
// These restrictions simplify CUDAGraph, but could be relaxed in the future:
|
||||
// in principle, the underlying Cuda calls do permit cross-device ops to be captured.
|
||||
// default generator is always registered
|
||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
|
||||
gen->register_graph(this);
|
||||
|
||||
auto options = TensorOptions().device(at::kCUDA).dtype(at::kLong);
|
||||
seed_extragraph_ = at::empty({1}, options);
|
||||
offset_extragraph_ = at::empty({1}, options);
|
||||
|
||||
seed_extragraph_.fill_(int64_t(gen->current_seed()));
|
||||
gen->capture_prologue(seed_extragraph_.data_ptr<int64_t>(), offset_extragraph_.mutable_data_ptr<int64_t>());
|
||||
for (auto& [generator_state, wholegraph_increments] :
|
||||
captured_generator_states_) {
|
||||
generator_state->capture_prologue();
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@ -115,7 +125,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
"default stream.)");
|
||||
|
||||
capture_stream_ = stream;
|
||||
capture_gen_ = gen;
|
||||
capture_dev_ = c10::cuda::current_device();
|
||||
|
||||
id_ = capture_sequence_id();
|
||||
@ -215,13 +224,10 @@ void CUDAGraph::capture_end() {
|
||||
|
||||
has_graph_exec_ = true;
|
||||
|
||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
|
||||
TORCH_CHECK(gen == capture_gen_,
|
||||
"Default CUDA RNG generator on current device at capture end "
|
||||
"is different from default generator on current device "
|
||||
"when capture began");
|
||||
wholegraph_increment_ = gen->capture_epilogue();
|
||||
for (auto& [generator_state, wholegraph_increments] :
|
||||
captured_generator_states_) {
|
||||
wholegraph_increments = generator_state->capture_epilogue();
|
||||
}
|
||||
|
||||
size_t numCUDAGraphNodes = 0;
|
||||
AT_CUDA_CHECK(cudaGraphGetNodes(graph_, NULL, &numCUDAGraphNodes));
|
||||
@ -251,17 +257,10 @@ void CUDAGraph::replay() {
|
||||
|
||||
c10::OptionalDeviceGuard device_guard{capture_stream_.device()};
|
||||
|
||||
// Just like any RNG consumer kernel!
|
||||
auto* gen = get_generator_or_default<CUDAGeneratorImpl>(
|
||||
c10::nullopt, cuda::detail::getDefaultCUDAGenerator());
|
||||
PhiloxCudaState rng_engine_inputs;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
rng_engine_inputs = gen->philox_cuda_state(wholegraph_increment_);
|
||||
for (auto& [generator_state, wholegraph_increments] :
|
||||
captured_generator_states_) {
|
||||
generator_state->replay_prologue(wholegraph_increments);
|
||||
}
|
||||
seed_extragraph_.fill_(int64_t(gen->current_seed()));
|
||||
offset_extragraph_.fill_(int64_t(rng_engine_inputs.offset_.val));
|
||||
|
||||
// graph_exec_ may be replayed in any stream.
|
||||
AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream()));
|
||||
|
||||
@ -355,6 +354,10 @@ TORCH_CHECK(has_graph_exec_,
|
||||
}
|
||||
|
||||
CUDAGraph::~CUDAGraph() {
|
||||
for (auto& [generator_state, wholegraph_increments] :
|
||||
captured_generator_states_) {
|
||||
generator_state->unregister_graph(this);
|
||||
}
|
||||
reset();
|
||||
}
|
||||
|
||||
|
@ -4,12 +4,13 @@
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
struct Generator;
|
||||
struct CUDAGeneratorImpl;
|
||||
struct CUDAGeneratorState;
|
||||
|
||||
namespace cuda {
|
||||
|
||||
@ -24,7 +25,12 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
static void inc_pending_event_queries();
|
||||
static void dec_pending_event_queries();
|
||||
static int num_pending_event_queries();
|
||||
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
|
||||
// See Note [Explicit Registration of Generators to the CUDA Graph]
|
||||
void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
|
||||
void register_generator_state(const at::Generator& generator);
|
||||
void capture_begin(
|
||||
MempoolId_t pool = {0, 0},
|
||||
cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
|
||||
void capture_end();
|
||||
void replay();
|
||||
void reset();
|
||||
@ -32,7 +38,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
void enable_debug_mode();
|
||||
void debug_dump(const std::string& debug_path);
|
||||
|
||||
protected:
|
||||
protected:
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
cudaGraph_t graph_ = NULL;
|
||||
cudaGraphExec_t graph_exec_ = NULL;
|
||||
@ -73,19 +79,16 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
||||
// Stream on which capture began
|
||||
at::cuda::CUDAStream capture_stream_;
|
||||
|
||||
// Default generator on device where capture began
|
||||
at::CUDAGeneratorImpl* capture_gen_;
|
||||
// multiple generator states and their wholegraph_increments in this graph
|
||||
// that are managed by the CUDA Graph
|
||||
ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t>
|
||||
captured_generator_states_;
|
||||
|
||||
// Device where capture occurred. Right now, for simplicity, we require all ops
|
||||
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
|
||||
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
|
||||
// captures if needed.
|
||||
int capture_dev_;
|
||||
|
||||
// RNG state trackers
|
||||
at::Tensor seed_extragraph_;
|
||||
at::Tensor offset_extragraph_;
|
||||
uint64_t wholegraph_increment_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
@ -31,6 +31,18 @@ c10::intrusive_ptr<GeneratorImpl> GeneratorImpl::clone() const {
|
||||
return c10::intrusive_ptr<GeneratorImpl>::reclaim(res);
|
||||
}
|
||||
|
||||
void GeneratorImpl::graphsafe_set_state(
|
||||
const c10::intrusive_ptr<c10::GeneratorImpl>& state) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "graphsafe_set_state is not supported in this Generator");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::GeneratorImpl> GeneratorImpl::graphsafe_get_state()
|
||||
const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "graphsafe_get_state is not supported in this Generator");
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the device of a generator.
|
||||
*/
|
||||
|
@ -73,6 +73,9 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {
|
||||
virtual uint64_t seed() = 0;
|
||||
virtual void set_state(const c10::TensorImpl& new_state) = 0;
|
||||
virtual c10::intrusive_ptr<c10::TensorImpl> get_state() const = 0;
|
||||
virtual void graphsafe_set_state(
|
||||
const c10::intrusive_ptr<c10::GeneratorImpl>& new_state);
|
||||
virtual c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const;
|
||||
Device device() const;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
|
@ -976,9 +976,12 @@ Violating any of these will likely cause a runtime error:
|
||||
:class:`~torch.cuda.graph` and
|
||||
:func:`~torch.cuda.make_graphed_callables` set a side stream for you.)
|
||||
* Ops that synchronize the CPU with the GPU (e.g., ``.item()`` calls) are prohibited.
|
||||
* CUDA RNG ops are allowed, but must use default generators. For example, explicitly constructing a
|
||||
new :class:`torch.Generator` instance and passing it as the ``generator`` argument to an RNG function
|
||||
is prohibited.
|
||||
* CUDA RNG operations are permitted, and when using multiple :class:`torch.Generator` instances within a graph,
|
||||
they must be registered using :meth:`CUDAGraph.register_generator_state<torch.cuda.CUDAGraph.register_generator_state>` before graph capture.
|
||||
Avoid using :meth:`Generator.get_state<torch.get_state>` and :meth:`Generator.set_state<torch.set_state>` during capture;
|
||||
instead, utilize :meth:`Generator.graphsafe_set_state<torch.Generator.graphsafe_set_state>` and :meth:`Generator.graphsafe_get_state<torch.Generator.graphsafe_get_state>`
|
||||
for managing generator states safely within the graph context. This ensures proper RNG operation and generator management within CUDA graphs.
|
||||
|
||||
|
||||
Violating any of these will likely cause silent numerical errors or undefined behavior:
|
||||
|
||||
|
@ -171,7 +171,6 @@ class TestCuda(TestCase):
|
||||
tensor.fill_(1)
|
||||
self.assertTrue((tensor == 1).all())
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)")
|
||||
def test_out_of_memory_retry(self):
|
||||
torch.cuda.empty_cache()
|
||||
@ -256,7 +255,6 @@ class TestCuda(TestCase):
|
||||
c.copy_(b, non_blocking=True)
|
||||
self.assertEqual(a, c, exact_dtype=False)
|
||||
|
||||
|
||||
def test_to_non_blocking(self):
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
@ -384,7 +382,6 @@ class TestCuda(TestCase):
|
||||
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
|
||||
|
||||
|
||||
def test_cudnn_allow_tf32_get_set(self):
|
||||
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
||||
self.assertFalse(torch.backends.cudnn.allow_tf32)
|
||||
@ -1463,8 +1460,6 @@ torch.cuda.synchronize()
|
||||
for op, args in self.autocast_lists.nn_fp16:
|
||||
self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._nn)
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
|
||||
def test_autocast_nn_bf16(self):
|
||||
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
|
||||
@ -1783,6 +1778,163 @@ torch.cuda.synchronize()
|
||||
|
||||
self.assertTrue(b.sum().item() == 11000.)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_graphsafe_set_get_rng_state(self):
|
||||
|
||||
# Define a function to create generator states, with optional graph registration
|
||||
def create_states(generator):
|
||||
"""Initializes generator states and registers them with a CUDA graph if provided."""
|
||||
# Ensure the CUDA generator is initialized
|
||||
torch.rand(1, device="cuda")
|
||||
generator.manual_seed(0)
|
||||
|
||||
# Save the current state of the generator
|
||||
old_state = generator.graphsafe_get_state()
|
||||
# Create and save a cloned state of the generator
|
||||
new_state = generator.clone_state()
|
||||
# Return the original generator and its two states
|
||||
return generator, old_state, new_state
|
||||
|
||||
def register_states_to_graph(generator_state, graph):
|
||||
generator, old_state, new_state = generator_state
|
||||
graph.register_generator_state(old_state)
|
||||
graph.register_generator_state(new_state)
|
||||
|
||||
# Define a function to perform specific RNG actions using the generator's states
|
||||
def perform_random_generation_steps(generator_state):
|
||||
generator, old_state, new_state = generator_state
|
||||
random_values = []
|
||||
|
||||
# Generate random numbers with the new generator state
|
||||
generator.graphsafe_set_state(new_state)
|
||||
random_values.append(torch.rand(5, device="cuda", generator=generator))
|
||||
|
||||
# Generate random numbers twice with the old generator state
|
||||
generator.graphsafe_set_state(old_state)
|
||||
random_values.extend(
|
||||
[torch.rand(5, device="cuda", generator=generator) for _ in range(2)]
|
||||
)
|
||||
|
||||
return random_values
|
||||
|
||||
# Define a function to retrieve the final offsets of the original and new generator states
|
||||
def get_final_offsets_of_states(generator_state):
|
||||
generator, old_state, new_state = generator_state
|
||||
old_state_offset = old_state.get_offset()
|
||||
new_state_offset = new_state.get_offset()
|
||||
return old_state_offset, new_state_offset
|
||||
|
||||
# Set up and test a new CUDA generator
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator_state = create_states(generator)
|
||||
|
||||
# Set up and test the default CUDA generator with a CUDA Graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
s = torch.cuda.Stream()
|
||||
default_generator = torch.cuda.default_generators[0]
|
||||
default_generator_state = create_states(default_generator)
|
||||
register_states_to_graph(default_generator_state, g)
|
||||
|
||||
# Perform random number generation within a CUDA graph
|
||||
with torch.cuda.stream(s):
|
||||
g.capture_begin()
|
||||
graphed_random_values = perform_random_generation_steps(
|
||||
default_generator_state
|
||||
)
|
||||
g.capture_end()
|
||||
|
||||
# Synchronize the streams and replay the graph
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
for _ in range(3):
|
||||
random_values = perform_random_generation_steps(generator_state)
|
||||
g.replay()
|
||||
offset = get_final_offsets_of_states(generator_state)
|
||||
graph_offset = get_final_offsets_of_states(default_generator_state)
|
||||
|
||||
# Compare the final offsets of states for both generators to ensure consistency
|
||||
self.assertTrue(offset == graph_offset)
|
||||
# Compare the states generated outside and inside the graph
|
||||
self.assertEqual(random_values, graphed_random_values)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_memory_stats_of_multiple_generators_and_graphs(self):
|
||||
# Function to clear CUDA cache and collect garbage
|
||||
def clear_cuda_cache():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph.
|
||||
def simple_graph_task(graph):
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
graph.capture_begin()
|
||||
torch.rand(1, device="cuda")
|
||||
graph.capture_end()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
graph.replay() # Replays the captured operations
|
||||
|
||||
def get_memory_stats():
|
||||
stats = torch.cuda.memory_stats()
|
||||
num_blocks = stats["active.all.current"]
|
||||
total_size = stats["active_bytes.all.current"]
|
||||
return num_blocks, total_size
|
||||
|
||||
def test(num_graphs, num_generators):
|
||||
baseline = get_memory_stats()
|
||||
baseline_num_blocks, baseline_total_size = baseline
|
||||
|
||||
# Allocate CUDA graphs
|
||||
graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)]
|
||||
|
||||
# Allocate and manage generator states
|
||||
default_generator = torch.cuda.default_generators[0]
|
||||
generators = [default_generator.graphsafe_get_state()]
|
||||
|
||||
# Starts from 1 as one state is already added
|
||||
for _ in range(1, num_generators):
|
||||
generators.append(default_generator.clone_state())
|
||||
|
||||
for graph in graphs:
|
||||
for generator_state in generators:
|
||||
graph.register_generator_state(generator_state)
|
||||
simple_graph_task(graph)
|
||||
|
||||
# Assert conditions after graph tasks
|
||||
num_blocks, total_size = get_memory_stats()
|
||||
# The allocated blocks should only be proportional to the number of generators
|
||||
expected_blocks_diff = 2 * num_generators
|
||||
expected_size_diff = 2 * 512 * num_generators # Each block's size is 512
|
||||
|
||||
self.assertTrue(
|
||||
(num_blocks - baseline_num_blocks) == expected_blocks_diff,
|
||||
"Unexpected number of active blocks.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(total_size - baseline_total_size) == expected_size_diff,
|
||||
"Unexpected total memory size.",
|
||||
)
|
||||
|
||||
# Cleanup graphs and clear CUDA cache
|
||||
while graphs:
|
||||
graph = graphs.pop()
|
||||
del graph
|
||||
clear_cuda_cache()
|
||||
|
||||
# Assert that memory stats return to baseline after cleanup
|
||||
self.assertTrue(
|
||||
get_memory_stats() == baseline,
|
||||
"Memory stats do not match baseline after cleanup.",
|
||||
)
|
||||
|
||||
# Running the test function with different parameters
|
||||
test(1, 1)
|
||||
test(3, 2)
|
||||
test(10, 20)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
|
||||
def test_graph_capture_reset_recapture(self):
|
||||
s = torch.cuda.Stream()
|
||||
@ -2332,14 +2484,18 @@ exit(2)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
|
||||
for (numel,
|
||||
delta_cudaMallocs,
|
||||
delta_cudaMalloc_bytes,
|
||||
delta_cudaMalloc_bytes_post_del_g,
|
||||
pool_string) in cases:
|
||||
for (
|
||||
numel,
|
||||
delta_cudaMallocs,
|
||||
delta_cudaMalloc_bytes,
|
||||
delta_cudaMalloc_bytes_post_del_g,
|
||||
pool_string,
|
||||
) in cases:
|
||||
if pool_string == "small_pool":
|
||||
delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders
|
||||
delta_active_bytes = numel * elem + 1024 # + 1024 for CUDAGraph's rng seed and offset holders each
|
||||
delta_active_bytes = (
|
||||
numel * elem + 1024
|
||||
) # + 1024 for CUDAGraph's rng seed and offset holders each
|
||||
else:
|
||||
delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder
|
||||
delta_active_bytes = numel * elem
|
||||
@ -3085,8 +3241,6 @@ exit(2)
|
||||
self.assertEqual(rc, "3")
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCudaMallocAsync(TestCase):
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync")
|
||||
@ -3165,7 +3319,6 @@ class TestCudaMallocAsync(TestCase):
|
||||
finally:
|
||||
torch.cuda.memory._record_memory_history(None)
|
||||
|
||||
|
||||
@skipIfRocm
|
||||
def test_memory_profiler_viz(self):
|
||||
with torch.profiler.profile(
|
||||
@ -3448,7 +3601,6 @@ class TestCudaMallocAsync(TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.cuda.memory._set_allocator_settings("pinned_num_register_threads:1024")
|
||||
|
||||
|
||||
@parametrize(
|
||||
"max_split_size_mb_setting", [False, True]
|
||||
)
|
||||
|
@ -1385,6 +1385,9 @@ class Generator:
|
||||
def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
|
||||
def get_state(self) -> Tensor: ...
|
||||
def set_state(self, _new_state: Tensor) -> Generator: ...
|
||||
def clone_state(self) -> Generator: ...
|
||||
def graphsafe_get_state(self) -> Generator: ...
|
||||
def graphsafe_set_state(self, _new_state: Generator) -> Generator: ...
|
||||
def set_offset(self, offset: _int) -> Generator: ...
|
||||
def get_offset(self) -> _int: ...
|
||||
def manual_seed(self, seed: _int) -> Generator: ...
|
||||
@ -1883,6 +1886,7 @@ class _CudaEventBase:
|
||||
class _CUDAGraph:
|
||||
def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
|
||||
def capture_end(self) -> None: ...
|
||||
def register_generator_state(self, Generator) -> None: ...
|
||||
def replay(self) -> None: ...
|
||||
def reset(self) -> None: ...
|
||||
def pool(self) -> Tuple[_int, _int]: ...
|
||||
|
@ -13722,6 +13722,58 @@ Example::
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.graphsafe_set_state,
|
||||
r"""
|
||||
Generator.graphsafe_set_state(state) -> None
|
||||
|
||||
Sets the state of the generator to the specified state in a manner that is safe for use in graph capture.
|
||||
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
|
||||
|
||||
Arguments:
|
||||
state (torch.Generator): A Generator point to the new state for the generator, typically obtained from `graphsafe_get_state`.
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> g_cuda_other = torch.Generator(device='cuda')
|
||||
>>> current_state = g_cuda_other.graphsafe_get_state()
|
||||
>>> g_cuda.graphsafe_set_state(current_state)
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.graphsafe_get_state,
|
||||
r"""
|
||||
Generator.graphsafe_get_state() -> torch.Generator
|
||||
|
||||
Retrieves the current state of the generator in a manner that is safe for graph capture.
|
||||
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
|
||||
|
||||
Returns:
|
||||
torch.Generator: A Generator point to the current state of the generator
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> current_state = g_cuda.graphsafe_get_state()
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.clone_state,
|
||||
r"""
|
||||
Generator.clone_state() -> torch.Generator
|
||||
|
||||
Clones the current state of the generator and returns a new generator pointing to this cloned state.
|
||||
This method is beneficial for preserving a particular state of a generator to restore at a later point.
|
||||
|
||||
Returns:
|
||||
torch.Generator: A Generator pointing to the newly cloned state.
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> cloned_state = g_cuda.clone_state()
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.manual_seed,
|
||||
|
@ -143,6 +143,47 @@ uint64_t unpack_uint64(PyObject* pyobj) {
|
||||
return unsigned_obj;
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_graphSafeGetState(
|
||||
PyObject* _self,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& gen = ((THPGenerator*)_self)->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
|
||||
return THPGenerator_Wrap(gen.graphsafe_get_state());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_graphSafeSetState(
|
||||
PyObject* _self,
|
||||
PyObject* _state) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPGenerator*)_self;
|
||||
auto& gen = self->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
gen.graphsafe_set_state(THPGenerator_Unwrap(_state));
|
||||
|
||||
Py_INCREF(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_cloneState(PyObject* _self, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto& gen = ((THPGenerator*)_self)->cdata;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::scoped_lock<std::mutex> lock(gen.mutex());
|
||||
auto new_generator = gen.clone();
|
||||
|
||||
return THPGenerator_Wrap(new_generator);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THPGenerator_manualSeed(PyObject* _self, PyObject* seed) {
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPGenerator*)_self;
|
||||
@ -218,6 +259,12 @@ static struct PyGetSetDef THPGenerator_properties[] = {
|
||||
static PyMethodDef THPGenerator_methods[] = {
|
||||
{"get_state", THPGenerator_getState, METH_NOARGS, nullptr},
|
||||
{"set_state", THPGenerator_setState, METH_O, nullptr},
|
||||
{"clone_state", THPGenerator_cloneState, METH_NOARGS, nullptr},
|
||||
{"graphsafe_get_state",
|
||||
THPGenerator_graphSafeGetState,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"graphsafe_set_state", THPGenerator_graphSafeSetState, METH_O, nullptr},
|
||||
{"set_offset", THPGenerator_setOffset, METH_O, nullptr},
|
||||
{"manual_seed", THPGenerator_manualSeed, METH_O, nullptr},
|
||||
{"seed", THPGenerator_seed, METH_NOARGS, nullptr},
|
||||
@ -304,6 +351,14 @@ PyObject* THPGenerator_Wrap(Generator gen) {
|
||||
(PyTypeObject*)THPGeneratorClass, std::move(gen));
|
||||
}
|
||||
|
||||
at::Generator THPGenerator_Unwrap(PyObject* state) {
|
||||
if (!Py_IS_TYPE(state, &THPGeneratorType)) {
|
||||
throw torch::TypeError(
|
||||
"expected a Generator, but got %s", Py_TYPE(state)->tp_name);
|
||||
}
|
||||
return reinterpret_cast<THPGenerator*>(state)->cdata;
|
||||
}
|
||||
|
||||
// Creates a new Python object for a Generator. The Generator must not already
|
||||
// have a PyObject* associated with it.
|
||||
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, Generator gen) {
|
||||
|
@ -23,6 +23,8 @@ bool THPGenerator_init(PyObject* module);
|
||||
|
||||
TORCH_PYTHON_API PyObject* THPGenerator_Wrap(at::Generator gen);
|
||||
|
||||
TORCH_PYTHON_API at::Generator THPGenerator_Unwrap(PyObject* state);
|
||||
|
||||
// Creates a new Python object for a Generator. The Generator must not already
|
||||
// have a PyObject* associated with it.
|
||||
PyObject* THPGenerator_NewWithVar(PyTypeObject* type, at::Generator gen);
|
||||
|
@ -56,6 +56,16 @@ void THCPGraph_init(PyObject* module) {
|
||||
.def(
|
||||
"capture_end",
|
||||
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end))
|
||||
.def(
|
||||
"register_generator_state",
|
||||
[](::at::cuda::CUDAGraph& self, py::handle raw_generator) {
|
||||
auto generator = THPGenerator_Unwrap(raw_generator.ptr());
|
||||
// We've unwrapped Python object to C++ object,
|
||||
// so we could release GIL before calling into C++
|
||||
py::gil_scoped_release release;
|
||||
return self.register_generator_state(generator);
|
||||
},
|
||||
py::arg("generator"))
|
||||
.def(
|
||||
"replay",
|
||||
torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay))
|
||||
|
Reference in New Issue
Block a user