Files
pytorch/torch/_prims/rng_prims.py
eellison 481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00

390 lines
14 KiB
Python

# mypy: allow-untyped-defs
from typing import Optional
import torch
import torch.utils._pytree as pytree
from torch import _prims
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.types import _device, _dtype
def throw_on_non_cuda(device):
raise RuntimeError(
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
)
def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
rngprim_def = torch.library.custom_op(
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
)
rngprim_def.register_fake(impl_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
prim = prim_packet.default
if tags:
prim._tags = tags
for p in (prim_packet, prim):
p.__doc__ = doc
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
p.schema = name + schema
p.impl_aten = impl_aten
p.prim_meta_impl = impl_meta
# Philox rand offsets could be shared in future with other philox ops, so
# keeping these functions in global scope.
def philox_rand_offset_meta(
shape: torch.Size,
):
return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
def philox_rand_offset(
shape: torch.Size,
):
# For impl, look at the function calc_execution_policy in the file
# aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
# commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
numel_scalar = 1
for dim_size in shape:
numel_scalar *= dim_size
numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
block_size = 256
unroll = 4
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
return offset
def register_philox_rand():
name = "philox_rand"
schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950
def _philox_rand_meta(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
stride = make_contiguous_strides_for(shape)
random_values = _prims.TensorMeta(
shape=shape, strides=stride, dtype=dtype, device=device
)
offset = philox_rand_offset_meta(shape)
return (random_values, offset)
def _philox_rand(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
if device.type == "cpu":
devices = []
else:
devices = [device]
if device.type != "cuda":
raise throw_on_non_cuda(device)
with torch.random.fork_rng(devices):
CUDARngStateHelper.set_torch_state_tensor(seed, offset)
random_values = torch.rand(shape, device=device, dtype=dtype)
return random_values, philox_rand_offset(shape)
register_rng_prim(
name=name,
schema=schema,
impl_aten=_philox_rand,
impl_meta=_philox_rand_meta,
doc="Philox based stateless rand operator",
tags=(torch.Tag.nondeterministic_seeded,),
)
def get_device(args, kwargs):
if kwargs.get("device"):
device = kwargs.get("device")
if isinstance(device, str):
device = torch.device(device)
return device.type
devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
if any(dev == "cuda" for dev in devices):
return "cuda"
elif any(dev == "xpu" for dev in devices):
return "xpu"
elif any(dev == "hpu" for dev in devices):
return "hpu"
elif any(dev == "cpu" for dev in devices):
return "cpu"
return None
def register_run_and_save_rng_state_op():
class RunAndSaveRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_and_save_rng_state")
def __call__(self, op, *args, **kwargs):
return super().__call__(op, *args, **kwargs)
run_and_save_rng_state = RunAndSaveRngState()
run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
)
@run_and_save_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, **kwargs):
return torch.cuda.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(op, *args, **kwargs):
return torch.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.HPU)
def impl_hpu(op, *args, **kwargs):
if hasattr(torch, "hpu"):
return torch.hpu.get_rng_state(), op(*args, **kwargs)
raise RuntimeError("functionalize a hpu RNG operator is not supported.")
@run_and_save_rng_state.py_impl(DispatchKey.XPU)
def impl_xpu(op, *args, **kwargs):
return torch.xpu.get_rng_state(), op(*args, **kwargs)
@run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, **kwargs):
impl_map = {
"cuda": impl_cuda,
"cpu": impl_cpu,
"hpu": impl_hpu,
"xpu": impl_xpu,
}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, op, *args, **kwargs):
# Check device to call the right impl
with mode:
return impl_backend_select(op, *args, **kwargs)
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
return run_and_save_rng_state
def register_run_with_rng_state_op():
class RunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("run_with_rng_state")
def __call__(self, rng_state, op, *args, **kwargs):
return super().__call__(rng_state, op, *args, **kwargs)
run_with_rng_state = RunWithRngState()
run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(run_with_rng_state, deferred_error=True)
)
@run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(rng_state, op, *args, **kwargs):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
out = op(*args, **kwargs)
torch.cuda.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.CPU)
def impl_cpu(rng_state, op, *args, **kwargs):
current_state = torch.get_rng_state()
torch.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(DispatchKey.HPU)
def impl_hpu(rng_state, op, *args, **kwargs):
if hasattr(torch, "hpu"):
current_state = torch.hpu.get_rng_state()
torch.hpu.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.hpu.set_rng_state(current_state)
return out
raise RuntimeError("functionalize a hpu RNG operator is not supported.")
@run_with_rng_state.py_impl(DispatchKey.XPU)
def impl_xpu(rng_state, op, *args, **kwargs):
current_state = torch.xpu.get_rng_state()
torch.xpu.set_rng_state(rng_state)
out = op(*args, **kwargs)
torch.xpu.set_rng_state(current_state)
return out
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
# TODO: you don't need to do this, the dispatch here already disabled
# it
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):
impl_map = {
"cuda": impl_cuda,
"cpu": impl_cpu,
"hpu": impl_hpu,
"xpu": impl_xpu,
}
device = get_device(args, kwargs)
assert device in impl_map, f"Backend not supported for {device}"
impl = impl_map[device]
return impl(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
# Skip setting the set_rng_state as it does not work well with fake tensors.
# And it does not matter for the fake tensor mode.
with mode:
return op(*args, **kwargs)
@run_with_rng_state.py_functionalize_impl
def impl_functional(ctx, rng_state, op, *args, **kwargs):
unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = run_with_rng_state(
unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
)
return ctx.wrap_tensors(out)
return run_with_rng_state
run_and_save_rng_state = register_run_and_save_rng_state_op()
run_with_rng_state = register_run_with_rng_state_op()
def register_graphsafe_run_with_rng_state_op():
class GraphSafeRunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("graphsafe_run_with_rng_state")
def __call__(self, op, *args, rng_state=None, **kwargs):
return super().__call__(op, *args, rng_state=rng_state, **kwargs)
graphsafe_run_with_rng_state = GraphSafeRunWithRngState()
graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True)
)
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, rng_state=None, **kwargs):
device_idx = rng_state.device.index
generator = torch.cuda.default_generators[device_idx]
current_state = generator.graphsafe_get_state()
generator.graphsafe_set_state(rng_state)
out = op(*args, **kwargs)
generator.graphsafe_set_state(current_state)
return out
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, rng_state=None, **kwargs):
device = get_device(args, kwargs)
assert (
device == "cuda"
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs):
with mode:
return op(*args, **kwargs)
@graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs):
with disable_proxy_modes_tracing():
out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(
mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs}
)
out_proxy = mode.tracer.create_proxy(
"call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
@graphsafe_run_with_rng_state.py_functionalize_impl
def impl_functional(ctx, op, *args, rng_state=None, **kwargs):
unwrapped_rng_state = (
ctx.unwrap_tensors(rng_state) if rng_state is not None else None
)
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = graphsafe_run_with_rng_state(
op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs
)
return ctx.wrap_tensors(out)
return graphsafe_run_with_rng_state
graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op()
def register_rng_prims():
register_philox_rand()