From 481a57bc37ffb90a670e63a7bdbfb68314190943 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 27 Feb 2025 07:39:55 -0800 Subject: [PATCH] Support torch.compile rng selective activation checkpointing with cudagraph (#146878) TODO: - [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync - [x] Update rng state initialization to take from correct device - [x] Tests - [x] handling of retain_graph - [x] respect fallback random Fix for https://github.com/pytorch/pytorch/issues/130123. Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states. We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward. ``` ===== Forward graph 1 ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None ... ===== Backward graph 1 ===== def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None ``` There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls: - fwd0: fwd_rng_state0 -> fwd_rng_state1 - fwd1: fwd_rng_state1 -> fwd_rng_state2 - bwd1 - bwd0 Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary. Other notes: Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order. Questions for reviewers: This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`. Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts. Edit: updated to be taken from randint() Update: initializing rng states from torch.randint.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878 Approved by: https://github.com/anijain2305, https://github.com/bdhirsh --- test/dynamo/test_activation_checkpointing.py | 3 + test/inductor/test_cudagraph_trees.py | 411 +++++++++++++++++- test/test_hop_infra.py | 1 + .../jit_compile_runtime_wrappers.py | 22 + .../_aot_autograd/runtime_wrappers.py | 126 ++++++ torch/_functorch/_aot_autograd/schemas.py | 7 +- torch/_functorch/_aot_autograd/utils.py | 11 + torch/_functorch/config.py | 4 + torch/_functorch/partitioners.py | 252 ++++++++--- torch/_inductor/codegen/wrapper.py | 15 +- torch/_inductor/config.py | 2 + torch/_inductor/cudagraph_trees.py | 8 +- torch/_inductor/graph.py | 14 +- torch/_inductor/ir.py | 34 +- torch/_inductor/lowering.py | 2 + torch/_inductor/output_code.py | 2 +- torch/_inductor/utils.py | 1 + torch/_prims/rng_prims.py | 70 +++ torch/fx/experimental/symbolic_shapes.py | 3 + torch/testing/_internal/hop_db.py | 1 + 20 files changed, 920 insertions(+), 69 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index baac1724a9d7..17021eb46565 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -54,6 +54,8 @@ def count_ops( return node.args[0] == op elif node.name == "run_with_rng_state": return node.args[1] == op + elif node.name == "graphsafe_run_with_rng_state": + return node.args[0] == op return False # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) @@ -1018,6 +1020,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no @requires_cuda @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @torch._inductor.config.patch(fallback_random=True) def test_compile_selective_checkpoint_random_op(self, device): for preserve_rng_state in [True, False]: diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 5cf20869aa3e..2f870cffc376 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -4,13 +4,17 @@ import contextlib import functools import gc import importlib +import itertools import sys import unittest import warnings +from collections import defaultdict +from collections.abc import Mapping, Sequence import torch import torch._dynamo.config as dynamo_config import torch.nn as nn +from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode from torch._dynamo.utils import counters from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._inductor import config @@ -19,7 +23,9 @@ from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl from torch._inductor.cudagraph_utils import FunctionID from torch._inductor.test_case import TestCase as InductorTestCase +from torch._ops import OpOverload from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.immutable_collections import immutable_dict from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( @@ -32,6 +38,7 @@ from torch.testing._internal.common_utils import ( TEST_CUDA_GRAPH, TEST_WITH_ASAN, ) +from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -46,7 +53,7 @@ if IS_WINDOWS and IS_CI: importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA aten = torch.ops.aten @@ -2504,7 +2511,407 @@ if HAS_CUDA and not TEST_WITH_ASAN: eager_result = f(example_input) self.assertEqual(compiled_result, eager_result) + class TestSAC(TestCase): + def _make_observer_mode(self): + class ObserverMode(TorchDispatchMode): + def __init__(self): + super().__init__() + self.curr_run = 0 + self.op_outputs = defaultdict(list) + + def __torch_dispatch__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + return func(*args, **kwargs) + + return ObserverMode + + def test_simple(self): + device = "cuda" + + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + ObserverMode = self._make_observer_mode() + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[op].append(out) + return out + + obs = ObserverMode() + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + for _ in range(2): + torch._dynamo.reset() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + fn(x, y).sum().backward() + + self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4) + for i in range(2): + self.assertEqual( + obs.op_outputs[aten.rand.default][0 + 2 * i], + obs.op_outputs[aten.rand.default][1 + 2 * i], + ) + self.assertNotEqual( + obs.op_outputs[aten.rand.default][0], + obs.op_outputs[aten.rand.default][2], + ) + + def test_cudagraph_uneven_forward_backward(self): + # torch.compile cudagraphs are difficult to test + # the rng updating bc is sensitive to duration of pending backwards, etc. + # this is a short repro to mimic the runtime wrappers integration + # and show that updating the backward rng state with cudagraphs works: + def forward(): + state = torch.cuda.get_rng_state() + perm = torch.randperm(10, device="cuda") + return state, perm + + def backward(rng_state): + current_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state.cpu()) + perm = torch.randperm(10, device="cuda") + torch.cuda.set_rng_state(current_state) + return perm + + def normal_test(): + state, perm = forward() + repro_perm = backward(state) + return perm, repro_perm + + def graphsafe_forward(): + perm = torch.randperm(10, device="cuda") + return perm + + def graphsafe_backward(generator, new_state): + current_state = generator.graphsafe_get_state() + generator.graphsafe_set_state(new_state) + perm = torch.randperm(10, device="cuda") + generator.graphsafe_set_state(current_state) + return perm + + def graph_test(generator, capture_cuda_graph): + if capture_cuda_graph: + graph = torch.cuda.CUDAGraph() + + # state should be cloned before the graph + old_state = generator.graphsafe_get_state() + new_state = old_state.clone_state() + + if capture_cuda_graph: + # state should be register to the graph + graph.register_generator_state(new_state) + + # only capturing the backward + with torch.cuda.graph(graph): + repro_perm = graphsafe_backward(generator, new_state) + + # some number of uneven forwards + graphsafe_forward() + graphsafe_forward() + graphsafe_forward() + + # state prior to rng invocation + state = generator.get_state() + perm = graphsafe_forward() + + new_state.set_state(state) + + if capture_cuda_graph: + graph.replay() + else: + repro_perm = graphsafe_backward(generator, new_state) + + return perm, repro_perm + + self.assertEqual(*normal_test()) + generator = torch.cuda.default_generators[0] + self.assertEqual(*graph_test(generator, capture_cuda_graph=False)) + self.assertEqual(*graph_test(generator, capture_cuda_graph=True)) + + def test_cpu_and_cuda_rng(self): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import ( + graphsafe_run_with_rng_state, + run_and_save_rng_state, + run_with_rng_state, + ) + + for hop in [ + graphsafe_run_with_rng_state, + run_and_save_rng_state, + run_with_rng_state, + ]: + + def make_impl(hop): + @hop.py_impl(ObserverMode) + def _(mode, *args, **kwargs): + with no_dispatch(): + out = hop(*args, **kwargs) + + op = None + for inp in itertools.chain(args, kwargs.values()): + if isinstance(inp, torch._ops.OpOverload): + op = inp + break + assert op is not None + if hop is run_and_save_rng_state: + mode.op_outputs[op].append(out[1]) + else: + mode.op_outputs[op].append(out) + return out + + make_impl(hop) + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def gn2(x): + return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape) + + def fn(x, y, z): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True) + return x * z.cuda() + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + z = torch.randn(4, 4, requires_grad=True) + + fn(x, y, z).sum().backward() + for op in [aten.rand.default, aten.randperm.default]: + self.assertEqual(len(obs.op_outputs[op]), 2) + self.assertEqual( + obs.op_outputs[op][0], + obs.op_outputs[op][1], + ) + self.assertEqual( + obs.op_outputs[op][0].device.type, + "cpu" if op == aten.randperm.default else "cuda", + ) + + @parametrize("order", (list(itertools.permutations([0, 1, 2])))) + def test_uneven_forward_backward(self, order): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[(mode.curr_run, op)].append(out) + return out + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def gn2(x): + return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape) + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True) + return x + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn_c = torch.compile(fn, backend=aot_eager_decomp_partition) + + torch.manual_seed(0) + outs = [] + for i in range(len(order)): + obs.curr_run = i + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + outs.append(fn_c(x, y)) + + for idx in order: + obs.curr_run = idx + outs[idx].sum().backward() + + for run in range(len(order)): + for op in (aten.rand.default, aten.randperm.default): + self.assertEqual(len(obs.op_outputs[(run, op)]), 2) + self.assertEqual( + obs.op_outputs[(run, op)][0], + obs.op_outputs[(run, op)][1], + ) + if run != 0: + self.assertNotEqual( + obs.op_outputs[(run - 1, op)][0], + obs.op_outputs[(run, op)][0], + ) + + @config.patch(fallback_random=True) + @config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True) + def _test_cudagraphs_aot_eager_compat_equal(self, device): + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + outs = [] + grads = [] + + outs2 = [] + grads2 = [] + + compile_fns = [ + lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"), + lambda fn: torch.compile(fn, mode="reduce-overhead"), + ] + for i, compile_fn in enumerate(compile_fns): + torch.manual_seed(0) + for index in range(3): + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + out = compile_fn(fn)(x, y) + torch.cuda.synchronize() + out.sum().backward() + if i == 0: + outs.append(out.clone()) + grads.append((x.grad.clone(), y.grad.clone())) + else: + outs2.append(out.clone()) + grads2.append((x.grad.clone(), y.grad.clone())) + + self.assertEqual(outs, outs2) + self.assertEqual(grads, grads2) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + + def test_cudagraphs_aot_eager_compat_equal(self): + self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0")) + + @requires_multigpu() + def test_cudagraphs_aot_eager_compat_equal_device_one(self): + self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1")) + + @requires_multigpu() + def test_multi_device(self): + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + def multi_fn(x, y, a, b): + return fn(x, y), fn(a, b) + + x = torch.randn(4, 4, device="cuda:0", requires_grad=True) + y = torch.randn(4, 4, device="cuda:0", requires_grad=True) + + a = torch.randn(4, 4, device="cuda:1", requires_grad=True) + b = torch.randn(4, 4, device="cuda:1", requires_grad=True) + + # No errors. TODO - get graphs from logging, couldnt figure out how + multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition") + + out = multi_fn_c(x, y, a, b) + out[0].sum().backward() + + def test_retain_graph(self): + device = "cuda" + + ObserverMode = self._make_observer_mode() + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + @graphsafe_run_with_rng_state.py_impl(ObserverMode) + def _(mode, op, *args, **kwargs): + with no_dispatch(): + out = graphsafe_run_with_rng_state(op, *args, **kwargs) + + mode.op_outputs[op].append(out) + return out + + obs = ObserverMode() + + def gn(x, y): + return torch.sigmoid(torch.rand_like(x) * y) * x + + def fn(x, y): + x = torch.sin(x) + x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) + x = torch.sin(x) + return x + + x = torch.randn(4, 4, device=device, requires_grad=True) + y = torch.randn(4, 4, device=device, requires_grad=True) + + aot_eager_decomp_partition = functools.partial( + aot_eager_decomp_partition_with_mode, mode=obs + ) + + fn = torch.compile(fn, backend=aot_eager_decomp_partition) + + out = fn(x, y).sum() + out.backward(retain_graph=True) + out.backward() + self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3) + self.assertEqual( + obs.op_outputs[aten.rand.default][0], + obs.op_outputs[aten.rand.default][1], + ) + self.assertEqual( + obs.op_outputs[aten.rand.default][1], + obs.op_outputs[aten.rand.default][2], + ) + instantiate_parametrized_tests(CudaGraphTreeTests) + instantiate_parametrized_tests(TestSAC) + if __name__ == "__main__": from torch._inductor.test_case import run_tests @@ -2514,5 +2921,5 @@ if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") - if HAS_CPU or HAS_CUDA: + if HAS_CUDA: run_tests(needs="filelock") diff --git a/test/test_hop_infra.py b/test/test_hop_infra.py index 291c86f330a9..2ece25a78423 100644 --- a/test/test_hop_infra.py +++ b/test/test_hop_infra.py @@ -62,6 +62,7 @@ class TestHOPInfra(TestCase): FIXME_ALLOWLIST = { "autograd_function_apply", "run_with_rng_state", + "graphsafe_run_with_rng_state", "map_impl", "_export_tracepoint", "run_and_save_rng_state", diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 7e13f5313d2a..509fd643eddd 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -66,6 +66,7 @@ from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta from .utils import ( _get_symint_hints, contain_metadata_mutation_ops, + get_cuda_generator_meta_val, make_boxed_func, strict_zip, unlift_tokens, @@ -458,9 +459,20 @@ def aot_dispatch_autograd( fake_mode = detect_fake_mode() if fake_mode is not None and fake_mode.shape_env is not None: tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode) + fw_module, bw_module = aot_config.partition_fn( fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs ) + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = ( + rng_states[0].meta["val"].device.index + ) # See Note [Side-Effectful Tokens in AOTAutograd] if config.unlift_effect_tokens and ( @@ -684,6 +696,16 @@ def aot_dispatch_autograd( functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper( return_new_outs=False ) + + if rng_states: + index = fw_metadata.graphsafe_rng_state_index + assert index is not None + rng_states = [ + get_cuda_generator_meta_val(index) + for _ in range(fw_metadata.num_graphsafe_rng_states) + ] + adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] + ( fw_module, adjusted_flat_args, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 06c76c9f6617..a4b1f7550544 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1684,6 +1684,45 @@ def _backward_prologue_functional( return all_args +def initialize_rng_states( + num_rng: int, + graphsafe_idx: int, + fwd_rng_states: list[torch.Generator], + bwd_rng_states: list[torch.Generator], +): + """ + Initialize the cudagraph safe rng states. + + Initialization of rng states should have a few properties: + - the initialization for each rng state should be independent + - the initialization should be deterministic + - the initialization should be based off current rng state, so that independent graphs do not + have equal rng behavior + + We defer initialization of rng states until runtime because compilation is wrapped + with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations + do not give equal randomness. + """ + with torch.utils._python_dispatch._disable_current_modes(): + seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu") + fwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + bwd_rng_states.extend( + [ + torch.cuda.default_generators[graphsafe_idx] + .clone_state() + .manual_seed(int(seeds[i])) + for i in range(num_rng) + ] + ) + + # NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants. def _backward_epilogue_functional( metadata, maybe_subclass_metadata, out, *, make_subclass_override=None @@ -1819,6 +1858,34 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa fw_metadata: ViewAndMutationMeta, # runtime metadata try_save_cache_entry: Optional[Callable], # Save cache entry after compilation ): + # For additional context see Note [CUDA Graph Safe RNG Functionalization] + # Each pair forward, backward rng states must be equal prior to its invocation on any + # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, + # running forward then backward advances them the same amount and keeps them equal. + # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. + # Initially we have: + # fwd_state0 == bwd_state0. + # Lets say we run: + # fwd0: fwd_state0 -> fwd_state1 + # fwd1: fwd_state1 -> fwd_state2 + # fwd2: fwd_state2 -> fwd_state3 + # If we now invoke bwd2, + # we need to update bwd_state equal to the rng that was observed in fwd2. + # we save the rng_state fwd_state2 in forward because we detect that it is not the + # current backward state and therefore would not be accessible if we do not save it. + # Similarly, if we are going to update the backward state to a new value, and there is a pending + # forwards which needs its current state, we will save it. + # Within the autograd context, we keep track of the curr iteration so that on backward + # we know what the generator state must be before the backward is run. + num_rng = fw_metadata.num_graphsafe_rng_states + graphsafe_idx = fw_metadata.graphsafe_rng_state_index + fwd_rng_states: list[torch.Generator] = [] + bwd_rng_states: list[torch.Generator] = [] + curr_fwd_iter = itertools.count(0) + backward_state_position = 0 + pending_forwards: set[int] = set() + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} + class CompiledFunction(torch.autograd.Function): compiled_fw = compiled_fw_func compiled_bw = compiled_bw_func @@ -1840,6 +1907,26 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa assert isinstance(bw_state, BackwardState) ctx._compiled_autograd_backward_state = bw_state + if num_rng: + if len(fwd_rng_states) == 0: + assert graphsafe_idx is not None + initialize_rng_states( + num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states + ) + + _curr_iter = next(curr_fwd_iter) + ctx._curr_iter = _curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if _curr_iter != backward_state_position: + saved_backward_tensor_states[_curr_iter] = [ + rng_state.get_state() for rng_state in fwd_rng_states + ] + + pending_forwards.add(_curr_iter) + args = (*args, *fwd_rng_states) + # There is a pretty complicated calling convention around what the compiled fw returns. # The full list of outputs and their relative order is: # (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints) @@ -1967,6 +2054,45 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa *flat_args, ) + if num_rng: + nonlocal backward_state_position, bwd_rng_states + curr_backward_iter = ctx._curr_iter + retain_graph = ( + torch._C._autograd._get_current_graph_task_keep_graph() + ) + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + backward_state_position in pending_forwards + and backward_state_position not in saved_backward_tensor_states + and ( + backward_state_position != curr_backward_iter + or retain_graph + ) + ): + saved_backward_tensor_states[backward_state_position] = [ + rng_state.get_state() for rng_state in bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in saved_backward_tensor_states: + if backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + bwd_rng_states, + saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del saved_backward_tensor_states[curr_backward_iter] + else: + assert backward_state_position == curr_backward_iter + + backward_state_position = curr_backward_iter + 1 + if not retain_graph: + pending_forwards.remove(curr_backward_iter) + all_args.extend(bwd_rng_states) + def impl_fn(double_ctx=None): out = CompiledFunction._backward_impl(ctx, all_args) return _backward_epilogue_functional( diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 923ca12f7db5..6259d082e2ae 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -420,11 +420,16 @@ class ViewAndMutationMeta: # Filled after tracing joint function. num_backward_tokens: int = 0 + # Number of rng states that will get thread into the forward and backward for + # cudagraph compatible run_and_save_rng + num_graphsafe_rng_states: int = 0 + + graphsafe_rng_state_index: Optional[int] = None + def __post_init__(self): # pre-compute the indices of the inputs that are mutated. # When keep_input_mutations is set, we don't need to worry about our epilogue # handling data-only mutations, because we keep them directly in the graph. - mutated_inp_runtime_indices = [ i for i, m in enumerate(self.input_info) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 940e7e8829af..cc14a77244f6 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -489,3 +489,14 @@ def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool: ): return True return False + + +def get_cuda_generator_meta_val(device_idx: int): + """ + Get a generator value to use as a meta val + + newly cloned generator will not contain tensors. it is only Generators that are + registered to a CUDAGraph that contain tensors. since this does not contain Tensor + it is fine to use in the meta. + """ + return torch.cuda.default_generators[device_idx].clone_state() diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 58d471bce7f6..bfcaa27b313e 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -210,6 +210,10 @@ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg") # real tensor outputs. generate_fake_kernels_from_real_mismatches = False +# CUDAGraph save run_with_rng functionalization. +# TODO: turn on by default +graphsafe_rng_functionalization = True + # Error on BypassAOTAutogradCache instead of just a warning # Used for tests diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index f3aeb3765d9d..6e085924d8e2 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -40,7 +40,7 @@ from ._activation_checkpointing.knapsack import ( ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.logging_utils import get_aot_graph_name -from ._aot_autograd.utils import is_with_effects +from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target @@ -622,6 +622,99 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: return new_gm +def apply_graphsafe_rng_functionalization( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + fw_node: torch.fx.Node, + bw_node: torch.fx.Node, + device: torch.device, + rng_count: int, + last_fwd_input: torch.fx.Node, + last_bwd_input: torch.fx.Node, +): + """ + Note [CUDA Graph Safe RNG Functionalization] + + CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values, + while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a + CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer + to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator + (and its cuda-tensor RNG state during graph capture). + + For each RNG operation's forward/backward pair: + + - We create two generators initialized with identical values + - Each forward and backward call advances its respective generator equally + - This keeps generators synchronized so forward and backward operations use matching RNG values + + When forward is called multiple times before backward (causing desynchronization): + + - We save the forward RNG state + - We update the backward Generator's state before executing backward + + Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator + changes are reflected during replay. + + This function modifies both forward and backward computation graphs by: + + Creating RNG state placeholders for both passes + Updating the forward node to use graph-safe RNG state + Updating the backward node to use graph-safe RNG state + + For more details: https://github.com/pytorch/pytorch/issues/113541 + """ + device_idx = device.index + assert device_idx is not None + fw_graph = fw_module.graph + bw_graph = bw_module.graph + graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state + + # Handle forward pass + + # Note: [Generator arguments in AOTDispatcher] + # Generator arguments in AOTDispatcher are added to support graphsafe rng + # functionalization. See note above [CUDA Graph Safe RNG Functionalization] + with fw_module.graph.inserting_after(last_fwd_input): + fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}") + fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_fwd_input = fwd_rng_state + + # Handle backward pass + with bw_module.graph.inserting_after(last_bwd_input): + bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}") + # as above, clone so that meta val generator will not contain tensors + bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx) + last_bwd_input = bwd_rng_state + + # Update forward node + fw_kwargs = dict(fw_node.kwargs) + fw_kwargs["rng_state"] = fwd_rng_state + with fw_module.graph.inserting_after(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(fw_node.target, *fw_node.args), # type: ignore[arg-type] + kwargs=fw_kwargs, + ) + fw_node.replace_all_uses_with(functional_fw_node) + fw_graph.erase_node(fw_node) + + # Update backward node + bwd_kwargs = dict(bw_node.kwargs) + bwd_kwargs["rng_state"] = bwd_rng_state + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + graphsafe_run_with_rng_state, + args=(bw_node.target, *bw_node.args), # type: ignore[arg-type] + kwargs=bwd_kwargs, + ) + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) + + return last_fwd_input, last_bwd_input + + def functionalize_rng_ops( joint_module: fx.GraphModule, fw_module: fx.GraphModule, @@ -660,7 +753,7 @@ def functionalize_rng_ops( random_nodes[node.name] = node return random_nodes - def get_device(node): + def get_device(node) -> Optional[torch.device]: """ Check the example value of the node outputs to find the device type. """ @@ -674,12 +767,12 @@ def functionalize_rng_ops( for candidate in candidates: if isinstance(candidate, torch.Tensor): if candidate.device.type == "cuda": - return "cuda" + return candidate.device - return "cpu" + return torch.device("cpu") - def get_sample_rng_state(device): - if device == "cuda": + def get_sample_rng_state(device: Optional[torch.device]): + if device is not None and device.type == "cuda": return torch.cuda.get_rng_state() return torch.get_rng_state() @@ -701,6 +794,7 @@ def functionalize_rng_ops( run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state run_with_rng_state = torch._prims.rng_prims.run_with_rng_state + bw_tangent_start_node = None for node in bw_module.graph.find_nodes(op="placeholder"): if "tangent" in node.name: @@ -712,68 +806,113 @@ def functionalize_rng_ops( ) fw_rng_state_outputs = [] - for base_node, node_pair in recomputable_rng_ops_map.items(): + + last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder"))) + last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder"))) + + devices = OrderedSet( + get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values() + ) + devices.discard(torch.device("cpu")) + # multiple cuda devices wont work with cudagraphs anyway, + # fallback to non graphsafe rng checkpointing + multi_cuda_devices = len(devices) > 1 + + # this changes numerics, so if fallback_random is set we will not use it + ind_config = torch._inductor.config + use_rng_graphsafe_rng_functionalization = ( + config.graphsafe_rng_functionalization + and not multi_cuda_devices + and ( + not ind_config.fallback_random + or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random + ) + ) + + for rng_count, (base_node, node_pair) in enumerate( + recomputable_rng_ops_map.items() + ): # Step 2 - Modify the fwd pass such that fw_node = node_pair["fwd"] bw_node = node_pair["bwd"] + device = get_device(fw_node) + fw_graph = fw_module.graph - with fw_graph.inserting_before(fw_node): - functional_fw_node = fw_graph.create_node( - "call_function", - run_and_save_rng, - args=(fw_node.target, *fw_node.args), - kwargs=fw_node.kwargs, - ) - state = fw_graph.create_node( - "call_function", - operator.getitem, - args=(functional_fw_node, 0), - kwargs={}, - ) - rng_output = fw_graph.create_node( - "call_function", - operator.getitem, - args=( - functional_fw_node, - 1, - ), - kwargs={}, - ) - fw_node.replace_all_uses_with(rng_output) - fw_graph.erase_node(fw_node) - fw_rng_state_outputs.append(state) - - # Step 3 - Modify the bwd pass such that bw_graph = bw_module.graph - with bw_graph.inserting_before(bw_tangent_start_node): - state_name = f"rng_state_output_{next(uid)}" - bw_rng_state_node = bw_graph.placeholder(state_name) - bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) - with bw_graph.inserting_before(bw_node): - rng_output = bw_graph.create_node( - "call_function", - run_with_rng_state, - args=(bw_rng_state_node, bw_node.target, *bw_node.args), - kwargs=bw_node.kwargs, + if ( + use_rng_graphsafe_rng_functionalization + and device is not None + and device.type == "cuda" + ): + last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization( + fw_module, + bw_module, + fw_node, + bw_node, + device, + rng_count, + last_fwd_input, + last_bwd_input, ) + else: + with fw_graph.inserting_before(fw_node): + functional_fw_node = fw_graph.create_node( + "call_function", + run_and_save_rng, + args=(fw_node.target, *fw_node.args), + kwargs=fw_node.kwargs, + ) + state = fw_graph.create_node( + "call_function", + operator.getitem, + args=(functional_fw_node, 0), + kwargs={}, + ) + rng_output = fw_graph.create_node( + "call_function", + operator.getitem, + args=( + functional_fw_node, + 1, + ), + kwargs={}, + ) + fw_node.replace_all_uses_with(rng_output) + fw_graph.erase_node(fw_node) + fw_rng_state_outputs.append(state) - bw_node.replace_all_uses_with(rng_output) - bw_graph.erase_node(bw_node) + # Step 3 - Modify the bwd pass such that + with bw_graph.inserting_before(bw_tangent_start_node): + state_name = f"rng_state_output_{next(uid)}" + bw_rng_state_node = bw_graph.placeholder(state_name) + bw_rng_state_node.meta["val"] = get_sample_rng_state(device) + + with bw_graph.inserting_before(bw_node): + rng_output = bw_graph.create_node( + "call_function", + run_with_rng_state, + args=(bw_rng_state_node, bw_node.target, *bw_node.args), + kwargs=bw_node.kwargs, + ) + + bw_node.replace_all_uses_with(rng_output) + bw_graph.erase_node(bw_node) # Add the rng states in the output of the fwd graph. AOT Autograd assumes # that symints are at the end of forward graph outputs. So, insert the new # rng states accordingly. - fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) - fw_outputs = fw_output_node.args[0] - sym_node_start_idx = len(fw_outputs) - num_sym_nodes - outputs = ( - fw_outputs[:sym_node_start_idx] - + tuple(fw_rng_state_outputs) - + fw_outputs[sym_node_start_idx:] - ) - fw_module.graph.output(outputs) - fw_module.graph.erase_node(fw_output_node) + if fw_rng_state_outputs: + fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) + fw_outputs = fw_output_node.args[0] + sym_node_start_idx = len(fw_outputs) - num_sym_nodes + outputs = ( + fw_outputs[:sym_node_start_idx] + + tuple(fw_rng_state_outputs) + + fw_outputs[sym_node_start_idx:] + ) + fw_module.graph.output(outputs) + fw_module.graph.erase_node(fw_output_node) fw_module.recompile() bw_module.recompile() return fw_module, bw_module @@ -1849,7 +1988,6 @@ def min_cut_rematerialization_partition( saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs, ) - if graph_has_recomputable_ops: if graph_has_recomputable_rng_ops: fw_module, bw_module = functionalize_rng_ops( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 90aba45ce0ac..8872b049b8af 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -898,7 +898,9 @@ class PythonWrapperCodegen(CodeGen): continue # a graph partition may take an IRNode output from a previous partition - if name not in V.graph.graph_input_names: + if name not in V.graph.graph_input_names or isinstance( + buf, ir.GeneratorState + ): continue # comparing strides for 0 size tensor is tricky. Ignore them for now. @@ -1418,7 +1420,9 @@ class PythonWrapperCodegen(CodeGen): code.writeline(f"{stride} = {strideof(name)}[{dim}]") bound_vars.add(stride) elif isinstance(value, ir.TorchBindObject): - pass + return + elif isinstance(value, ir.GeneratorState): + return else: if torch._inductor.config.graph_partition: pass @@ -1612,6 +1616,11 @@ class PythonWrapperCodegen(CodeGen): # is actually a valid value for the kernel in question. # See https://github.com/pytorch/pytorch/issues/124686 add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + elif isinstance(value, ir.GeneratorState): + add_expr_input( + name, + f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", + ) else: shape = [ V.graph.sizevars.size_hint(x, fallback=42) @@ -2287,6 +2296,8 @@ class PythonWrapperCodegen(CodeGen): return s.codegen_reference() elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] return dtype_to_string(s) + elif isinstance(s, ir.GeneratorState): + return s.codegen_reference() else: return repr(s) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c563c3f1118d..18d25cf342cf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1517,6 +1517,8 @@ class test_configs: autotune_choice_name_regex: Optional[str] = None autotune_choice_desc_regex: Optional[str] = None + graphsafe_rng_func_ignores_fallback_random = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 5421eecc6a7f..a1bbbb1f39d1 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -908,6 +908,7 @@ class CUDAGraphNode: self.recorded_liveness_before_graph = curr_liveness self.expected_dead_indices_before_graph = different_indices + rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)] recording_inputs = self._allocate_and_copy_recording_inputs(inputs) # recording inputs will copy over memory, so we can free non recording inputs inputs.clear() @@ -916,6 +917,11 @@ class CUDAGraphNode: # graph used for recording model invocation self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + # TODO: register_generator_state should potentially take explicit device + with torch.cuda.device(self.device): + for rng_state in rng_states: + self.graph.register_generator_state(rng_state) + # we allocate non-static inputs within the same memory pool as the CUDAGraph # which we will record the model with. For memory efficiency, it is important # to reclaim the input memory when the inputs are no longer live. To accomplish this, @@ -1602,7 +1608,7 @@ class CUDAGraphNode: ): for i, inp in enumerate(inputs): if not isinstance(inp, torch.Tensor): - assert isinstance(inp, int) + assert isinstance(inp, (int, torch.Generator)) recording_inputs.append(inp) elif i not in self.static_input_idxs: # static_input does an allocation! diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index f9bc10db1700..1c77f68b0b93 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1040,6 +1040,18 @@ class GraphLowering(torch.fx.Interpreter): # Alternately we could filter this out in AotAutograd self.graph_input_names.append(target) return None + # See note: Note: [Generator arguments in AOTDispatcher] + elif isinstance(example, torch.Generator): + assert ( + len(V.graph.current_node.users) == 1 + and next(iter(V.graph.current_node.users)).target + is torch._prims.rng_prims.graphsafe_run_with_rng_state + ) + gen = ir.GeneratorState(name=target, device=example.device) + self.graph_inputs[target] = gen # type: ignore[assignment] + self.graph_input_names.append(target) + return gen + assert isinstance(example, torch.Tensor), example # todo(chilli): We can remove the last check once we turn buffers into # static shape tensors. That's a hack to workaround Inductor believing @@ -1288,7 +1300,7 @@ class GraphLowering(torch.fx.Interpreter): if isinstance(value, TorchBindObject): continue assert isinstance( - value, (TensorBox, sympy.Expr) + value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState) ), f"Unsupported inductor graph input type: {type(value)}" if not isinstance(value, TensorBox): continue diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3fa472645aec..9c128e3949e2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4977,7 +4977,9 @@ class ExternKernel(InputsKernel): tensor_args = [] non_tensor_args: list[Any] = [] for arg in args_flat: - is_arg_tensor.append(isinstance(arg, IRNode)) + is_arg_tensor.append( + isinstance(arg, IRNode) and not isinstance(arg, GeneratorState) + ) if is_arg_tensor[-1]: tensor_args.append(arg) else: @@ -5008,7 +5010,9 @@ class ExternKernel(InputsKernel): # Rerun fake tensor propagation, because Inductor may have changed the # strides of inputs and we need to determine accurately what the # output stride will be. - example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = [] + example_args: list[ + Union[torch.Tensor, torch._C.ScriptObject, torch.Generator] + ] = [] # We need to retain the constant values of fake tensors that we originally # propagated the graph with, because for some operators running without a @@ -5025,6 +5029,12 @@ class ExternKernel(InputsKernel): example_args.append(V.graph.torchbind_constants[x.get_name()]) elif isinstance(x, TorchBindObject): example_args.append(x.get_real_obj()) + elif isinstance(x, torch._inductor.ir.GeneratorState): + device_index = x.device.index + assert x.device.type == "cuda" and device_index is not None + example_args.append( + torch.cuda.default_generators[device_index].clone_state() + ) else: example_args.append(ir_node_to_tensor(x, guard_shape=True)) @@ -5155,7 +5165,7 @@ class ExternKernel(InputsKernel): # TODO(jansel): impose layout preference on realized buffer x.realize() return x - if isinstance(x, TorchBindObject): + if isinstance(x, (NonTensorObj)): return x return cls.copy_input(x) @@ -7570,8 +7580,12 @@ class EffectfulKernel(FallbackKernel): return True +class NonTensorObj(IRNode): + pass + + @ir_dataclass -class TorchBindObject(IRNode): +class TorchBindObject(NonTensorObj): from torch._library.fake_class_registry import FakeScriptObject name: str @@ -7605,6 +7619,18 @@ class TorchBindObject(IRNode): return functools.reduce(lambda x, y: x + y, flat_sizes, 0) +@ir_dataclass +class GeneratorState(NonTensorObj): + name: str + device: torch.device + + def get_name(self): # type: ignore[no-untyped-def] + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + class _CollectiveKernel(FallbackKernel): def should_allocate(self) -> bool: return False diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 40780a18b692..21c569f2dc7d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2718,6 +2718,8 @@ make_fallback(aten.gcd.default, warn=False) make_fallback(aten._thnn_fused_lstm_cell, require_dense) make_fallback(torch._prims.rng_prims.run_and_save_rng_state) make_fallback(torch._prims.rng_prims.run_with_rng_state) +make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state) + # Implmented / Half implemented # Scans. Implemented for CUDA, missing CPU diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 9637ce8c4c29..dd0adb36c49c 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -426,7 +426,7 @@ class CompiledFxGraph(OutputCode): (not complex_memory_overlap_inputs, "complex memory overlap"), ( all( - isinstance(t, (torch.Tensor, torch.SymInt)) + isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator)) for t in example_inputs ), "non-Tensor inputs", diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f26d3312052..cb3f8a5417f1 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2104,6 +2104,7 @@ def count_tangents(fx_g: torch.fx.GraphModule) -> int: "tangents" not in x.name and "bwd_seed" not in x.name and "bwd_base_offset" not in x.name + and "bwd_rng_state" not in x.name ) arg_count = 0 diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index d4d9203ef6ab..70b4bc472358 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -315,5 +315,75 @@ run_and_save_rng_state = register_run_and_save_rng_state_op() run_with_rng_state = register_run_with_rng_state_op() +def register_graphsafe_run_with_rng_state_op(): + class GraphSafeRunWithRngState(HigherOrderOperator): + def __init__(self): + super().__init__("graphsafe_run_with_rng_state") + + def __call__(self, op, *args, rng_state=None, **kwargs): + return super().__call__(op, *args, rng_state=rng_state, **kwargs) + + graphsafe_run_with_rng_state = GraphSafeRunWithRngState() + + graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True) + ) + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, rng_state=None, **kwargs): + device_idx = rng_state.device.index + generator = torch.cuda.default_generators[device_idx] + current_state = generator.graphsafe_get_state() + generator.graphsafe_set_state(rng_state) + out = op(*args, **kwargs) + generator.graphsafe_set_state(current_state) + return out + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, rng_state=None, **kwargs): + device = get_device(args, kwargs) + assert ( + device == "cuda" + ), f"GraphSafe RNG operations only supported for CUDA, got {device}" + return impl_cuda(op, *args, rng_state=rng_state, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs): + with mode: + return op(*args, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs): + with disable_proxy_modes_tracing(): + out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) + proxy_kwargs = pytree.tree_map( + mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs} + ) + out_proxy = mode.tracer.create_proxy( + "call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + @graphsafe_run_with_rng_state.py_functionalize_impl + def impl_functional(ctx, op, *args, rng_state=None, **kwargs): + unwrapped_rng_state = ( + ctx.unwrap_tensors(rng_state) if rng_state is not None else None + ) + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = graphsafe_run_with_rng_state( + op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs + ) + return ctx.wrap_tensors(out) + + return graphsafe_run_with_rng_state + + +graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op() + + def register_rng_prims(): register_philox_rand() diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7a4ecbf705ec..368ba88dcd11 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -819,6 +819,9 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: yield from _iterate_exprs(val.storage_offset()) elif val is None: pass + # see Note: [Generator arguments in AOTDispatcher] + elif isinstance(val, torch.Generator): + pass else: raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 37c44ab7f725..814df49f0f71 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -77,6 +77,7 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [ "autograd_function_apply", "run_and_save_rng_state", "run_with_rng_state", + "graphsafe_run_with_rng_state", "out_dtype", "trace_wrapped", 'tag_activation_checkpoint',