diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index baac1724a9d7..17021eb46565 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -54,6 +54,8 @@ def count_ops( return node.args[0] == op elif node.name == "run_with_rng_state": return node.args[1] == op + elif node.name == "graphsafe_run_with_rng_state": + return node.args[0] == op return False # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) @@ -1018,6 +1020,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no @requires_cuda @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @torch._inductor.config.patch(fallback_random=True) def test_compile_selective_checkpoint_random_op(self, device): for preserve_rng_state in [True, False]: diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 5cf20869aa3e..2f870cffc376 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -4,13 +4,17 @@ import contextlib import functools import gc import importlib +import itertools import sys import unittest import warnings +from collections import defaultdict +from collections.abc import Mapping, Sequence import torch import torch._dynamo.config as dynamo_config import torch.nn as nn +from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode from torch._dynamo.utils import counters from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._inductor import config @@ -19,7 +23,9 @@ from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl from torch._inductor.cudagraph_utils import FunctionID from torch._inductor.test_case import TestCase as InductorTestCase +from torch._ops import OpOverload from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.immutable_collections import immutable_dict from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( @@ -32,6 +38,7 @@ from torch.testing._internal.common_utils import ( TEST_CUDA_GRAPH, TEST_WITH_ASAN, ) +from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -46,7 +53,7 @@ if IS_WINDOWS and IS_CI: importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA aten = torch.ops.aten @@ -2504,7 +2511,407 @@ if HAS_CUDA and not TEST_WITH_ASAN: eager_result = f(example_input) self.assertEqual(compiled_result, eager_result) + class TestSAC(TestCase): + def _make_observer_mode(self): + class ObserverMode(TorchDispatchMode): + def __init__(self): + super().__init__() + self.curr_run = 0 + self.op_outputs = defaultdict(list) + + def __torch_dispatch__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + return func(*args, **kwargs) + + return ObserverMode + + def test_simple(self): + device = "cuda" + + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + ObserverMode = self._make_observer_mode() + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[op].append(out) + return out + + obs = ObserverMode() + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + for _ in range(2): + torch._dynamo.reset() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + fn(x, y).sum().backward() + + self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4) + for i in range(2): + self.assertEqual( + obs.op_outputs[aten.rand.default][0 + 2 * i], + obs.op_outputs[aten.rand.default][1 + 2 * i], + ) + self.assertNotEqual( + obs.op_outputs[aten.rand.default][0], + obs.op_outputs[aten.rand.default][2], + ) + + def test_cudagraph_uneven_forward_backward(self): + # torch.compile cudagraphs are difficult to test + # the rng updating bc is sensitive to duration of pending backwards, etc. + # this is a short repro to mimic the runtime wrappers integration + # and show that updating the backward rng state with cudagraphs works: + def forward(): + state = torch.cuda.get_rng_state() + perm = torch.randperm(10, device="cuda") + return state, perm + + def backward(rng_state): + current_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state.cpu()) + perm = torch.randperm(10, device="cuda") + torch.cuda.set_rng_state(current_state) + return perm + + def normal_test(): + state, perm = forward() + repro_perm = backward(state) + return perm, repro_perm + + def graphsafe_forward(): + perm = torch.randperm(10, device="cuda") + return perm + + def graphsafe_backward(generator, new_state): + current_state = generator.graphsafe_get_state() + generator.graphsafe_set_state(new_state) + perm = torch.randperm(10, device="cuda") + generator.graphsafe_set_state(current_state) + return perm + + def graph_test(generator, capture_cuda_graph): + if capture_cuda_graph: + graph = torch.cuda.CUDAGraph() + + # state should be cloned before the graph + old_state = generator.graphsafe_get_state() + new_state = old_state.clone_state() + + if capture_cuda_graph: + # state should be register to the graph + graph.register_generator_state(new_state) + + # only capturing the backward + with torch.cuda.graph(graph): + repro_perm = graphsafe_backward(generator, new_state) + + # some number of uneven forwards + graphsafe_forward() + graphsafe_forward() + graphsafe_forward() + + # state prior to rng invocation + state = generator.get_state() + perm = graphsafe_forward() + + new_state.set_state(state) + + if capture_cuda_graph: + graph.replay() + else: + repro_perm = graphsafe_backward(generator, new_state) + + return perm, repro_perm + + self.assertEqual(*normal_test()) + generator = torch.cuda.default_generators[0] + self.assertEqual(*graph_test(generator, capture_cuda_graph=False)) + self.assertEqual(*graph_test(generator, capture_cuda_graph=True)) + + def test_cpu_and_cuda_rng(self): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import ( + graphsafe_run_with_rng_state, + run_and_save_rng_state, + run_with_rng_state, + ) + + for hop in [ + graphsafe_run_with_rng_state, + run_and_save_rng_state, + run_with_rng_state, + ]: + + def make_impl(hop): + @hop.py_impl(ObserverMode) + def _(mode, *args, **kwargs): + with no_dispatch(): + out = hop(*args, **kwargs) + + op = None + for inp in itertools.chain(args, kwargs.values()): + if isinstance(inp, torch._ops.OpOverload): + op = inp + break + assert op is not None + if hop is run_and_save_rng_state: + mode.op_outputs[op].append(out[1]) + else: + mode.op_outputs[op].append(out) + return out + + make_impl(hop) + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def gn2(x): + return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape) + + def fn(x, y, z): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True) + return x * z.cuda() + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + z = torch.randn(4, 4, requires_grad=True) + + fn(x, y, z).sum().backward() + for op in [aten.rand.default, aten.randperm.default]: + self.assertEqual(len(obs.op_outputs[op]), 2) + self.assertEqual( + obs.op_outputs[op][0], + obs.op_outputs[op][1], + ) + self.assertEqual( + obs.op_outputs[op][0].device.type, + "cpu" if op == aten.randperm.default else "cuda", + ) + + @parametrize("order", (list(itertools.permutations([0, 1, 2])))) + def test_uneven_forward_backward(self, order): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[(mode.curr_run, op)].append(out) + return out + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def gn2(x): + return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape) + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True) + return x + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn_c = torch.compile(fn, backend=aot_eager_decomp_partition) + + torch.manual_seed(0) + outs = [] + for i in range(len(order)): + obs.curr_run = i + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + outs.append(fn_c(x, y)) + + for idx in order: + obs.curr_run = idx + outs[idx].sum().backward() + + for run in range(len(order)): + for op in (aten.rand.default, aten.randperm.default): + self.assertEqual(len(obs.op_outputs[(run, op)]), 2) + self.assertEqual( + obs.op_outputs[(run, op)][0], + obs.op_outputs[(run, op)][1], + ) + if run != 0: + self.assertNotEqual( + obs.op_outputs[(run - 1, op)][0], + obs.op_outputs[(run, op)][0], + ) + + @config.patch(fallback_random=True) + @config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True) + def _test_cudagraphs_aot_eager_compat_equal(self, device): + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + outs = [] + grads = [] + + outs2 = [] + grads2 = [] + + compile_fns = [ + lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"), + lambda fn: torch.compile(fn, mode="reduce-overhead"), + ] + for i, compile_fn in enumerate(compile_fns): + torch.manual_seed(0) + for index in range(3): + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + out = compile_fn(fn)(x, y) + torch.cuda.synchronize() + out.sum().backward() + if i == 0: + outs.append(out.clone()) + grads.append((x.grad.clone(), y.grad.clone())) + else: + outs2.append(out.clone()) + grads2.append((x.grad.clone(), y.grad.clone())) + + self.assertEqual(outs, outs2) + self.assertEqual(grads, grads2) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + + def test_cudagraphs_aot_eager_compat_equal(self): + self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0")) + + @requires_multigpu() + def test_cudagraphs_aot_eager_compat_equal_device_one(self): + self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1")) + + @requires_multigpu() + def test_multi_device(self): + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + def multi_fn(x, y, a, b): + return fn(x, y), fn(a, b) + + x = torch.randn(4, 4, device="cuda:0", requires_grad=True) + y = torch.randn(4, 4, device="cuda:0", requires_grad=True) + + a = torch.randn(4, 4, device="cuda:1", requires_grad=True) + b = torch.randn(4, 4, device="cuda:1", requires_grad=True) + + # No errors. TODO - get graphs from logging, couldnt figure out how + multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition") + + out = multi_fn_c(x, y, a, b) + out[0].sum().backward() + + def test_retain_graph(self): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[op].append(out) + return out + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + out = fn(x, y).sum() + out.backward(retain_graph=True) + out.backward() + self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3) + self.assertEqual( + obs.op_outputs[aten.rand.default][0], + obs.op_outputs[aten.rand.default][1], + ) + self.assertEqual( + obs.op_outputs[aten.rand.default][1], + obs.op_outputs[aten.rand.default][2], + ) + instantiate_parametrized_tests(CudaGraphTreeTests) + instantiate_parametrized_tests(TestSAC) + if __name__ == "__main__": from torch._inductor.test_case import run_tests @@ -2514,5 +2921,5 @@ if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") - if HAS_CPU or HAS_CUDA: + if HAS_CUDA: run_tests(needs="filelock") diff --git a/test/test_hop_infra.py b/test/test_hop_infra.py index 291c86f330a9..2ece25a78423 100644 --- a/test/test_hop_infra.py +++ b/test/test_hop_infra.py @@ -62,6 +62,7 @@ class TestHOPInfra(TestCase): FIXME_ALLOWLIST = { "autograd_function_apply", "run_with_rng_state", + "graphsafe_run_with_rng_state", "map_impl", "_export_tracepoint", "run_and_save_rng_state", diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 7e13f5313d2a..509fd643eddd 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -66,6 +66,7 @@ from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta from .utils import ( _get_symint_hints, contain_metadata_mutation_ops, + get_cuda_generator_meta_val, make_boxed_func, strict_zip, unlift_tokens, @@ -458,9 +459,20 @@ def aot_dispatch_autograd( fake_mode = detect_fake_mode() if fake_mode is not None and fake_mode.shape_env is not None: tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode) + fw_module, bw_module = aot_config.partition_fn( fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) # See Note [Side-Effectful Tokens in AOTAutograd] if config.unlift_effect_tokens and ( @@ -684,6 +696,16 @@ def aot_dispatch_autograd( functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( return_new_outs=False ) + + if rng_states: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + ( fw_module, adjusted_flat_args, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 06c76c9f6617..a4b1f7550544 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1684,6 +1684,45 @@ def _backward_prologue_functional( return all_args +def initialize_rng_states( + num_rng: int, + graphsafe_idx: int, + fwd_rng_states: list[torch.Generator], + bwd_rng_states: list[torch.Generator], +): + """ + Initialize the cudagraph safe rng states. + + Initialization of rng states should have a few properties: + - the initialization for each rng state should be independent + - the initialization should be deterministic + - the initialization should be based off current rng state, so that independent graphs do not + have equal rng behavior + + We defer initialization of rng states until runtime because compilation is wrapped + with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations + do not give equal randomness. + """ + with torch.utils._python_dispatch._disable_current_modes(): + seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu") + fwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + bwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + + # NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. def _backward_epilogue_functional( metadata, maybe_subclass_metadata, out, *, make_subclass_override=None @@ -1819,6 +1858,34 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa fw_metadata: ViewAndMutationMeta, # runtime metadata try_save_cache_entry: Optional[Callable], # Save cache entry after compilation ): + # For additional context see Note [CUDA Graph Safe RNG Functionalization] + # Each pair forward, backward rng states must be equal prior to its invocation on any + # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, + # running forward then backward advances them the same amount and keeps them equal. + # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. + # Initially we have: + # fwd_state0 == bwd_state0. + # Lets say we run: + # fwd0: fwd_state0 -> fwd_state1 + # fwd1: fwd_state1 -> fwd_state2 + # fwd2: fwd_state2 -> fwd_state3 + # If we now invoke bwd2, + # we need to update bwd_state equal to the rng that was observed in fwd2. + # we save the rng_state fwd_state2 in forward because we detect that it is not the + # current backward state and therefore would not be accessible if we do not save it. + # Similarly, if we are going to update the backward state to a new value, and there is a pending + # forwards which needs its current state, we will save it. + # Within the autograd context, we keep track of the curr iteration so that on backward + # we know what the generator state must be before the backward is run. + num_rng = fw_metadata.num_graphsafe_rng_states + graphsafe_idx = fw_metadata.graphsafe_rng_state_index + fwd_rng_states: list[torch.Generator] = [] + bwd_rng_states: list[torch.Generator] = [] + curr_fwd_iter = itertools.count(0) + backward_state_position = 0 + pending_forwards: set[int] = set() + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} + class CompiledFunction(torch.autograd.Function): compiled_fw = compiled_fw_func compiled_bw = compiled_bw_func @@ -1840,6 +1907,26 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa assert isinstance(bw_state, BackwardState) ctx._compiled_autograd_backward_state = bw_state + if num_rng: + if len(fwd_rng_states) == 0: + assert graphsafe_idx is not None + initialize_rng_states( + num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states + ) + + _curr_iter = next(curr_fwd_iter) + ctx._curr_iter = _curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if _curr_iter != backward_state_position: + saved_backward_tensor_states[_curr_iter] = [ + rng_state.get_state() for rng_state in fwd_rng_states + ] + + pending_forwards.add(_curr_iter) + args = (*args, *fwd_rng_states) + # There is a pretty complicated calling convention around what the compiled fw returns. # The full list of outputs and their relative order is: # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) @@ -1967,6 +2054,45 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa *flat_args, ) + if num_rng: + nonlocal backward_state_position, bwd_rng_states + curr_backward_iter = ctx._curr_iter + retain_graph = ( + torch._C._autograd._get_current_graph_task_keep_graph() + ) + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + backward_state_position in pending_forwards + and backward_state_position not in saved_backward_tensor_states + and ( + backward_state_position != curr_backward_iter + or retain_graph + ) + ): + saved_backward_tensor_states[backward_state_position] = [ + rng_state.get_state() for rng_state in bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in saved_backward_tensor_states: + if backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + bwd_rng_states, + saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del saved_backward_tensor_states[curr_backward_iter] + else: + assert backward_state_position == curr_backward_iter + + backward_state_position = curr_backward_iter + 1 + if not retain_graph: + pending_forwards.remove(curr_backward_iter) + all_args.extend(bwd_rng_states) + def impl_fn(double_ctx=None): out = CompiledFunction._backward_impl(ctx, all_args) return _backward_epilogue_functional( diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 923ca12f7db5..6259d082e2ae 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -420,11 +420,16 @@ class ViewAndMutationMeta: # Filled after tracing joint function. num_backward_tokens: int = 0 + # Number of rng states that will get thread into the forward and backward for + # cudagraph compatible run_and_save_rng + num_graphsafe_rng_states: int = 0 + + graphsafe_rng_state_index: Optional[int] = None + def __post_init__(self): # pre-compute the indices of the inputs that are mutated. # When keep_input_mutations is set, we don't need to worry about our epilogue # handling data-only mutations, because we keep them directly in the graph. - mutated_inp_runtime_indices = [ i for i, m in enumerate(self.input_info) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 940e7e8829af..cc14a77244f6 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -489,3 +489,14 @@ def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool: ): return True return False + + +def get_cuda_generator_meta_val(device_idx: int): + """ + Get a generator value to use as a meta val + + newly cloned generator will not contain tensors. it is only Generators that are + registered to a CUDAGraph that contain tensors. since this does not contain Tensor + it is fine to use in the meta. + """ + return torch.cuda.default_generators[device_idx].clone_state() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 58d471bce7f6..bfcaa27b313e 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -210,6 +210,10 @@ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") # real tensor outputs. generate_fake_kernels_from_real_mismatches = False +# CUDAGraph save run_with_rng functionalization. +# TODO: turn on by default +graphsafe_rng_functionalization = True + # Error on BypassAOTAutogradCache instead of just a warning # Used for tests diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index f3aeb3765d9d..6e085924d8e2 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -40,7 +40,7 @@ from ._activation_checkpointing.knapsack import ( ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.logging_utils import get_aot_graph_name -from ._aot_autograd.utils import is_with_effects +from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target @@ -622,6 +622,99 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: return new_gm +def apply_graphsafe_rng_functionalization( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_node: torch.fx.Node, + bw_node: torch.fx.Node, + device: torch.device, + rng_count: int, + last_fwd_input: torch.fx.Node, + last_bwd_input: torch.fx.Node, +): + """ + Note [CUDA Graph Safe RNG Functionalization] + + CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values, + while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a + CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer + to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator + (and its cuda-tensor RNG state during graph capture). + + For each RNG operation's forward/backward pair: + + - We create two generators initialized with identical values + - Each forward and backward call advances its respective generator equally + - This keeps generators synchronized so forward and backward operations use matching RNG values + + When forward is called multiple times before backward (causing desynchronization): + + - We save the forward RNG state + - We update the backward Generator's state before executing backward + + Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator + changes are reflected during replay. + + This function modifies both forward and backward computation graphs by: + + Creating RNG state placeholders for both passes + Updating the forward node to use graph-safe RNG state + Updating the backward node to use graph-safe RNG state + + For more details: https://github.com/pytorch/pytorch/issues/113541 + """ + device_idx = device.index + assert device_idx is not None + fw_graph = fw_module.graph + bw_graph = bw_module.graph + graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state + + # Handle forward pass + + # Note: [Generator arguments in AOTDispatcher] + # Generator arguments in AOTDispatcher are added to support graphsafe rng + # functionalization. See note above [CUDA Graph Safe RNG Functionalization] + with fw_module.graph.inserting_after(last_fwd_input): + fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}") + fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_fwd_input = fwd_rng_state + + # Handle backward pass + with bw_module.graph.inserting_after(last_bwd_input): + bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}") + # as above, clone so that meta val generator will not contain tensors + bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_bwd_input = bwd_rng_state + + # Update forward node + fw_kwargs = dict(fw_node.kwargs) + fw_kwargs["rng_state"] = fwd_rng_state + with fw_module.graph.inserting_after(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(fw_node.target, *fw_node.args), # type: ignore[arg-type] + kwargs=fw_kwargs, + ) + fw_node.replace_all_uses_with(functional_fw_node) + fw_graph.erase_node(fw_node) + + # Update backward node + bwd_kwargs = dict(bw_node.kwargs) + bwd_kwargs["rng_state"] = bwd_rng_state + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(bw_node.target, *bw_node.args), # type: ignore[arg-type] + kwargs=bwd_kwargs, + ) + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + return last_fwd_input, last_bwd_input + + def functionalize_rng_ops( joint_module: fx.GraphModule, fw_module: fx.GraphModule, @@ -660,7 +753,7 @@ def functionalize_rng_ops( random_nodes[node.name] = node return random_nodes - def get_device(node): + def get_device(node) -> Optional[torch.device]: """ Check the example value of the node outputs to find the device type. """ @@ -674,12 +767,12 @@ def functionalize_rng_ops( for candidate in candidates: if isinstance(candidate, torch.Tensor): if candidate.device.type == "cuda": - return "cuda" + return candidate.device - return "cpu" + return torch.device("cpu") - def get_sample_rng_state(device): - if device == "cuda": + def get_sample_rng_state(device: Optional[torch.device]): + if device is not None and device.type == "cuda": return torch.cuda.get_rng_state() return torch.get_rng_state() @@ -701,6 +794,7 @@ def functionalize_rng_ops( run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + bw_tangent_start_node = None for node in bw_module.graph.find_nodes(op="placeholder"): if "tangent" in node.name: @@ -712,68 +806,113 @@ def functionalize_rng_ops( ) fw_rng_state_outputs = [] - for base_node, node_pair in recomputable_rng_ops_map.items(): + + last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder"))) + last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder"))) + + devices = OrderedSet( + get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values() + ) + devices.discard(torch.device("cpu")) + # multiple cuda devices wont work with cudagraphs anyway, + # fallback to non graphsafe rng checkpointing + multi_cuda_devices = len(devices) > 1 + + # this changes numerics, so if fallback_random is set we will not use it + ind_config = torch._inductor.config + use_rng_graphsafe_rng_functionalization = ( + config.graphsafe_rng_functionalization + and not multi_cuda_devices + and ( + not ind_config.fallback_random + or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random + ) + ) + + for rng_count, (base_node, node_pair) in enumerate( + recomputable_rng_ops_map.items() + ): # Step 2 - Modify the fwd pass such that fw_node = node_pair["fwd"] bw_node = node_pair["bwd"] + device = get_device(fw_node) + fw_graph = fw_module.graph - with fw_graph.inserting_before(fw_node): - functional_fw_node = fw_graph.create_node( - "call_function", - run_and_save_rng, - args=(fw_node.target, *fw_node.args), - kwargs=fw_node.kwargs, - ) - state = fw_graph.create_node( - "call_function", - operator.getitem, - args=(functional_fw_node, 0), - kwargs={}, - ) - rng_output = fw_graph.create_node( - "call_function", - operator.getitem, - args=( - functional_fw_node, - 1, - ), - kwargs={}, - ) - fw_node.replace_all_uses_with(rng_output) - fw_graph.erase_node(fw_node) - fw_rng_state_outputs.append(state) - - # Step 3 - Modify the bwd pass such that bw_graph = bw_module.graph - with bw_graph.inserting_before(bw_tangent_start_node): - state_name = f"rng_state_output_{next(uid)}" - bw_rng_state_node = bw_graph.placeholder(state_name) - bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) - with bw_graph.inserting_before(bw_node): - rng_output = bw_graph.create_node( - "call_function", - run_with_rng_state, - args=(bw_rng_state_node, bw_node.target, *bw_node.args), - kwargs=bw_node.kwargs, + if ( + use_rng_graphsafe_rng_functionalization + and device is not None + and device.type == "cuda" + ): + last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization( + fw_module, + bw_module, + fw_node, + bw_node, + device, + rng_count, + last_fwd_input, + last_bwd_input, ) + else: + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs, + ) + state = fw_graph.create_node( + "call_function", + operator.getitem, + args=(functional_fw_node, 0), + kwargs={}, + ) + rng_output = fw_graph.create_node( + "call_function", + operator.getitem, + args=( + functional_fw_node, + 1, + ), + kwargs={}, + ) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) - bw_node.replace_all_uses_with(rng_output) - bw_graph.erase_node(bw_node) + # Step 3 - Modify the bwd pass such that + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = get_sample_rng_state(device) + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs, + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) # Add the rng states in the output of the fwd graph. AOT Autograd assumes # that symints are at the end of forward graph outputs. So, insert the new # rng states accordingly. - fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) - fw_outputs = fw_output_node.args[0] - sym_node_start_idx = len(fw_outputs) - num_sym_nodes - outputs = ( - fw_outputs[:sym_node_start_idx] - + tuple(fw_rng_state_outputs) - + fw_outputs[sym_node_start_idx:] - ) - fw_module.graph.output(outputs) - fw_module.graph.erase_node(fw_output_node) + if fw_rng_state_outputs: + fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = ( + fw_outputs[:sym_node_start_idx] + + tuple(fw_rng_state_outputs) + + fw_outputs[sym_node_start_idx:] + ) + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) fw_module.recompile() bw_module.recompile() return fw_module, bw_module @@ -1849,7 +1988,6 @@ def min_cut_rematerialization_partition( saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs, ) - if graph_has_recomputable_ops: if graph_has_recomputable_rng_ops: fw_module, bw_module = functionalize_rng_ops( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 90aba45ce0ac..8872b049b8af 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -898,7 +898,9 @@ class PythonWrapperCodegen(CodeGen): continue # a graph partition may take an IRNode output from a previous partition - if name not in V.graph.graph_input_names: + if name not in V.graph.graph_input_names or isinstance( + buf, ir.GeneratorState + ): continue # comparing strides for 0 size tensor is tricky. Ignore them for now. @@ -1418,7 +1420,9 @@ class PythonWrapperCodegen(CodeGen): code.writeline(f"{stride} = {strideof(name)}[{dim}]") bound_vars.add(stride) elif isinstance(value, ir.TorchBindObject): - pass + return + elif isinstance(value, ir.GeneratorState): + return else: if torch._inductor.config.graph_partition: pass @@ -1612,6 +1616,11 @@ class PythonWrapperCodegen(CodeGen): # is actually a valid value for the kernel in question. # See https://github.com/pytorch/pytorch/issues/124686 add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + elif isinstance(value, ir.GeneratorState): + add_expr_input( + name, + f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", + ) else: shape = [ V.graph.sizevars.size_hint(x, fallback=42) @@ -2287,6 +2296,8 @@ class PythonWrapperCodegen(CodeGen): return s.codegen_reference() elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] return dtype_to_string(s) + elif isinstance(s, ir.GeneratorState): + return s.codegen_reference() else: return repr(s) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c563c3f1118d..18d25cf342cf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1517,6 +1517,8 @@ class test_configs: autotune_choice_name_regex: Optional[str] = None autotune_choice_desc_regex: Optional[str] = None + graphsafe_rng_func_ignores_fallback_random = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 5421eecc6a7f..a1bbbb1f39d1 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -908,6 +908,7 @@ class CUDAGraphNode: self.recorded_liveness_before_graph = curr_liveness self.expected_dead_indices_before_graph = different_indices + rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)] recording_inputs = self._allocate_and_copy_recording_inputs(inputs) # recording inputs will copy over memory, so we can free non recording inputs inputs.clear() @@ -916,6 +917,11 @@ class CUDAGraphNode: # graph used for recording model invocation self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + # TODO: register_generator_state should potentially take explicit device + with torch.cuda.device(self.device): + for rng_state in rng_states: + self.graph.register_generator_state(rng_state) + # we allocate non-static inputs within the same memory pool as the CUDAGraph # which we will record the model with. For memory efficiency, it is important # to reclaim the input memory when the inputs are no longer live. To accomplish this, @@ -1602,7 +1608,7 @@ class CUDAGraphNode: ): for i, inp in enumerate(inputs): if not isinstance(inp, torch.Tensor): - assert isinstance(inp, int) + assert isinstance(inp, (int, torch.Generator)) recording_inputs.append(inp) elif i not in self.static_input_idxs: # static_input does an allocation! diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index f9bc10db1700..1c77f68b0b93 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1040,6 +1040,18 @@ class GraphLowering(torch.fx.Interpreter): # Alternately we could filter this out in AotAutograd self.graph_input_names.append(target) return None + # See note: Note: [Generator arguments in AOTDispatcher] + elif isinstance(example, torch.Generator): + assert ( + len(V.graph.current_node.users) == 1 + and next(iter(V.graph.current_node.users)).target + is torch._prims.rng_prims.graphsafe_run_with_rng_state + ) + gen = ir.GeneratorState(name=target, device=example.device) + self.graph_inputs[target] = gen # type: ignore[assignment] + self.graph_input_names.append(target) + return gen + assert isinstance(example, torch.Tensor), example # todo(chilli): We can remove the last check once we turn buffers into # static shape tensors. That's a hack to workaround Inductor believing @@ -1288,7 +1300,7 @@ class GraphLowering(torch.fx.Interpreter): if isinstance(value, TorchBindObject): continue assert isinstance( - value, (TensorBox, sympy.Expr) + value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState) ), f"Unsupported inductor graph input type: {type(value)}" if not isinstance(value, TensorBox): continue diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3fa472645aec..9c128e3949e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4977,7 +4977,9 @@ class ExternKernel(InputsKernel): tensor_args = [] non_tensor_args: list[Any] = [] for arg in args_flat: - is_arg_tensor.append(isinstance(arg, IRNode)) + is_arg_tensor.append( + isinstance(arg, IRNode) and not isinstance(arg, GeneratorState) + ) if is_arg_tensor[-1]: tensor_args.append(arg) else: @@ -5008,7 +5010,9 @@ class ExternKernel(InputsKernel): # Rerun fake tensor propagation, because Inductor may have changed the # strides of inputs and we need to determine accurately what the # output stride will be. - example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = [] + example_args: list[ + Union[torch.Tensor, torch._C.ScriptObject, torch.Generator] + ] = [] # We need to retain the constant values of fake tensors that we originally # propagated the graph with, because for some operators running without a @@ -5025,6 +5029,12 @@ class ExternKernel(InputsKernel): example_args.append(V.graph.torchbind_constants[x.get_name()]) elif isinstance(x, TorchBindObject): example_args.append(x.get_real_obj()) + elif isinstance(x, torch._inductor.ir.GeneratorState): + device_index = x.device.index + assert x.device.type == "cuda" and device_index is not None + example_args.append( + torch.cuda.default_generators[device_index].clone_state() + ) else: example_args.append(ir_node_to_tensor(x, guard_shape=True)) @@ -5155,7 +5165,7 @@ class ExternKernel(InputsKernel): # TODO(jansel): impose layout preference on realized buffer x.realize() return x - if isinstance(x, TorchBindObject): + if isinstance(x, (NonTensorObj)): return x return cls.copy_input(x) @@ -7570,8 +7580,12 @@ class EffectfulKernel(FallbackKernel): return True +class NonTensorObj(IRNode): + pass + + @ir_dataclass -class TorchBindObject(IRNode): +class TorchBindObject(NonTensorObj): from torch._library.fake_class_registry import FakeScriptObject name: str @@ -7605,6 +7619,18 @@ class TorchBindObject(IRNode): return functools.reduce(lambda x, y: x + y, flat_sizes, 0) +@ir_dataclass +class GeneratorState(NonTensorObj): + name: str + device: torch.device + + def get_name(self): # type: ignore[no-untyped-def] + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + class _CollectiveKernel(FallbackKernel): def should_allocate(self) -> bool: return False diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 40780a18b692..21c569f2dc7d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2718,6 +2718,8 @@ make_fallback(aten.gcd.default, warn=False) make_fallback(aten._thnn_fused_lstm_cell, require_dense) make_fallback(torch._prims.rng_prims.run_and_save_rng_state) make_fallback(torch._prims.rng_prims.run_with_rng_state) +make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state) + # Implmented / Half implemented # Scans. Implemented for CUDA, missing CPU diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 9637ce8c4c29..dd0adb36c49c 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -426,7 +426,7 @@ class CompiledFxGraph(OutputCode): (not complex_memory_overlap_inputs, "complex memory overlap"), ( all( - isinstance(t, (torch.Tensor, torch.SymInt)) + isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator)) for t in example_inputs ), "non-Tensor inputs", diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f26d3312052..cb3f8a5417f1 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2104,6 +2104,7 @@ def count_tangents(fx_g: torch.fx.GraphModule) -> int: "tangents" not in x.name and "bwd_seed" not in x.name and "bwd_base_offset" not in x.name + and "bwd_rng_state" not in x.name ) arg_count = 0 diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index d4d9203ef6ab..70b4bc472358 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -315,5 +315,75 @@ run_and_save_rng_state = register_run_and_save_rng_state_op() run_with_rng_state = register_run_with_rng_state_op() +def register_graphsafe_run_with_rng_state_op(): + class GraphSafeRunWithRngState(HigherOrderOperator): + def __init__(self): + super().__init__("graphsafe_run_with_rng_state") + + def __call__(self, op, *args, rng_state=None, **kwargs): + return super().__call__(op, *args, rng_state=rng_state, **kwargs) + + graphsafe_run_with_rng_state = GraphSafeRunWithRngState() + + graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True) + ) + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, rng_state=None, **kwargs): + device_idx = rng_state.device.index + generator = torch.cuda.default_generators[device_idx] + current_state = generator.graphsafe_get_state() + generator.graphsafe_set_state(rng_state) + out = op(*args, **kwargs) + generator.graphsafe_set_state(current_state) + return out + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, rng_state=None, **kwargs): + device = get_device(args, kwargs) + assert ( + device == "cuda" + ), f"GraphSafe RNG operations only supported for CUDA, got {device}" + return impl_cuda(op, *args, rng_state=rng_state, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs): + with mode: + return op(*args, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs): + with disable_proxy_modes_tracing(): + out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) + proxy_kwargs = pytree.tree_map( + mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs} + ) + out_proxy = mode.tracer.create_proxy( + "call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + @graphsafe_run_with_rng_state.py_functionalize_impl + def impl_functional(ctx, op, *args, rng_state=None, **kwargs): + unwrapped_rng_state = ( + ctx.unwrap_tensors(rng_state) if rng_state is not None else None + ) + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = graphsafe_run_with_rng_state( + op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs + ) + return ctx.wrap_tensors(out) + + return graphsafe_run_with_rng_state + + +graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op() + + def register_rng_prims(): register_philox_rand() diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7a4ecbf705ec..368ba88dcd11 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -819,6 +819,9 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: yield from _iterate_exprs(val.storage_offset()) elif val is None: pass + # see Note: [Generator arguments in AOTDispatcher] + elif isinstance(val, torch.Generator): + pass else: raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 37c44ab7f725..814df49f0f71 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -77,6 +77,7 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [ "autograd_function_apply", "run_and_save_rng_state", "run_with_rng_state", + "graphsafe_run_with_rng_state", "out_dtype", "trace_wrapped", 'tag_activation_checkpoint',