mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make torch.cuda.rng_set_state() and torch.cuda.rng_get_state() work during stream capture. (#162505)
Note that this works only in a limited case, where you *don't* change the seed, but change only the offset of the philox generator. This captures the main use case we're interested in: Rewinding the RNG to a previous state. This is done by torch.utils.checkpoint.checkpoint in particular. Calls to increase() change only the offset, not the seed. Thus, we allow for "no-op" calls to set_seed where the new seed is the same as the old seed. If a user does happen to try to change the seed during stream capture, they will receive an error. Fixes #162504 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162505 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/eellison, https://github.com/eee4017, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
e28983be76
commit
7a3791c5d0
@ -266,11 +266,14 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot call CUDAGeneratorImpl::set_current_seed");
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
no_reset_rnn_state_.clear();
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
no_reset_rnn_state_.clear();
|
||||
} else {
|
||||
TORCH_CHECK(state_->seed_ == seed, "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
|
||||
// no-op case
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -299,9 +302,6 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
|
||||
* Gets the current seed of CUDAGeneratorImpl.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::current_seed() const {
|
||||
// Debatable if current_seed() should be allowed in captured regions.
|
||||
// Conservatively disallow it for now.
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
|
||||
return state_->seed_;
|
||||
}
|
||||
|
||||
@ -346,8 +346,6 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
* and size of the internal state.
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
@ -402,15 +400,27 @@ c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
|
||||
// Note: If you use CUDNN RNN's, calling
|
||||
// set_philox_offset_per_thread instead of set_offset will cause the
|
||||
// cudnn RNN rng state to become stale.
|
||||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
} else {
|
||||
state_->offset_intragraph_ = offset;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
|
||||
return state_->philox_offset_per_thread_;
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
return state_->philox_offset_per_thread_;
|
||||
} else {
|
||||
return state_->offset_intragraph_;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3157,6 +3157,54 @@ exit(2)
|
||||
|
||||
model(x)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@serialTest()
|
||||
def test_graph_checkpoint_preserve_rng_state(self):
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
def fn(x):
|
||||
return x * torch.sigmoid(torch.randn(1, device="cuda"))
|
||||
|
||||
fn(torch.ones(1, device="cuda"))
|
||||
|
||||
torch.cuda.manual_seed(42)
|
||||
eager_in = torch.ones(1, device="cuda", requires_grad=True)
|
||||
eager_out = torch.utils.checkpoint.checkpoint(
|
||||
fn, eager_in, use_reentrant=False, preserve_rng_state=True
|
||||
)
|
||||
(eager_in_grad,) = torch.autograd.grad(eager_out, eager_in)
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
graph_in = torch.ones(1, device="cuda", requires_grad=True)
|
||||
graph_out = torch.utils.checkpoint.checkpoint(
|
||||
fn, graph_in, use_reentrant=False, preserve_rng_state=True
|
||||
)
|
||||
(graph_in_grad,) = torch.autograd.grad(graph_out, graph_in)
|
||||
|
||||
torch.cuda.manual_seed(42)
|
||||
g.replay()
|
||||
|
||||
self.assertEqual(eager_in_grad, graph_in_grad, rtol=0.0, atol=0.0)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@serialTest()
|
||||
def test_graph_manual_seed_mismatch_raises(self):
|
||||
torch.cuda.manual_seed(0)
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", # noqa: B950
|
||||
):
|
||||
with torch.cuda.graph(g):
|
||||
torch.cuda.manual_seed(1)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
|
Reference in New Issue
Block a user