mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			267 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			267 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-decorators
 | |
| # mypy: allow-untyped-defs
 | |
| import functools
 | |
| from collections import defaultdict
 | |
| from collections.abc import Callable
 | |
| 
 | |
| import torch
 | |
| import torch._decomp as decomp
 | |
| from torch._decomp import get_decompositions
 | |
| from torch._ops import OpOverload
 | |
| 
 | |
| 
 | |
| aten = torch.ops.aten
 | |
| 
 | |
| rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict)
 | |
| 
 | |
| 
 | |
| def register_rng_decomposition(aten_op):
 | |
|     return decomp.register_decomposition(aten_op, rng_decompositions)
 | |
| 
 | |
| 
 | |
| def throw_on_non_cuda(device):
 | |
|     raise RuntimeError(
 | |
|         f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
 | |
|         f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
 | |
|         "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
 | |
|     )
 | |
| 
 | |
| 
 | |
| # TODO - We have to register many more distributions here, and also higher level
 | |
| # ops like dropout which have fused implementation and can hide the rand inside.
 | |
| @register_rng_decomposition(aten.rand)
 | |
| def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
 | |
|     if device and device.type != "cuda":
 | |
|         throw_on_non_cuda(device)
 | |
|     seed, offset = PhiloxStateTracker.get_state_as_tuple()
 | |
|     dtype = dtype or torch.float32
 | |
|     out, offset_jump = torch.ops.rngprims.philox_rand(
 | |
|         shape, seed, offset, None, device, dtype
 | |
|     )
 | |
|     PhiloxStateTracker.advance_offset(offset_jump)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| @register_rng_decomposition(aten.rand_like)
 | |
| def rand_like(
 | |
|     x: torch.Tensor,
 | |
|     dtype=None,
 | |
|     layout=None,
 | |
|     device=None,
 | |
|     pin_memory=False,
 | |
|     memory_format=torch.preserve_format,
 | |
| ):
 | |
|     device = device or x.device
 | |
|     if device.type != "cuda":
 | |
|         throw_on_non_cuda(device)
 | |
|     dtype = dtype or x.dtype
 | |
|     seed, offset = PhiloxStateTracker.get_state_as_tuple()
 | |
|     out, offset_jump = torch.ops.rngprims.philox_rand(
 | |
|         x.shape, seed, offset, None, device, dtype
 | |
|     )
 | |
|     PhiloxStateTracker.advance_offset(offset_jump)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| class PhiloxState:
 | |
|     """
 | |
|     Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
 | |
|     relative_offset. seed and base_offset basically point to the rng state just
 | |
|     before tracing starts. relative offset tracks the totally consumed offset at
 | |
|     trace time.
 | |
|     """
 | |
| 
 | |
|     def __init__(self) -> None:
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.seed = torch.tensor(())
 | |
|         self.base_offset = torch.tensor(())
 | |
|         self.relative_offset = 0
 | |
|         self.offset_advanced_alteast_once = False
 | |
| 
 | |
|     def validate_state(self):
 | |
|         assert self.seed.numel() != 0 and self.base_offset.numel() != 0
 | |
| 
 | |
|     def advance_offset(self, consumed_offset):
 | |
|         self.offset_advanced_alteast_once = True
 | |
|         self.relative_offset = self.relative_offset + consumed_offset
 | |
| 
 | |
|     def set_state(self, seed, base_offset, relative_offset=0):
 | |
|         self.seed = seed
 | |
|         self.base_offset = base_offset
 | |
|         self.relative_offset = relative_offset
 | |
| 
 | |
|     def get_state_as_tuple(self):
 | |
|         self.validate_state()
 | |
|         return (self.seed, self.base_offset + self.relative_offset)
 | |
| 
 | |
|     def get_state_as_tensor(self):
 | |
|         # Only needed because we override get_rng_state.
 | |
|         self.validate_state()
 | |
|         return torch.stack([self.seed, self.base_offset + self.relative_offset])
 | |
| 
 | |
|     def set_state_from_tensor(self, state):
 | |
|         # Only needed because we override set_rng_state.
 | |
|         self.seed, self.base_offset = torch.unbind(state)
 | |
|         self.relative_offset = 0
 | |
| 
 | |
| 
 | |
| class PhiloxStateTracker:
 | |
|     """
 | |
|     Singleton class to track the philox rng state during AOT Autograd tracing.
 | |
|     For each aot tracing instance, AOT Autograd resets this tracker and keeps
 | |
|     track of both forward and backward offsets. At runtime, we only care about
 | |
|     the total consumed forward and backward offsets. For dynamic shapes, these
 | |
|     offsets are a function of input shapes. Therefore, the AOT generated graphs
 | |
|     have additional outputs that compute total consumed forward and backward
 | |
|     offsets.
 | |
|     """
 | |
| 
 | |
|     running_state: PhiloxState
 | |
|     fwd_state: PhiloxState
 | |
|     bwd_state: PhiloxState
 | |
| 
 | |
|     def __enter__(self):
 | |
|         PhiloxStateTracker.reset()
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_cal, exc_tb):
 | |
|         PhiloxStateTracker.reset()
 | |
| 
 | |
|     @classmethod
 | |
|     def reset(cls):
 | |
|         cls.running_state = PhiloxState()
 | |
|         cls.fwd_state = PhiloxState()
 | |
|         cls.bwd_state = PhiloxState()
 | |
| 
 | |
|     @classmethod
 | |
|     def mark_beginning_of_forward(cls):
 | |
|         # Tells the tracker to use fwd_state as the running state
 | |
|         cls.running_state = cls.fwd_state
 | |
| 
 | |
|     @classmethod
 | |
|     def mark_beginning_of_backward(cls):
 | |
|         # Tells the tracker to use bwd_state as the running state
 | |
|         cls.running_state = cls.bwd_state
 | |
| 
 | |
|     @classmethod
 | |
|     def record_state(cls, seed, offset, mode):
 | |
|         # Records the seed and offset tensors. These tensors are used to invoke
 | |
|         # the philox_rand functional primitives.
 | |
|         if mode == "forward":
 | |
|             cls.fwd_state.set_state(seed, offset)
 | |
|             cls.mark_beginning_of_forward()
 | |
|         else:
 | |
|             assert mode == "backward"
 | |
|             cls.bwd_state.set_state(seed, offset)
 | |
| 
 | |
|     @classmethod
 | |
|     def get_state_as_tensor(cls):
 | |
|         # The only reason this exists is because we override get_rng_state and
 | |
|         # set_rng_state during tracing. get_rng_state expects a tensor output,
 | |
|         # so return (seed, offset) tuple upset other parts of the program like
 | |
|         # ctx.saved_tensors.
 | |
| 
 | |
|         # A bad consequence is that if user saves and restores rng state, we
 | |
|         # have little bit of ugliness in the generated code, where we first
 | |
|         # concat the (seed, offset) to create a tensor for get_rng_state, and
 | |
|         # then split it back to get (seed, offset) tuple in set_rng_state.
 | |
| 
 | |
|         # TODO: Investigate if there is be a better way to wrap the tuple in a
 | |
|         # false Tensor object, and then desugar it later on.
 | |
|         return cls.running_state.get_state_as_tensor()
 | |
| 
 | |
|     @classmethod
 | |
|     def get_state_as_tuple(cls):
 | |
|         return cls.running_state.get_state_as_tuple()
 | |
| 
 | |
|     @classmethod
 | |
|     def set_state_from_tensor(cls, x):
 | |
|         # This is only needed because we override set_rng_state. Look at the
 | |
|         # comment in get_state_from_tensor method.
 | |
|         cls.running_state.set_state_from_tensor(x)
 | |
| 
 | |
|     @classmethod
 | |
|     def advance_offset(cls, consumed_offset):
 | |
|         cls.running_state.advance_offset(consumed_offset)
 | |
| 
 | |
|     @classmethod
 | |
|     def get_current_relative_offset(cls):
 | |
|         return cls.running_state.relative_offset
 | |
| 
 | |
|     @staticmethod
 | |
|     def multiple_of_4(offset):
 | |
|         # torch cuda rng state offset must be a multiple of 4. For inductor, as
 | |
|         # we sum up all the numel, the result might not be a multiple of 4. This
 | |
|         # method achieves that.
 | |
|         return (offset + 3) // 4 * 4
 | |
| 
 | |
|     @classmethod
 | |
|     def get_updated_fwd_offset(cls):
 | |
|         # Short circuit if no rand ops were observed
 | |
|         if not cls.fwd_state.offset_advanced_alteast_once:
 | |
|             return cls.fwd_state.base_offset
 | |
|         return cls.multiple_of_4(
 | |
|             cls.fwd_state.base_offset + cls.fwd_state.relative_offset
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def get_updated_bwd_offset(cls):
 | |
|         # Short circuit if no rand ops were observed
 | |
|         if not cls.bwd_state.offset_advanced_alteast_once:
 | |
|             return cls.bwd_state.base_offset
 | |
|         return cls.multiple_of_4(
 | |
|             cls.bwd_state.base_offset + cls.bwd_state.relative_offset
 | |
|         )
 | |
| 
 | |
| 
 | |
| # Adding more decompositions which eventually use rand_like inside decomps.
 | |
| # Adding these in rng_decompositions ensures the functionalization of rand_like
 | |
| # ops used in these decomps. The list is copied from inductor codebase, which
 | |
| # uses it for similar purpose.
 | |
| #
 | |
| # Caution - These decomps do not have same accuracy as that of eager. However,
 | |
| # we can't just disable them with a config flag like fallback_random, because
 | |
| # for functionalization of rng ops, we have to decompose these ops.
 | |
| extra_random_decomps = get_decompositions(
 | |
|     [
 | |
|         aten.cauchy,
 | |
|         aten.cauchy_,
 | |
|         aten.exponential,
 | |
|         aten.exponential_,
 | |
|         aten.geometric,
 | |
|         aten.geometric_,
 | |
|         aten.native_dropout,
 | |
|         aten.normal,
 | |
|         aten.normal_,
 | |
|         aten.normal_functional,
 | |
|         aten.log_normal,
 | |
|         aten.log_normal_,
 | |
|         aten.rrelu_with_noise,
 | |
|         aten.rrelu_with_noise_,
 | |
|         aten.uniform_,
 | |
|     ]
 | |
| )
 | |
| register_extra_random_decomp = functools.partial(
 | |
|     decomp.register_decomposition, registry=extra_random_decomps
 | |
| )
 | |
| 
 | |
| 
 | |
| @register_extra_random_decomp([aten.bernoulli_])
 | |
| def bernoulli_(self, p=0.5):
 | |
|     if self.device == torch.device("cpu"):
 | |
|         return NotImplemented
 | |
|     return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
 | |
| 
 | |
| 
 | |
| @register_extra_random_decomp([aten.bernoulli.p])
 | |
| def bernoulli_p(self, p=0.5, *, generator=None):
 | |
|     if self.device == torch.device("cpu"):
 | |
|         return NotImplemented
 | |
|     assert generator is None
 | |
|     return torch.rand_like(self, dtype=torch.float32) < p
 | |
| 
 | |
| 
 | |
| rng_decompositions.update(extra_random_decomps)  # type: ignore[arg-type]
 |