Support torch.compile rng selective activation checkpointing with cudagraph (#146878)

TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
This commit is contained in:
eellison
2025-02-27 07:39:55 -08:00
committed by PyTorch MergeBot
parent c6d1038aaa
commit 481a57bc37
20 changed files with 920 additions and 69 deletions

View File

@ -54,6 +54,8 @@ def count_ops(
return node.args[0] == op
elif node.name == "run_with_rng_state":
return node.args[1] == op
elif node.name == "graphsafe_run_with_rng_state":
return node.args[0] == op
return False
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
@ -1018,6 +1020,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
for preserve_rng_state in [True, False]:

View File

@ -4,13 +4,17 @@ import contextlib
import functools
import gc
import importlib
import itertools
import sys
import unittest
import warnings
from collections import defaultdict
from collections.abc import Mapping, Sequence
import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode
from torch._dynamo.utils import counters
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor import config
@ -19,7 +23,9 @@ from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch._inductor.cudagraph_utils import FunctionID
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._ops import OpOverload
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.immutable_collections import immutable_dict
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
@ -32,6 +38,7 @@ from torch.testing._internal.common_utils import (
TEST_CUDA_GRAPH,
TEST_WITH_ASAN,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
@ -46,7 +53,7 @@ if IS_WINDOWS and IS_CI:
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
@ -2504,7 +2511,407 @@ if HAS_CUDA and not TEST_WITH_ASAN:
eager_result = f(example_input)
self.assertEqual(compiled_result, eager_result)
class TestSAC(TestCase):
def _make_observer_mode(self):
class ObserverMode(TorchDispatchMode):
def __init__(self):
super().__init__()
self.curr_run = 0
self.op_outputs = defaultdict(list)
def __torch_dispatch__(
self,
func: OpOverload,
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Mapping[str, object] = immutable_dict(),
) -> object:
return func(*args, **kwargs)
return ObserverMode
def test_simple(self):
device = "cuda"
from torch._prims.rng_prims import graphsafe_run_with_rng_state
ObserverMode = self._make_observer_mode()
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[op].append(out)
return out
obs = ObserverMode()
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
for _ in range(2):
torch._dynamo.reset()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
fn(x, y).sum().backward()
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4)
for i in range(2):
self.assertEqual(
obs.op_outputs[aten.rand.default][0 + 2 * i],
obs.op_outputs[aten.rand.default][1 + 2 * i],
)
self.assertNotEqual(
obs.op_outputs[aten.rand.default][0],
obs.op_outputs[aten.rand.default][2],
)
def test_cudagraph_uneven_forward_backward(self):
# torch.compile cudagraphs are difficult to test
# the rng updating bc is sensitive to duration of pending backwards, etc.
# this is a short repro to mimic the runtime wrappers integration
# and show that updating the backward rng state with cudagraphs works:
def forward():
state = torch.cuda.get_rng_state()
perm = torch.randperm(10, device="cuda")
return state, perm
def backward(rng_state):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
perm = torch.randperm(10, device="cuda")
torch.cuda.set_rng_state(current_state)
return perm
def normal_test():
state, perm = forward()
repro_perm = backward(state)
return perm, repro_perm
def graphsafe_forward():
perm = torch.randperm(10, device="cuda")
return perm
def graphsafe_backward(generator, new_state):
current_state = generator.graphsafe_get_state()
generator.graphsafe_set_state(new_state)
perm = torch.randperm(10, device="cuda")
generator.graphsafe_set_state(current_state)
return perm
def graph_test(generator, capture_cuda_graph):
if capture_cuda_graph:
graph = torch.cuda.CUDAGraph()
# state should be cloned before the graph
old_state = generator.graphsafe_get_state()
new_state = old_state.clone_state()
if capture_cuda_graph:
# state should be register to the graph
graph.register_generator_state(new_state)
# only capturing the backward
with torch.cuda.graph(graph):
repro_perm = graphsafe_backward(generator, new_state)
# some number of uneven forwards
graphsafe_forward()
graphsafe_forward()
graphsafe_forward()
# state prior to rng invocation
state = generator.get_state()
perm = graphsafe_forward()
new_state.set_state(state)
if capture_cuda_graph:
graph.replay()
else:
repro_perm = graphsafe_backward(generator, new_state)
return perm, repro_perm
self.assertEqual(*normal_test())
generator = torch.cuda.default_generators[0]
self.assertEqual(*graph_test(generator, capture_cuda_graph=False))
self.assertEqual(*graph_test(generator, capture_cuda_graph=True))
def test_cpu_and_cuda_rng(self):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import (
graphsafe_run_with_rng_state,
run_and_save_rng_state,
run_with_rng_state,
)
for hop in [
graphsafe_run_with_rng_state,
run_and_save_rng_state,
run_with_rng_state,
]:
def make_impl(hop):
@hop.py_impl(ObserverMode)
def _(mode, *args, **kwargs):
with no_dispatch():
out = hop(*args, **kwargs)
op = None
for inp in itertools.chain(args, kwargs.values()):
if isinstance(inp, torch._ops.OpOverload):
op = inp
break
assert op is not None
if hop is run_and_save_rng_state:
mode.op_outputs[op].append(out[1])
else:
mode.op_outputs[op].append(out)
return out
make_impl(hop)
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def gn2(x):
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
def fn(x, y, z):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True)
return x * z.cuda()
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
z = torch.randn(4, 4, requires_grad=True)
fn(x, y, z).sum().backward()
for op in [aten.rand.default, aten.randperm.default]:
self.assertEqual(len(obs.op_outputs[op]), 2)
self.assertEqual(
obs.op_outputs[op][0],
obs.op_outputs[op][1],
)
self.assertEqual(
obs.op_outputs[op][0].device.type,
"cpu" if op == aten.randperm.default else "cuda",
)
@parametrize("order", (list(itertools.permutations([0, 1, 2]))))
def test_uneven_forward_backward(self, order):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import graphsafe_run_with_rng_state
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[(mode.curr_run, op)].append(out)
return out
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def gn2(x):
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True)
return x
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn_c = torch.compile(fn, backend=aot_eager_decomp_partition)
torch.manual_seed(0)
outs = []
for i in range(len(order)):
obs.curr_run = i
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
outs.append(fn_c(x, y))
for idx in order:
obs.curr_run = idx
outs[idx].sum().backward()
for run in range(len(order)):
for op in (aten.rand.default, aten.randperm.default):
self.assertEqual(len(obs.op_outputs[(run, op)]), 2)
self.assertEqual(
obs.op_outputs[(run, op)][0],
obs.op_outputs[(run, op)][1],
)
if run != 0:
self.assertNotEqual(
obs.op_outputs[(run - 1, op)][0],
obs.op_outputs[(run, op)][0],
)
@config.patch(fallback_random=True)
@config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True)
def _test_cudagraphs_aot_eager_compat_equal(self, device):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
outs = []
grads = []
outs2 = []
grads2 = []
compile_fns = [
lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"),
lambda fn: torch.compile(fn, mode="reduce-overhead"),
]
for i, compile_fn in enumerate(compile_fns):
torch.manual_seed(0)
for index in range(3):
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
out = compile_fn(fn)(x, y)
torch.cuda.synchronize()
out.sum().backward()
if i == 0:
outs.append(out.clone())
grads.append((x.grad.clone(), y.grad.clone()))
else:
outs2.append(out.clone())
grads2.append((x.grad.clone(), y.grad.clone()))
self.assertEqual(outs, outs2)
self.assertEqual(grads, grads2)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
def test_cudagraphs_aot_eager_compat_equal(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
@requires_multigpu()
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
@requires_multigpu()
def test_multi_device(self):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
def multi_fn(x, y, a, b):
return fn(x, y), fn(a, b)
x = torch.randn(4, 4, device="cuda:0", requires_grad=True)
y = torch.randn(4, 4, device="cuda:0", requires_grad=True)
a = torch.randn(4, 4, device="cuda:1", requires_grad=True)
b = torch.randn(4, 4, device="cuda:1", requires_grad=True)
# No errors. TODO - get graphs from logging, couldnt figure out how
multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition")
out = multi_fn_c(x, y, a, b)
out[0].sum().backward()
def test_retain_graph(self):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import graphsafe_run_with_rng_state
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[op].append(out)
return out
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
out = fn(x, y).sum()
out.backward(retain_graph=True)
out.backward()
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3)
self.assertEqual(
obs.op_outputs[aten.rand.default][0],
obs.op_outputs[aten.rand.default][1],
)
self.assertEqual(
obs.op_outputs[aten.rand.default][1],
obs.op_outputs[aten.rand.default][2],
)
instantiate_parametrized_tests(CudaGraphTreeTests)
instantiate_parametrized_tests(TestSAC)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
@ -2514,5 +2921,5 @@ if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("cuda graph test is skipped")
if HAS_CPU or HAS_CUDA:
if HAS_CUDA:
run_tests(needs="filelock")

View File

@ -62,6 +62,7 @@ class TestHOPInfra(TestCase):
FIXME_ALLOWLIST = {
"autograd_function_apply",
"run_with_rng_state",
"graphsafe_run_with_rng_state",
"map_impl",
"_export_tracepoint",
"run_and_save_rng_state",

View File

@ -66,6 +66,7 @@ from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
from .utils import (
_get_symint_hints,
contain_metadata_mutation_ops,
get_cuda_generator_meta_val,
make_boxed_func,
strict_zip,
unlift_tokens,
@ -458,9 +459,20 @@ def aot_dispatch_autograd(
fake_mode = detect_fake_mode()
if fake_mode is not None and fake_mode.shape_env is not None:
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
rng_states = [
n
for n in fw_module.graph.find_nodes(op="placeholder")
if "fwd_rng_state" in n.name
]
fw_metadata.num_graphsafe_rng_states = len(rng_states)
if rng_states:
fw_metadata.graphsafe_rng_state_index = (
rng_states[0].meta["val"].device.index
)
# See Note [Side-Effectful Tokens in AOTAutograd]
if config.unlift_effect_tokens and (
@ -684,6 +696,16 @@ def aot_dispatch_autograd(
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
return_new_outs=False
)
if rng_states:
index = fw_metadata.graphsafe_rng_state_index
assert index is not None
rng_states = [
get_cuda_generator_meta_val(index)
for _ in range(fw_metadata.num_graphsafe_rng_states)
]
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
(
fw_module,
adjusted_flat_args,

View File

@ -1684,6 +1684,45 @@ def _backward_prologue_functional(
return all_args
def initialize_rng_states(
num_rng: int,
graphsafe_idx: int,
fwd_rng_states: list[torch.Generator],
bwd_rng_states: list[torch.Generator],
):
"""
Initialize the cudagraph safe rng states.
Initialization of rng states should have a few properties:
- the initialization for each rng state should be independent
- the initialization should be deterministic
- the initialization should be based off current rng state, so that independent graphs do not
have equal rng behavior
We defer initialization of rng states until runtime because compilation is wrapped
with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations
do not give equal randomness.
"""
with torch.utils._python_dispatch._disable_current_modes():
seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu")
fwd_rng_states.extend(
[
torch.cuda.default_generators[graphsafe_idx]
.clone_state()
.manual_seed(int(seeds[i]))
for i in range(num_rng)
]
)
bwd_rng_states.extend(
[
torch.cuda.default_generators[graphsafe_idx]
.clone_state()
.manual_seed(int(seeds[i]))
for i in range(num_rng)
]
)
# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants.
def _backward_epilogue_functional(
metadata, maybe_subclass_metadata, out, *, make_subclass_override=None
@ -1819,6 +1858,34 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
fw_metadata: ViewAndMutationMeta, # runtime metadata
try_save_cache_entry: Optional[Callable], # Save cache entry after compilation
):
# For additional context see Note [CUDA Graph Safe RNG Functionalization]
# Each pair forward, backward rng states must be equal prior to its invocation on any
# iteration of forward, backward. Because they are initialized equal, and are computing the same rng op,
# running forward then backward advances them the same amount and keeps them equal.
# However, a user may invoke multiple forwards, then backwards, such that they are not in sync.
# Initially we have:
# fwd_state0 == bwd_state0.
# Lets say we run:
# fwd0: fwd_state0 -> fwd_state1
# fwd1: fwd_state1 -> fwd_state2
# fwd2: fwd_state2 -> fwd_state3
# If we now invoke bwd2,
# we need to update bwd_state equal to the rng that was observed in fwd2.
# we save the rng_state fwd_state2 in forward because we detect that it is not the
# current backward state and therefore would not be accessible if we do not save it.
# Similarly, if we are going to update the backward state to a new value, and there is a pending
# forwards which needs its current state, we will save it.
# Within the autograd context, we keep track of the curr iteration so that on backward
# we know what the generator state must be before the backward is run.
num_rng = fw_metadata.num_graphsafe_rng_states
graphsafe_idx = fw_metadata.graphsafe_rng_state_index
fwd_rng_states: list[torch.Generator] = []
bwd_rng_states: list[torch.Generator] = []
curr_fwd_iter = itertools.count(0)
backward_state_position = 0
pending_forwards: set[int] = set()
saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {}
class CompiledFunction(torch.autograd.Function):
compiled_fw = compiled_fw_func
compiled_bw = compiled_bw_func
@ -1840,6 +1907,26 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
assert isinstance(bw_state, BackwardState)
ctx._compiled_autograd_backward_state = bw_state
if num_rng:
if len(fwd_rng_states) == 0:
assert graphsafe_idx is not None
initialize_rng_states(
num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states
)
_curr_iter = next(curr_fwd_iter)
ctx._curr_iter = _curr_iter
# if this state is not contained in the backward,
# we need to save it for when its backward pass happens
if _curr_iter != backward_state_position:
saved_backward_tensor_states[_curr_iter] = [
rng_state.get_state() for rng_state in fwd_rng_states
]
pending_forwards.add(_curr_iter)
args = (*args, *fwd_rng_states)
# There is a pretty complicated calling convention around what the compiled fw returns.
# The full list of outputs and their relative order is:
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
@ -1967,6 +2054,45 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
*flat_args,
)
if num_rng:
nonlocal backward_state_position, bwd_rng_states
curr_backward_iter = ctx._curr_iter
retain_graph = (
torch._C._autograd._get_current_graph_task_keep_graph()
)
# Save current state if we have a pending forward that needs this state
# or this state may be needed again because of retain graph
if (
backward_state_position in pending_forwards
and backward_state_position not in saved_backward_tensor_states
and (
backward_state_position != curr_backward_iter
or retain_graph
)
):
saved_backward_tensor_states[backward_state_position] = [
rng_state.get_state() for rng_state in bwd_rng_states
]
# Restore saved states if needed
if curr_backward_iter in saved_backward_tensor_states:
if backward_state_position != curr_backward_iter:
for bwd_state, saved_state in zip(
bwd_rng_states,
saved_backward_tensor_states[curr_backward_iter],
):
bwd_state.set_state(saved_state)
if not retain_graph:
del saved_backward_tensor_states[curr_backward_iter]
else:
assert backward_state_position == curr_backward_iter
backward_state_position = curr_backward_iter + 1
if not retain_graph:
pending_forwards.remove(curr_backward_iter)
all_args.extend(bwd_rng_states)
def impl_fn(double_ctx=None):
out = CompiledFunction._backward_impl(ctx, all_args)
return _backward_epilogue_functional(

View File

@ -420,11 +420,16 @@ class ViewAndMutationMeta:
# Filled after tracing joint function.
num_backward_tokens: int = 0
# Number of rng states that will get thread into the forward and backward for
# cudagraph compatible run_and_save_rng
num_graphsafe_rng_states: int = 0
graphsafe_rng_state_index: Optional[int] = None
def __post_init__(self):
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.
mutated_inp_runtime_indices = [
i
for i, m in enumerate(self.input_info)

View File

@ -489,3 +489,14 @@ def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool:
):
return True
return False
def get_cuda_generator_meta_val(device_idx: int):
"""
Get a generator value to use as a meta val
newly cloned generator will not contain tensors. it is only Generators that are
registered to a CUDAGraph that contain tensors. since this does not contain Tensor
it is fine to use in the meta.
"""
return torch.cuda.default_generators[device_idx].clone_state()

View File

@ -210,6 +210,10 @@ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
# real tensor outputs.
generate_fake_kernels_from_real_mismatches = False
# CUDAGraph save run_with_rng functionalization.
# TODO: turn on by default
graphsafe_rng_functionalization = True
# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests

View File

@ -40,7 +40,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import is_with_effects
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target
@ -622,6 +622,99 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
return new_gm
def apply_graphsafe_rng_functionalization(
fw_module: torch.fx.GraphModule,
bw_module: torch.fx.GraphModule,
fw_node: torch.fx.Node,
bw_node: torch.fx.Node,
device: torch.device,
rng_count: int,
last_fwd_input: torch.fx.Node,
last_bwd_input: torch.fx.Node,
):
"""
Note [CUDA Graph Safe RNG Functionalization]
CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values,
while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a
CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer
to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator
(and its cuda-tensor RNG state during graph capture).
For each RNG operation's forward/backward pair:
- We create two generators initialized with identical values
- Each forward and backward call advances its respective generator equally
- This keeps generators synchronized so forward and backward operations use matching RNG values
When forward is called multiple times before backward (causing desynchronization):
- We save the forward RNG state
- We update the backward Generator's state before executing backward
Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator
changes are reflected during replay.
This function modifies both forward and backward computation graphs by:
Creating RNG state placeholders for both passes
Updating the forward node to use graph-safe RNG state
Updating the backward node to use graph-safe RNG state
For more details: https://github.com/pytorch/pytorch/issues/113541
"""
device_idx = device.index
assert device_idx is not None
fw_graph = fw_module.graph
bw_graph = bw_module.graph
graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state
# Handle forward pass
# Note: [Generator arguments in AOTDispatcher]
# Generator arguments in AOTDispatcher are added to support graphsafe rng
# functionalization. See note above [CUDA Graph Safe RNG Functionalization]
with fw_module.graph.inserting_after(last_fwd_input):
fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}")
fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
last_fwd_input = fwd_rng_state
# Handle backward pass
with bw_module.graph.inserting_after(last_bwd_input):
bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}")
# as above, clone so that meta val generator will not contain tensors
bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
last_bwd_input = bwd_rng_state
# Update forward node
fw_kwargs = dict(fw_node.kwargs)
fw_kwargs["rng_state"] = fwd_rng_state
with fw_module.graph.inserting_after(fw_node):
functional_fw_node = fw_graph.create_node(
"call_function",
graphsafe_run_with_rng_state,
args=(fw_node.target, *fw_node.args), # type: ignore[arg-type]
kwargs=fw_kwargs,
)
fw_node.replace_all_uses_with(functional_fw_node)
fw_graph.erase_node(fw_node)
# Update backward node
bwd_kwargs = dict(bw_node.kwargs)
bwd_kwargs["rng_state"] = bwd_rng_state
with bw_graph.inserting_before(bw_node):
rng_output = bw_graph.create_node(
"call_function",
graphsafe_run_with_rng_state,
args=(bw_node.target, *bw_node.args), # type: ignore[arg-type]
kwargs=bwd_kwargs,
)
bw_node.replace_all_uses_with(rng_output)
bw_graph.erase_node(bw_node)
return last_fwd_input, last_bwd_input
def functionalize_rng_ops(
joint_module: fx.GraphModule,
fw_module: fx.GraphModule,
@ -660,7 +753,7 @@ def functionalize_rng_ops(
random_nodes[node.name] = node
return random_nodes
def get_device(node):
def get_device(node) -> Optional[torch.device]:
"""
Check the example value of the node outputs to find the device type.
"""
@ -674,12 +767,12 @@ def functionalize_rng_ops(
for candidate in candidates:
if isinstance(candidate, torch.Tensor):
if candidate.device.type == "cuda":
return "cuda"
return candidate.device
return "cpu"
return torch.device("cpu")
def get_sample_rng_state(device):
if device == "cuda":
def get_sample_rng_state(device: Optional[torch.device]):
if device is not None and device.type == "cuda":
return torch.cuda.get_rng_state()
return torch.get_rng_state()
@ -701,6 +794,7 @@ def functionalize_rng_ops(
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
bw_tangent_start_node = None
for node in bw_module.graph.find_nodes(op="placeholder"):
if "tangent" in node.name:
@ -712,68 +806,113 @@ def functionalize_rng_ops(
)
fw_rng_state_outputs = []
for base_node, node_pair in recomputable_rng_ops_map.items():
last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder")))
last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder")))
devices = OrderedSet(
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
)
devices.discard(torch.device("cpu"))
# multiple cuda devices wont work with cudagraphs anyway,
# fallback to non graphsafe rng checkpointing
multi_cuda_devices = len(devices) > 1
# this changes numerics, so if fallback_random is set we will not use it
ind_config = torch._inductor.config
use_rng_graphsafe_rng_functionalization = (
config.graphsafe_rng_functionalization
and not multi_cuda_devices
and (
not ind_config.fallback_random
or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random
)
)
for rng_count, (base_node, node_pair) in enumerate(
recomputable_rng_ops_map.items()
):
# Step 2 - Modify the fwd pass such that
fw_node = node_pair["fwd"]
bw_node = node_pair["bwd"]
device = get_device(fw_node)
fw_graph = fw_module.graph
with fw_graph.inserting_before(fw_node):
functional_fw_node = fw_graph.create_node(
"call_function",
run_and_save_rng,
args=(fw_node.target, *fw_node.args),
kwargs=fw_node.kwargs,
)
state = fw_graph.create_node(
"call_function",
operator.getitem,
args=(functional_fw_node, 0),
kwargs={},
)
rng_output = fw_graph.create_node(
"call_function",
operator.getitem,
args=(
functional_fw_node,
1,
),
kwargs={},
)
fw_node.replace_all_uses_with(rng_output)
fw_graph.erase_node(fw_node)
fw_rng_state_outputs.append(state)
# Step 3 - Modify the bwd pass such that
bw_graph = bw_module.graph
with bw_graph.inserting_before(bw_tangent_start_node):
state_name = f"rng_state_output_{next(uid)}"
bw_rng_state_node = bw_graph.placeholder(state_name)
bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node))
with bw_graph.inserting_before(bw_node):
rng_output = bw_graph.create_node(
"call_function",
run_with_rng_state,
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
kwargs=bw_node.kwargs,
if (
use_rng_graphsafe_rng_functionalization
and device is not None
and device.type == "cuda"
):
last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization(
fw_module,
bw_module,
fw_node,
bw_node,
device,
rng_count,
last_fwd_input,
last_bwd_input,
)
else:
with fw_graph.inserting_before(fw_node):
functional_fw_node = fw_graph.create_node(
"call_function",
run_and_save_rng,
args=(fw_node.target, *fw_node.args),
kwargs=fw_node.kwargs,
)
state = fw_graph.create_node(
"call_function",
operator.getitem,
args=(functional_fw_node, 0),
kwargs={},
)
rng_output = fw_graph.create_node(
"call_function",
operator.getitem,
args=(
functional_fw_node,
1,
),
kwargs={},
)
fw_node.replace_all_uses_with(rng_output)
fw_graph.erase_node(fw_node)
fw_rng_state_outputs.append(state)
bw_node.replace_all_uses_with(rng_output)
bw_graph.erase_node(bw_node)
# Step 3 - Modify the bwd pass such that
with bw_graph.inserting_before(bw_tangent_start_node):
state_name = f"rng_state_output_{next(uid)}"
bw_rng_state_node = bw_graph.placeholder(state_name)
bw_rng_state_node.meta["val"] = get_sample_rng_state(device)
with bw_graph.inserting_before(bw_node):
rng_output = bw_graph.create_node(
"call_function",
run_with_rng_state,
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
kwargs=bw_node.kwargs,
)
bw_node.replace_all_uses_with(rng_output)
bw_graph.erase_node(bw_node)
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
# that symints are at the end of forward graph outputs. So, insert the new
# rng states accordingly.
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
fw_outputs = fw_output_node.args[0]
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
outputs = (
fw_outputs[:sym_node_start_idx]
+ tuple(fw_rng_state_outputs)
+ fw_outputs[sym_node_start_idx:]
)
fw_module.graph.output(outputs)
fw_module.graph.erase_node(fw_output_node)
if fw_rng_state_outputs:
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
fw_outputs = fw_output_node.args[0]
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
outputs = (
fw_outputs[:sym_node_start_idx]
+ tuple(fw_rng_state_outputs)
+ fw_outputs[sym_node_start_idx:]
)
fw_module.graph.output(outputs)
fw_module.graph.erase_node(fw_output_node)
fw_module.recompile()
bw_module.recompile()
return fw_module, bw_module
@ -1849,7 +1988,6 @@ def min_cut_rematerialization_partition(
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(

View File

@ -898,7 +898,9 @@ class PythonWrapperCodegen(CodeGen):
continue
# a graph partition may take an IRNode output from a previous partition
if name not in V.graph.graph_input_names:
if name not in V.graph.graph_input_names or isinstance(
buf, ir.GeneratorState
):
continue
# comparing strides for 0 size tensor is tricky. Ignore them for now.
@ -1418,7 +1420,9 @@ class PythonWrapperCodegen(CodeGen):
code.writeline(f"{stride} = {strideof(name)}[{dim}]")
bound_vars.add(stride)
elif isinstance(value, ir.TorchBindObject):
pass
return
elif isinstance(value, ir.GeneratorState):
return
else:
if torch._inductor.config.graph_partition:
pass
@ -1612,6 +1616,11 @@ class PythonWrapperCodegen(CodeGen):
# is actually a valid value for the kernel in question.
# See https://github.com/pytorch/pytorch/issues/124686
add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
elif isinstance(value, ir.GeneratorState):
add_expr_input(
name,
f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()",
)
else:
shape = [
V.graph.sizevars.size_hint(x, fallback=42)
@ -2287,6 +2296,8 @@ class PythonWrapperCodegen(CodeGen):
return s.codegen_reference()
elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
return dtype_to_string(s)
elif isinstance(s, ir.GeneratorState):
return s.codegen_reference()
else:
return repr(s)

View File

@ -1517,6 +1517,8 @@ class test_configs:
autotune_choice_name_regex: Optional[str] = None
autotune_choice_desc_regex: Optional[str] = None
graphsafe_rng_func_ignores_fallback_random = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -908,6 +908,7 @@ class CUDAGraphNode:
self.recorded_liveness_before_graph = curr_liveness
self.expected_dead_indices_before_graph = different_indices
rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)]
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
# recording inputs will copy over memory, so we can free non recording inputs
inputs.clear()
@ -916,6 +917,11 @@ class CUDAGraphNode:
# graph used for recording model invocation
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
# TODO: register_generator_state should potentially take explicit device
with torch.cuda.device(self.device):
for rng_state in rng_states:
self.graph.register_generator_state(rng_state)
# we allocate non-static inputs within the same memory pool as the CUDAGraph
# which we will record the model with. For memory efficiency, it is important
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
@ -1602,7 +1608,7 @@ class CUDAGraphNode:
):
for i, inp in enumerate(inputs):
if not isinstance(inp, torch.Tensor):
assert isinstance(inp, int)
assert isinstance(inp, (int, torch.Generator))
recording_inputs.append(inp)
elif i not in self.static_input_idxs:
# static_input does an allocation!

View File

@ -1040,6 +1040,18 @@ class GraphLowering(torch.fx.Interpreter):
# Alternately we could filter this out in AotAutograd
self.graph_input_names.append(target)
return None
# See note: Note: [Generator arguments in AOTDispatcher]
elif isinstance(example, torch.Generator):
assert (
len(V.graph.current_node.users) == 1
and next(iter(V.graph.current_node.users)).target
is torch._prims.rng_prims.graphsafe_run_with_rng_state
)
gen = ir.GeneratorState(name=target, device=example.device)
self.graph_inputs[target] = gen # type: ignore[assignment]
self.graph_input_names.append(target)
return gen
assert isinstance(example, torch.Tensor), example
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing
@ -1288,7 +1300,7 @@ class GraphLowering(torch.fx.Interpreter):
if isinstance(value, TorchBindObject):
continue
assert isinstance(
value, (TensorBox, sympy.Expr)
value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState)
), f"Unsupported inductor graph input type: {type(value)}"
if not isinstance(value, TensorBox):
continue

View File

@ -4977,7 +4977,9 @@ class ExternKernel(InputsKernel):
tensor_args = []
non_tensor_args: list[Any] = []
for arg in args_flat:
is_arg_tensor.append(isinstance(arg, IRNode))
is_arg_tensor.append(
isinstance(arg, IRNode) and not isinstance(arg, GeneratorState)
)
if is_arg_tensor[-1]:
tensor_args.append(arg)
else:
@ -5008,7 +5010,9 @@ class ExternKernel(InputsKernel):
# Rerun fake tensor propagation, because Inductor may have changed the
# strides of inputs and we need to determine accurately what the
# output stride will be.
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
example_args: list[
Union[torch.Tensor, torch._C.ScriptObject, torch.Generator]
] = []
# We need to retain the constant values of fake tensors that we originally
# propagated the graph with, because for some operators running without a
@ -5025,6 +5029,12 @@ class ExternKernel(InputsKernel):
example_args.append(V.graph.torchbind_constants[x.get_name()])
elif isinstance(x, TorchBindObject):
example_args.append(x.get_real_obj())
elif isinstance(x, torch._inductor.ir.GeneratorState):
device_index = x.device.index
assert x.device.type == "cuda" and device_index is not None
example_args.append(
torch.cuda.default_generators[device_index].clone_state()
)
else:
example_args.append(ir_node_to_tensor(x, guard_shape=True))
@ -5155,7 +5165,7 @@ class ExternKernel(InputsKernel):
# TODO(jansel): impose layout preference on realized buffer
x.realize()
return x
if isinstance(x, TorchBindObject):
if isinstance(x, (NonTensorObj)):
return x
return cls.copy_input(x)
@ -7570,8 +7580,12 @@ class EffectfulKernel(FallbackKernel):
return True
class NonTensorObj(IRNode):
pass
@ir_dataclass
class TorchBindObject(IRNode):
class TorchBindObject(NonTensorObj):
from torch._library.fake_class_registry import FakeScriptObject
name: str
@ -7605,6 +7619,18 @@ class TorchBindObject(IRNode):
return functools.reduce(lambda x, y: x + y, flat_sizes, 0)
@ir_dataclass
class GeneratorState(NonTensorObj):
name: str
device: torch.device
def get_name(self): # type: ignore[no-untyped-def]
return self.name
def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
return self.name
class _CollectiveKernel(FallbackKernel):
def should_allocate(self) -> bool:
return False

View File

@ -2718,6 +2718,8 @@ make_fallback(aten.gcd.default, warn=False)
make_fallback(aten._thnn_fused_lstm_cell, require_dense)
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
make_fallback(torch._prims.rng_prims.run_with_rng_state)
make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
# Implmented / Half implemented
# Scans. Implemented for CUDA, missing CPU

View File

@ -426,7 +426,7 @@ class CompiledFxGraph(OutputCode):
(not complex_memory_overlap_inputs, "complex memory overlap"),
(
all(
isinstance(t, (torch.Tensor, torch.SymInt))
isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator))
for t in example_inputs
),
"non-Tensor inputs",

View File

@ -2104,6 +2104,7 @@ def count_tangents(fx_g: torch.fx.GraphModule) -> int:
"tangents" not in x.name
and "bwd_seed" not in x.name
and "bwd_base_offset" not in x.name
and "bwd_rng_state" not in x.name
)
arg_count = 0

View File

@ -315,5 +315,75 @@ run_and_save_rng_state = register_run_and_save_rng_state_op()
run_with_rng_state = register_run_with_rng_state_op()
def register_graphsafe_run_with_rng_state_op():
class GraphSafeRunWithRngState(HigherOrderOperator):
def __init__(self):
super().__init__("graphsafe_run_with_rng_state")
def __call__(self, op, *args, rng_state=None, **kwargs):
return super().__call__(op, *args, rng_state=rng_state, **kwargs)
graphsafe_run_with_rng_state = GraphSafeRunWithRngState()
graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True)
)
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, rng_state=None, **kwargs):
device_idx = rng_state.device.index
generator = torch.cuda.default_generators[device_idx]
current_state = generator.graphsafe_get_state()
generator.graphsafe_set_state(rng_state)
out = op(*args, **kwargs)
generator.graphsafe_set_state(current_state)
return out
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(op, *args, rng_state=None, **kwargs):
device = get_device(args, kwargs)
assert (
device == "cuda"
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs):
with mode:
return op(*args, **kwargs)
@graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs):
with disable_proxy_modes_tracing():
out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(
mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs}
)
out_proxy = mode.tracer.create_proxy(
"call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
@graphsafe_run_with_rng_state.py_functionalize_impl
def impl_functional(ctx, op, *args, rng_state=None, **kwargs):
unwrapped_rng_state = (
ctx.unwrap_tensors(rng_state) if rng_state is not None else None
)
unwrapped_args = ctx.unwrap_tensors(args)
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
out = graphsafe_run_with_rng_state(
op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs
)
return ctx.wrap_tensors(out)
return graphsafe_run_with_rng_state
graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op()
def register_rng_prims():
register_philox_rand()

View File

@ -819,6 +819,9 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
yield from _iterate_exprs(val.storage_offset())
elif val is None:
pass
# see Note: [Generator arguments in AOTDispatcher]
elif isinstance(val, torch.Generator):
pass
else:
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")

View File

@ -77,6 +77,7 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
"autograd_function_apply",
"run_and_save_rng_state",
"run_with_rng_state",
"graphsafe_run_with_rng_state",
"out_dtype",
"trace_wrapped",
'tag_activation_checkpoint',