Make torch.cuda.rng_set_state() and torch.cuda.rng_get_state() work during stream capture. (#162505)

Note that this works only in a limited case, where you *don't* change the seed, but change only the offset of the philox generator. This captures the main use case we're interested in: Rewinding the RNG to a previous state. This is done by torch.utils.checkpoint.checkpoint in particular.

Calls to increase() change only the offset, not the seed. Thus, we allow for "no-op" calls to set_seed where the new seed is the same as the old seed. If a user does happen to try to change the seed during stream capture, they will receive an error.

Fixes #162504

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162505
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/eellison, https://github.com/eee4017, https://github.com/cyyever
This commit is contained in:
Daniel Galvez
2025-09-17 03:57:34 +00:00
committed by PyTorch MergeBot
parent e28983be76
commit 7a3791c5d0
2 changed files with 70 additions and 12 deletions

View File

@ -266,11 +266,14 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(
* See Note [Acquire lock when using random generators]
*/
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
at::cuda::assertNotCapturing(
"Cannot call CUDAGeneratorImpl::set_current_seed");
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
no_reset_rnn_state_.clear();
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
no_reset_rnn_state_.clear();
} else {
TORCH_CHECK(state_->seed_ == seed, "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
// no-op case
}
}
/**
@ -299,9 +302,6 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
* Gets the current seed of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::current_seed() const {
// Debatable if current_seed() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
return state_->seed_;
}
@ -346,8 +346,6 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::cuda::assertNotCapturing(
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
@ -402,15 +400,27 @@ c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
*/
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
// see Note [Why enforce RNG offset % 4 == 0?]
// Note: If you use CUDNN RNN's, calling
// set_philox_offset_per_thread instead of set_offset will cause the
// cudnn RNN rng state to become stale.
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
state_->philox_offset_per_thread_ = offset;
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
state_->philox_offset_per_thread_ = offset;
} else {
state_->offset_intragraph_ = offset;
}
}
/**
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
return state_->philox_offset_per_thread_;
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
return state_->philox_offset_per_thread_;
} else {
return state_->offset_intragraph_;
}
}
/**

View File

@ -3157,6 +3157,54 @@ exit(2)
model(x)
@skipIfRocm
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@serialTest()
def test_graph_checkpoint_preserve_rng_state(self):
torch.cuda.manual_seed(42)
def fn(x):
return x * torch.sigmoid(torch.randn(1, device="cuda"))
fn(torch.ones(1, device="cuda"))
torch.cuda.manual_seed(42)
eager_in = torch.ones(1, device="cuda", requires_grad=True)
eager_out = torch.utils.checkpoint.checkpoint(
fn, eager_in, use_reentrant=False, preserve_rng_state=True
)
(eager_in_grad,) = torch.autograd.grad(eager_out, eager_in)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
graph_in = torch.ones(1, device="cuda", requires_grad=True)
graph_out = torch.utils.checkpoint.checkpoint(
fn, graph_in, use_reentrant=False, preserve_rng_state=True
)
(graph_in_grad,) = torch.autograd.grad(graph_out, graph_in)
torch.cuda.manual_seed(42)
g.replay()
self.assertEqual(eager_in_grad, graph_in_grad, rtol=0.0, atol=0.0)
@skipIfRocm
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@serialTest()
def test_graph_manual_seed_mismatch_raises(self):
torch.cuda.manual_seed(0)
g = torch.cuda.CUDAGraph()
with self.assertRaisesRegex(
RuntimeError,
"CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", # noqa: B950
):
with torch.cuda.graph(g):
torch.cuda.manual_seed(1)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)