mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Clone has performance issue according to f49c3eb6e6/megatron/core/tensor_parallel/random.py (L77-L80)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149283
Approved by: https://github.com/cyyever, https://github.com/Skylion007
		
	
		
			
				
	
	
		
			185 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			185 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| from collections.abc import Iterable
 | |
| from typing import Union
 | |
| 
 | |
| import torch
 | |
| from torch import Tensor
 | |
| 
 | |
| from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
 | |
| 
 | |
| 
 | |
| __all__ = [
 | |
|     "get_rng_state",
 | |
|     "get_rng_state_all",
 | |
|     "set_rng_state",
 | |
|     "set_rng_state_all",
 | |
|     "manual_seed",
 | |
|     "manual_seed_all",
 | |
|     "seed",
 | |
|     "seed_all",
 | |
|     "initial_seed",
 | |
| ]
 | |
| 
 | |
| 
 | |
| def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
 | |
|     r"""Return the random number generator state of the specified GPU as a ByteTensor.
 | |
| 
 | |
|     Args:
 | |
|         device (torch.device or int, optional): The device to return the RNG state of.
 | |
|             Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
 | |
| 
 | |
|     .. warning::
 | |
|         This function eagerly initializes CUDA.
 | |
|     """
 | |
|     _lazy_init()
 | |
|     if isinstance(device, str):
 | |
|         device = torch.device(device)
 | |
|     elif isinstance(device, int):
 | |
|         device = torch.device("cuda", device)
 | |
|     idx = device.index
 | |
|     if idx is None:
 | |
|         idx = current_device()
 | |
|     default_generator = torch.cuda.default_generators[idx]
 | |
|     return default_generator.get_state()
 | |
| 
 | |
| 
 | |
| def get_rng_state_all() -> list[Tensor]:
 | |
|     r"""Return a list of ByteTensor representing the random number states of all devices."""
 | |
|     results = [get_rng_state(i) for i in range(device_count())]
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def set_rng_state(
 | |
|     new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
 | |
| ) -> None:
 | |
|     r"""Set the random number generator state of the specified GPU.
 | |
| 
 | |
|     Args:
 | |
|         new_state (torch.ByteTensor): The desired state
 | |
|         device (torch.device or int, optional): The device to set the RNG state.
 | |
|             Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
 | |
|     """
 | |
|     if not is_initialized():
 | |
|         with torch._C._DisableFuncTorch():
 | |
|             # Clone the state because the callback will be triggered
 | |
|             # later when CUDA is lazy initialized.
 | |
|             new_state = new_state.clone(memory_format=torch.contiguous_format)
 | |
|     if isinstance(device, str):
 | |
|         device = torch.device(device)
 | |
|     elif isinstance(device, int):
 | |
|         device = torch.device("cuda", device)
 | |
| 
 | |
|     def cb():
 | |
|         idx = device.index
 | |
|         if idx is None:
 | |
|             idx = current_device()
 | |
|         default_generator = torch.cuda.default_generators[idx]
 | |
|         default_generator.set_state(new_state)
 | |
| 
 | |
|     _lazy_call(cb)
 | |
| 
 | |
| 
 | |
| def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
 | |
|     r"""Set the random number generator state of all devices.
 | |
| 
 | |
|     Args:
 | |
|         new_states (Iterable of torch.ByteTensor): The desired state for each device.
 | |
|     """
 | |
|     for i, state in enumerate(new_states):
 | |
|         set_rng_state(state, i)
 | |
| 
 | |
| 
 | |
| def manual_seed(seed: int) -> None:
 | |
|     r"""Set the seed for generating random numbers for the current GPU.
 | |
| 
 | |
|     It's safe to call this function if CUDA is not available; in that
 | |
|     case, it is silently ignored.
 | |
| 
 | |
|     Args:
 | |
|         seed (int): The desired seed.
 | |
| 
 | |
|     .. warning::
 | |
|         If you are working with a multi-GPU model, this function is insufficient
 | |
|         to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
 | |
|     """
 | |
|     seed = int(seed)
 | |
| 
 | |
|     def cb():
 | |
|         idx = current_device()
 | |
|         default_generator = torch.cuda.default_generators[idx]
 | |
|         default_generator.manual_seed(seed)
 | |
| 
 | |
|     _lazy_call(cb, seed=True)
 | |
| 
 | |
| 
 | |
| def manual_seed_all(seed: int) -> None:
 | |
|     r"""Set the seed for generating random numbers on all GPUs.
 | |
| 
 | |
|     It's safe to call this function if CUDA is not available; in that
 | |
|     case, it is silently ignored.
 | |
| 
 | |
|     Args:
 | |
|         seed (int): The desired seed.
 | |
|     """
 | |
|     seed = int(seed)
 | |
| 
 | |
|     def cb():
 | |
|         for i in range(device_count()):
 | |
|             default_generator = torch.cuda.default_generators[i]
 | |
|             default_generator.manual_seed(seed)
 | |
| 
 | |
|     _lazy_call(cb, seed_all=True)
 | |
| 
 | |
| 
 | |
| def seed() -> None:
 | |
|     r"""Set the seed for generating random numbers to a random number for the current GPU.
 | |
| 
 | |
|     It's safe to call this function if CUDA is not available; in that
 | |
|     case, it is silently ignored.
 | |
| 
 | |
|     .. warning::
 | |
|         If you are working with a multi-GPU model, this function will only initialize
 | |
|         the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
 | |
|     """
 | |
| 
 | |
|     def cb():
 | |
|         idx = current_device()
 | |
|         default_generator = torch.cuda.default_generators[idx]
 | |
|         default_generator.seed()
 | |
| 
 | |
|     _lazy_call(cb)
 | |
| 
 | |
| 
 | |
| def seed_all() -> None:
 | |
|     r"""Set the seed for generating random numbers to a random number on all GPUs.
 | |
| 
 | |
|     It's safe to call this function if CUDA is not available; in that
 | |
|     case, it is silently ignored.
 | |
|     """
 | |
| 
 | |
|     def cb():
 | |
|         random_seed = 0
 | |
|         seeded = False
 | |
|         for i in range(device_count()):
 | |
|             default_generator = torch.cuda.default_generators[i]
 | |
|             if not seeded:
 | |
|                 default_generator.seed()
 | |
|                 random_seed = default_generator.initial_seed()
 | |
|                 seeded = True
 | |
|             else:
 | |
|                 default_generator.manual_seed(random_seed)
 | |
| 
 | |
|     _lazy_call(cb)
 | |
| 
 | |
| 
 | |
| def initial_seed() -> int:
 | |
|     r"""Return the current random seed of the current GPU.
 | |
| 
 | |
|     .. warning::
 | |
|         This function eagerly initializes CUDA.
 | |
|     """
 | |
|     _lazy_init()
 | |
|     idx = current_device()
 | |
|     default_generator = torch.cuda.default_generators[idx]
 | |
|     return default_generator.initial_seed()
 |