mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add torch.random.fork_rng, which forks the RNG temporarily.
There is a bit of nuance to this function. If one blindly charges in and initializes all GPUs, it is going to take a long time. 20sec for 8 GPUs on my dev machine. But to a user, it is non-obvious that fork_rng is going to hit all the GPUs by default (which it does by default for safety reasons.) So there is a nice warning when we notice we're hitting more than one GPU. There is a bit of extra generality which is going to be used by torch.jit in a subsequent commit.
This commit is contained in:
committed by
Soumith Chintala
parent
539ae451d2
commit
2861638e8a
@ -1,4 +1,6 @@
|
||||
import torch
|
||||
import contextlib
|
||||
import warnings
|
||||
|
||||
from torch._C import default_generator
|
||||
|
||||
@ -37,3 +39,72 @@ def initial_seed():
|
||||
python `long`.
|
||||
"""
|
||||
return default_generator.initial_seed()
|
||||
|
||||
|
||||
_fork_rng_warned_already = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
|
||||
"""
|
||||
Forks the RNG, so that when you return, the RNG is reset
|
||||
to the state that it was previously in.
|
||||
|
||||
Arguments:
|
||||
devices (iterable of CUDA IDs): CUDA devices for which to fork
|
||||
the RNG. CPU RNG state is always forked. By default, fork_rng operates
|
||||
on all devices, but will emit a warning if your machine has a lot
|
||||
of devices, since this function will run very slowly in that case.
|
||||
If you explicitly specify devices, this warning will be supressed
|
||||
enabled (bool): if False, the RNG is not forked. This is a convenience
|
||||
argument for easily disabling the context manager without having
|
||||
to reindent your Python code.
|
||||
"""
|
||||
|
||||
import torch.cuda
|
||||
global _fork_rng_warned_already
|
||||
|
||||
# Internal arguments:
|
||||
# _caller: the function which called fork_rng, which the user used
|
||||
# _devices_kw: the devices keyword of _caller
|
||||
|
||||
if not enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
if devices is None:
|
||||
num_devices = torch.cuda.device_count()
|
||||
if num_devices > 1 and not _fork_rng_warned_already:
|
||||
warnings.warn(
|
||||
("CUDA reports that you have {num_devices} available devices, and you "
|
||||
"have used {caller} without explicitly specifying which devices are being used. "
|
||||
"For safety, we initialize *every* CUDA device by default, which "
|
||||
"can be quite slow if you have a lot of GPUs. If you know that you are only "
|
||||
"making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
|
||||
"or the '{devices_kw}' keyword argument of {caller} with the set of devices "
|
||||
"you are actually using. For example, if you are using CPU only, "
|
||||
"set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
|
||||
"GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
|
||||
"all devices and suppress this warning, set the '{devices_kw}' keyword argument "
|
||||
"to `range(torch.cuda.device_count())`."
|
||||
).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
|
||||
_fork_rng_warned_already = True
|
||||
devices = list(range(num_devices))
|
||||
else:
|
||||
# Protect against user passing us a generator; we need to traverse this
|
||||
# multiple times but a generator will be exhausted upon first traversal
|
||||
devices = list(devices)
|
||||
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
gpu_rng_states = []
|
||||
for device in devices:
|
||||
with torch.cuda.device(device):
|
||||
gpu_rng_states.append(torch.cuda.get_rng_state())
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
for device, gpu_rng_state in zip(devices, gpu_rng_states):
|
||||
with torch.cuda.device(device):
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
|
Reference in New Issue
Block a user