mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO: - [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync - [x] Update rng state initialization to take from correct device - [x] Tests - [x] handling of retain_graph - [x] respect fallback random Fix for https://github.com/pytorch/pytorch/issues/130123. Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states. We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward. ``` ===== Forward graph 1 ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None ... ===== Backward graph 1 ===== def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None ``` There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls: - fwd0: fwd_rng_state0 -> fwd_rng_state1 - fwd1: fwd_rng_state1 -> fwd_rng_state2 - bwd1 - bwd0 Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary. Other notes: Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order. Questions for reviewers: This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`. Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts. Edit: updated to be taken from randint() Update: initializing rng states from torch.randint.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878 Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6d1038aaa
commit
481a57bc37
@ -54,6 +54,8 @@ def count_ops(
|
|||||||
return node.args[0] == op
|
return node.args[0] == op
|
||||||
elif node.name == "run_with_rng_state":
|
elif node.name == "run_with_rng_state":
|
||||||
return node.args[1] == op
|
return node.args[1] == op
|
||||||
|
elif node.name == "graphsafe_run_with_rng_state":
|
||||||
|
return node.args[0] == op
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
|
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
|
||||||
@ -1018,6 +1020,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
|||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||||
|
@torch._inductor.config.patch(fallback_random=True)
|
||||||
def test_compile_selective_checkpoint_random_op(self, device):
|
def test_compile_selective_checkpoint_random_op(self, device):
|
||||||
for preserve_rng_state in [True, False]:
|
for preserve_rng_state in [True, False]:
|
||||||
|
|
||||||
|
@ -4,13 +4,17 @@ import contextlib
|
|||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
|
import itertools
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.config as dynamo_config
|
import torch._dynamo.config as dynamo_config
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
@ -19,7 +23,9 @@ from torch._inductor.compile_fx import compile_fx_inner
|
|||||||
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
|
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
|
||||||
from torch._inductor.cudagraph_utils import FunctionID
|
from torch._inductor.cudagraph_utils import FunctionID
|
||||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||||
|
from torch._ops import OpOverload
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
from torch.fx.immutable_collections import immutable_dict
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
@ -32,6 +38,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
TEST_CUDA_GRAPH,
|
TEST_CUDA_GRAPH,
|
||||||
TEST_WITH_ASAN,
|
TEST_WITH_ASAN,
|
||||||
)
|
)
|
||||||
|
from torch.utils._mode_utils import no_dispatch
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +53,7 @@ if IS_WINDOWS and IS_CI:
|
|||||||
importlib.import_module("functorch")
|
importlib.import_module("functorch")
|
||||||
importlib.import_module("filelock")
|
importlib.import_module("filelock")
|
||||||
|
|
||||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
@ -2504,7 +2511,407 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
|||||||
eager_result = f(example_input)
|
eager_result = f(example_input)
|
||||||
self.assertEqual(compiled_result, eager_result)
|
self.assertEqual(compiled_result, eager_result)
|
||||||
|
|
||||||
|
class TestSAC(TestCase):
|
||||||
|
def _make_observer_mode(self):
|
||||||
|
class ObserverMode(TorchDispatchMode):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.curr_run = 0
|
||||||
|
self.op_outputs = defaultdict(list)
|
||||||
|
|
||||||
|
def __torch_dispatch__(
|
||||||
|
self,
|
||||||
|
func: OpOverload,
|
||||||
|
types: Sequence[type],
|
||||||
|
args: Sequence[object] = (),
|
||||||
|
kwargs: Mapping[str, object] = immutable_dict(),
|
||||||
|
) -> object:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return ObserverMode
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||||
|
|
||||||
|
ObserverMode = self._make_observer_mode()
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||||
|
def _(mode, op, *args, **kwargs):
|
||||||
|
with no_dispatch():
|
||||||
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||||
|
|
||||||
|
mode.op_outputs[op].append(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
obs = ObserverMode()
|
||||||
|
|
||||||
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
aot_eager_decomp_partition = functools.partial(
|
||||||
|
aot_eager_decomp_partition_with_mode, mode=obs
|
||||||
|
)
|
||||||
|
|
||||||
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||||
|
|
||||||
|
fn(x, y).sum().backward()
|
||||||
|
|
||||||
|
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4)
|
||||||
|
for i in range(2):
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[aten.rand.default][0 + 2 * i],
|
||||||
|
obs.op_outputs[aten.rand.default][1 + 2 * i],
|
||||||
|
)
|
||||||
|
self.assertNotEqual(
|
||||||
|
obs.op_outputs[aten.rand.default][0],
|
||||||
|
obs.op_outputs[aten.rand.default][2],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cudagraph_uneven_forward_backward(self):
|
||||||
|
# torch.compile cudagraphs are difficult to test
|
||||||
|
# the rng updating bc is sensitive to duration of pending backwards, etc.
|
||||||
|
# this is a short repro to mimic the runtime wrappers integration
|
||||||
|
# and show that updating the backward rng state with cudagraphs works:
|
||||||
|
def forward():
|
||||||
|
state = torch.cuda.get_rng_state()
|
||||||
|
perm = torch.randperm(10, device="cuda")
|
||||||
|
return state, perm
|
||||||
|
|
||||||
|
def backward(rng_state):
|
||||||
|
current_state = torch.cuda.get_rng_state()
|
||||||
|
torch.cuda.set_rng_state(rng_state.cpu())
|
||||||
|
perm = torch.randperm(10, device="cuda")
|
||||||
|
torch.cuda.set_rng_state(current_state)
|
||||||
|
return perm
|
||||||
|
|
||||||
|
def normal_test():
|
||||||
|
state, perm = forward()
|
||||||
|
repro_perm = backward(state)
|
||||||
|
return perm, repro_perm
|
||||||
|
|
||||||
|
def graphsafe_forward():
|
||||||
|
perm = torch.randperm(10, device="cuda")
|
||||||
|
return perm
|
||||||
|
|
||||||
|
def graphsafe_backward(generator, new_state):
|
||||||
|
current_state = generator.graphsafe_get_state()
|
||||||
|
generator.graphsafe_set_state(new_state)
|
||||||
|
perm = torch.randperm(10, device="cuda")
|
||||||
|
generator.graphsafe_set_state(current_state)
|
||||||
|
return perm
|
||||||
|
|
||||||
|
def graph_test(generator, capture_cuda_graph):
|
||||||
|
if capture_cuda_graph:
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
# state should be cloned before the graph
|
||||||
|
old_state = generator.graphsafe_get_state()
|
||||||
|
new_state = old_state.clone_state()
|
||||||
|
|
||||||
|
if capture_cuda_graph:
|
||||||
|
# state should be register to the graph
|
||||||
|
graph.register_generator_state(new_state)
|
||||||
|
|
||||||
|
# only capturing the backward
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
repro_perm = graphsafe_backward(generator, new_state)
|
||||||
|
|
||||||
|
# some number of uneven forwards
|
||||||
|
graphsafe_forward()
|
||||||
|
graphsafe_forward()
|
||||||
|
graphsafe_forward()
|
||||||
|
|
||||||
|
# state prior to rng invocation
|
||||||
|
state = generator.get_state()
|
||||||
|
perm = graphsafe_forward()
|
||||||
|
|
||||||
|
new_state.set_state(state)
|
||||||
|
|
||||||
|
if capture_cuda_graph:
|
||||||
|
graph.replay()
|
||||||
|
else:
|
||||||
|
repro_perm = graphsafe_backward(generator, new_state)
|
||||||
|
|
||||||
|
return perm, repro_perm
|
||||||
|
|
||||||
|
self.assertEqual(*normal_test())
|
||||||
|
generator = torch.cuda.default_generators[0]
|
||||||
|
self.assertEqual(*graph_test(generator, capture_cuda_graph=False))
|
||||||
|
self.assertEqual(*graph_test(generator, capture_cuda_graph=True))
|
||||||
|
|
||||||
|
def test_cpu_and_cuda_rng(self):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
ObserverMode = self._make_observer_mode()
|
||||||
|
from torch._prims.rng_prims import (
|
||||||
|
graphsafe_run_with_rng_state,
|
||||||
|
run_and_save_rng_state,
|
||||||
|
run_with_rng_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
for hop in [
|
||||||
|
graphsafe_run_with_rng_state,
|
||||||
|
run_and_save_rng_state,
|
||||||
|
run_with_rng_state,
|
||||||
|
]:
|
||||||
|
|
||||||
|
def make_impl(hop):
|
||||||
|
@hop.py_impl(ObserverMode)
|
||||||
|
def _(mode, *args, **kwargs):
|
||||||
|
with no_dispatch():
|
||||||
|
out = hop(*args, **kwargs)
|
||||||
|
|
||||||
|
op = None
|
||||||
|
for inp in itertools.chain(args, kwargs.values()):
|
||||||
|
if isinstance(inp, torch._ops.OpOverload):
|
||||||
|
op = inp
|
||||||
|
break
|
||||||
|
assert op is not None
|
||||||
|
if hop is run_and_save_rng_state:
|
||||||
|
mode.op_outputs[op].append(out[1])
|
||||||
|
else:
|
||||||
|
mode.op_outputs[op].append(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
make_impl(hop)
|
||||||
|
|
||||||
|
obs = ObserverMode()
|
||||||
|
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def gn2(x):
|
||||||
|
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
||||||
|
|
||||||
|
def fn(x, y, z):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True)
|
||||||
|
return x * z.cuda()
|
||||||
|
|
||||||
|
aot_eager_decomp_partition = functools.partial(
|
||||||
|
aot_eager_decomp_partition_with_mode, mode=obs
|
||||||
|
)
|
||||||
|
|
||||||
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||||
|
|
||||||
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
z = torch.randn(4, 4, requires_grad=True)
|
||||||
|
|
||||||
|
fn(x, y, z).sum().backward()
|
||||||
|
for op in [aten.rand.default, aten.randperm.default]:
|
||||||
|
self.assertEqual(len(obs.op_outputs[op]), 2)
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[op][0],
|
||||||
|
obs.op_outputs[op][1],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[op][0].device.type,
|
||||||
|
"cpu" if op == aten.randperm.default else "cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
@parametrize("order", (list(itertools.permutations([0, 1, 2]))))
|
||||||
|
def test_uneven_forward_backward(self, order):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
ObserverMode = self._make_observer_mode()
|
||||||
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||||
|
def _(mode, op, *args, **kwargs):
|
||||||
|
with no_dispatch():
|
||||||
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||||
|
|
||||||
|
mode.op_outputs[(mode.curr_run, op)].append(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
obs = ObserverMode()
|
||||||
|
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def gn2(x):
|
||||||
|
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
aot_eager_decomp_partition = functools.partial(
|
||||||
|
aot_eager_decomp_partition_with_mode, mode=obs
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_c = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
outs = []
|
||||||
|
for i in range(len(order)):
|
||||||
|
obs.curr_run = i
|
||||||
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
outs.append(fn_c(x, y))
|
||||||
|
|
||||||
|
for idx in order:
|
||||||
|
obs.curr_run = idx
|
||||||
|
outs[idx].sum().backward()
|
||||||
|
|
||||||
|
for run in range(len(order)):
|
||||||
|
for op in (aten.rand.default, aten.randperm.default):
|
||||||
|
self.assertEqual(len(obs.op_outputs[(run, op)]), 2)
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[(run, op)][0],
|
||||||
|
obs.op_outputs[(run, op)][1],
|
||||||
|
)
|
||||||
|
if run != 0:
|
||||||
|
self.assertNotEqual(
|
||||||
|
obs.op_outputs[(run - 1, op)][0],
|
||||||
|
obs.op_outputs[(run, op)][0],
|
||||||
|
)
|
||||||
|
|
||||||
|
@config.patch(fallback_random=True)
|
||||||
|
@config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True)
|
||||||
|
def _test_cudagraphs_aot_eager_compat_equal(self, device):
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
grads = []
|
||||||
|
|
||||||
|
outs2 = []
|
||||||
|
grads2 = []
|
||||||
|
|
||||||
|
compile_fns = [
|
||||||
|
lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"),
|
||||||
|
lambda fn: torch.compile(fn, mode="reduce-overhead"),
|
||||||
|
]
|
||||||
|
for i, compile_fn in enumerate(compile_fns):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
for index in range(3):
|
||||||
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
out = compile_fn(fn)(x, y)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
out.sum().backward()
|
||||||
|
if i == 0:
|
||||||
|
outs.append(out.clone())
|
||||||
|
grads.append((x.grad.clone(), y.grad.clone()))
|
||||||
|
else:
|
||||||
|
outs2.append(out.clone())
|
||||||
|
grads2.append((x.grad.clone(), y.grad.clone()))
|
||||||
|
|
||||||
|
self.assertEqual(outs, outs2)
|
||||||
|
self.assertEqual(grads, grads2)
|
||||||
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
||||||
|
|
||||||
|
def test_cudagraphs_aot_eager_compat_equal(self):
|
||||||
|
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
|
||||||
|
|
||||||
|
@requires_multigpu()
|
||||||
|
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
|
||||||
|
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
|
||||||
|
|
||||||
|
@requires_multigpu()
|
||||||
|
def test_multi_device(self):
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def multi_fn(x, y, a, b):
|
||||||
|
return fn(x, y), fn(a, b)
|
||||||
|
|
||||||
|
x = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
||||||
|
|
||||||
|
a = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
||||||
|
b = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
||||||
|
|
||||||
|
# No errors. TODO - get graphs from logging, couldnt figure out how
|
||||||
|
multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition")
|
||||||
|
|
||||||
|
out = multi_fn_c(x, y, a, b)
|
||||||
|
out[0].sum().backward()
|
||||||
|
|
||||||
|
def test_retain_graph(self):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
ObserverMode = self._make_observer_mode()
|
||||||
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||||
|
def _(mode, op, *args, **kwargs):
|
||||||
|
with no_dispatch():
|
||||||
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||||
|
|
||||||
|
mode.op_outputs[op].append(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
obs = ObserverMode()
|
||||||
|
|
||||||
|
def gn(x, y):
|
||||||
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
x = torch.sin(x)
|
||||||
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||||
|
x = torch.sin(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
aot_eager_decomp_partition = functools.partial(
|
||||||
|
aot_eager_decomp_partition_with_mode, mode=obs
|
||||||
|
)
|
||||||
|
|
||||||
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||||
|
|
||||||
|
out = fn(x, y).sum()
|
||||||
|
out.backward(retain_graph=True)
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3)
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[aten.rand.default][0],
|
||||||
|
obs.op_outputs[aten.rand.default][1],
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
obs.op_outputs[aten.rand.default][1],
|
||||||
|
obs.op_outputs[aten.rand.default][2],
|
||||||
|
)
|
||||||
|
|
||||||
instantiate_parametrized_tests(CudaGraphTreeTests)
|
instantiate_parametrized_tests(CudaGraphTreeTests)
|
||||||
|
instantiate_parametrized_tests(TestSAC)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._inductor.test_case import run_tests
|
from torch._inductor.test_case import run_tests
|
||||||
@ -2514,5 +2921,5 @@ if __name__ == "__main__":
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
raise unittest.SkipTest("cuda graph test is skipped")
|
raise unittest.SkipTest("cuda graph test is skipped")
|
||||||
|
|
||||||
if HAS_CPU or HAS_CUDA:
|
if HAS_CUDA:
|
||||||
run_tests(needs="filelock")
|
run_tests(needs="filelock")
|
||||||
|
@ -62,6 +62,7 @@ class TestHOPInfra(TestCase):
|
|||||||
FIXME_ALLOWLIST = {
|
FIXME_ALLOWLIST = {
|
||||||
"autograd_function_apply",
|
"autograd_function_apply",
|
||||||
"run_with_rng_state",
|
"run_with_rng_state",
|
||||||
|
"graphsafe_run_with_rng_state",
|
||||||
"map_impl",
|
"map_impl",
|
||||||
"_export_tracepoint",
|
"_export_tracepoint",
|
||||||
"run_and_save_rng_state",
|
"run_and_save_rng_state",
|
||||||
|
@ -66,6 +66,7 @@ from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
_get_symint_hints,
|
_get_symint_hints,
|
||||||
contain_metadata_mutation_ops,
|
contain_metadata_mutation_ops,
|
||||||
|
get_cuda_generator_meta_val,
|
||||||
make_boxed_func,
|
make_boxed_func,
|
||||||
strict_zip,
|
strict_zip,
|
||||||
unlift_tokens,
|
unlift_tokens,
|
||||||
@ -458,9 +459,20 @@ def aot_dispatch_autograd(
|
|||||||
fake_mode = detect_fake_mode()
|
fake_mode = detect_fake_mode()
|
||||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||||
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
|
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
|
||||||
|
|
||||||
fw_module, bw_module = aot_config.partition_fn(
|
fw_module, bw_module = aot_config.partition_fn(
|
||||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||||
)
|
)
|
||||||
|
rng_states = [
|
||||||
|
n
|
||||||
|
for n in fw_module.graph.find_nodes(op="placeholder")
|
||||||
|
if "fwd_rng_state" in n.name
|
||||||
|
]
|
||||||
|
fw_metadata.num_graphsafe_rng_states = len(rng_states)
|
||||||
|
if rng_states:
|
||||||
|
fw_metadata.graphsafe_rng_state_index = (
|
||||||
|
rng_states[0].meta["val"].device.index
|
||||||
|
)
|
||||||
|
|
||||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||||
if config.unlift_effect_tokens and (
|
if config.unlift_effect_tokens and (
|
||||||
@ -684,6 +696,16 @@ def aot_dispatch_autograd(
|
|||||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
|
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
|
||||||
return_new_outs=False
|
return_new_outs=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if rng_states:
|
||||||
|
index = fw_metadata.graphsafe_rng_state_index
|
||||||
|
assert index is not None
|
||||||
|
rng_states = [
|
||||||
|
get_cuda_generator_meta_val(index)
|
||||||
|
for _ in range(fw_metadata.num_graphsafe_rng_states)
|
||||||
|
]
|
||||||
|
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
|
||||||
|
|
||||||
(
|
(
|
||||||
fw_module,
|
fw_module,
|
||||||
adjusted_flat_args,
|
adjusted_flat_args,
|
||||||
|
@ -1684,6 +1684,45 @@ def _backward_prologue_functional(
|
|||||||
return all_args
|
return all_args
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_rng_states(
|
||||||
|
num_rng: int,
|
||||||
|
graphsafe_idx: int,
|
||||||
|
fwd_rng_states: list[torch.Generator],
|
||||||
|
bwd_rng_states: list[torch.Generator],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the cudagraph safe rng states.
|
||||||
|
|
||||||
|
Initialization of rng states should have a few properties:
|
||||||
|
- the initialization for each rng state should be independent
|
||||||
|
- the initialization should be deterministic
|
||||||
|
- the initialization should be based off current rng state, so that independent graphs do not
|
||||||
|
have equal rng behavior
|
||||||
|
|
||||||
|
We defer initialization of rng states until runtime because compilation is wrapped
|
||||||
|
with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations
|
||||||
|
do not give equal randomness.
|
||||||
|
"""
|
||||||
|
with torch.utils._python_dispatch._disable_current_modes():
|
||||||
|
seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu")
|
||||||
|
fwd_rng_states.extend(
|
||||||
|
[
|
||||||
|
torch.cuda.default_generators[graphsafe_idx]
|
||||||
|
.clone_state()
|
||||||
|
.manual_seed(int(seeds[i]))
|
||||||
|
for i in range(num_rng)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
bwd_rng_states.extend(
|
||||||
|
[
|
||||||
|
torch.cuda.default_generators[graphsafe_idx]
|
||||||
|
.clone_state()
|
||||||
|
.manual_seed(int(seeds[i]))
|
||||||
|
for i in range(num_rng)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants.
|
# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants.
|
||||||
def _backward_epilogue_functional(
|
def _backward_epilogue_functional(
|
||||||
metadata, maybe_subclass_metadata, out, *, make_subclass_override=None
|
metadata, maybe_subclass_metadata, out, *, make_subclass_override=None
|
||||||
@ -1819,6 +1858,34 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||||||
fw_metadata: ViewAndMutationMeta, # runtime metadata
|
fw_metadata: ViewAndMutationMeta, # runtime metadata
|
||||||
try_save_cache_entry: Optional[Callable], # Save cache entry after compilation
|
try_save_cache_entry: Optional[Callable], # Save cache entry after compilation
|
||||||
):
|
):
|
||||||
|
# For additional context see Note [CUDA Graph Safe RNG Functionalization]
|
||||||
|
# Each pair forward, backward rng states must be equal prior to its invocation on any
|
||||||
|
# iteration of forward, backward. Because they are initialized equal, and are computing the same rng op,
|
||||||
|
# running forward then backward advances them the same amount and keeps them equal.
|
||||||
|
# However, a user may invoke multiple forwards, then backwards, such that they are not in sync.
|
||||||
|
# Initially we have:
|
||||||
|
# fwd_state0 == bwd_state0.
|
||||||
|
# Lets say we run:
|
||||||
|
# fwd0: fwd_state0 -> fwd_state1
|
||||||
|
# fwd1: fwd_state1 -> fwd_state2
|
||||||
|
# fwd2: fwd_state2 -> fwd_state3
|
||||||
|
# If we now invoke bwd2,
|
||||||
|
# we need to update bwd_state equal to the rng that was observed in fwd2.
|
||||||
|
# we save the rng_state fwd_state2 in forward because we detect that it is not the
|
||||||
|
# current backward state and therefore would not be accessible if we do not save it.
|
||||||
|
# Similarly, if we are going to update the backward state to a new value, and there is a pending
|
||||||
|
# forwards which needs its current state, we will save it.
|
||||||
|
# Within the autograd context, we keep track of the curr iteration so that on backward
|
||||||
|
# we know what the generator state must be before the backward is run.
|
||||||
|
num_rng = fw_metadata.num_graphsafe_rng_states
|
||||||
|
graphsafe_idx = fw_metadata.graphsafe_rng_state_index
|
||||||
|
fwd_rng_states: list[torch.Generator] = []
|
||||||
|
bwd_rng_states: list[torch.Generator] = []
|
||||||
|
curr_fwd_iter = itertools.count(0)
|
||||||
|
backward_state_position = 0
|
||||||
|
pending_forwards: set[int] = set()
|
||||||
|
saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {}
|
||||||
|
|
||||||
class CompiledFunction(torch.autograd.Function):
|
class CompiledFunction(torch.autograd.Function):
|
||||||
compiled_fw = compiled_fw_func
|
compiled_fw = compiled_fw_func
|
||||||
compiled_bw = compiled_bw_func
|
compiled_bw = compiled_bw_func
|
||||||
@ -1840,6 +1907,26 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||||||
assert isinstance(bw_state, BackwardState)
|
assert isinstance(bw_state, BackwardState)
|
||||||
ctx._compiled_autograd_backward_state = bw_state
|
ctx._compiled_autograd_backward_state = bw_state
|
||||||
|
|
||||||
|
if num_rng:
|
||||||
|
if len(fwd_rng_states) == 0:
|
||||||
|
assert graphsafe_idx is not None
|
||||||
|
initialize_rng_states(
|
||||||
|
num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states
|
||||||
|
)
|
||||||
|
|
||||||
|
_curr_iter = next(curr_fwd_iter)
|
||||||
|
ctx._curr_iter = _curr_iter
|
||||||
|
|
||||||
|
# if this state is not contained in the backward,
|
||||||
|
# we need to save it for when its backward pass happens
|
||||||
|
if _curr_iter != backward_state_position:
|
||||||
|
saved_backward_tensor_states[_curr_iter] = [
|
||||||
|
rng_state.get_state() for rng_state in fwd_rng_states
|
||||||
|
]
|
||||||
|
|
||||||
|
pending_forwards.add(_curr_iter)
|
||||||
|
args = (*args, *fwd_rng_states)
|
||||||
|
|
||||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||||
# The full list of outputs and their relative order is:
|
# The full list of outputs and their relative order is:
|
||||||
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
||||||
@ -1967,6 +2054,45 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||||||
*flat_args,
|
*flat_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_rng:
|
||||||
|
nonlocal backward_state_position, bwd_rng_states
|
||||||
|
curr_backward_iter = ctx._curr_iter
|
||||||
|
retain_graph = (
|
||||||
|
torch._C._autograd._get_current_graph_task_keep_graph()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save current state if we have a pending forward that needs this state
|
||||||
|
# or this state may be needed again because of retain graph
|
||||||
|
if (
|
||||||
|
backward_state_position in pending_forwards
|
||||||
|
and backward_state_position not in saved_backward_tensor_states
|
||||||
|
and (
|
||||||
|
backward_state_position != curr_backward_iter
|
||||||
|
or retain_graph
|
||||||
|
)
|
||||||
|
):
|
||||||
|
saved_backward_tensor_states[backward_state_position] = [
|
||||||
|
rng_state.get_state() for rng_state in bwd_rng_states
|
||||||
|
]
|
||||||
|
|
||||||
|
# Restore saved states if needed
|
||||||
|
if curr_backward_iter in saved_backward_tensor_states:
|
||||||
|
if backward_state_position != curr_backward_iter:
|
||||||
|
for bwd_state, saved_state in zip(
|
||||||
|
bwd_rng_states,
|
||||||
|
saved_backward_tensor_states[curr_backward_iter],
|
||||||
|
):
|
||||||
|
bwd_state.set_state(saved_state)
|
||||||
|
if not retain_graph:
|
||||||
|
del saved_backward_tensor_states[curr_backward_iter]
|
||||||
|
else:
|
||||||
|
assert backward_state_position == curr_backward_iter
|
||||||
|
|
||||||
|
backward_state_position = curr_backward_iter + 1
|
||||||
|
if not retain_graph:
|
||||||
|
pending_forwards.remove(curr_backward_iter)
|
||||||
|
all_args.extend(bwd_rng_states)
|
||||||
|
|
||||||
def impl_fn(double_ctx=None):
|
def impl_fn(double_ctx=None):
|
||||||
out = CompiledFunction._backward_impl(ctx, all_args)
|
out = CompiledFunction._backward_impl(ctx, all_args)
|
||||||
return _backward_epilogue_functional(
|
return _backward_epilogue_functional(
|
||||||
|
@ -420,11 +420,16 @@ class ViewAndMutationMeta:
|
|||||||
# Filled after tracing joint function.
|
# Filled after tracing joint function.
|
||||||
num_backward_tokens: int = 0
|
num_backward_tokens: int = 0
|
||||||
|
|
||||||
|
# Number of rng states that will get thread into the forward and backward for
|
||||||
|
# cudagraph compatible run_and_save_rng
|
||||||
|
num_graphsafe_rng_states: int = 0
|
||||||
|
|
||||||
|
graphsafe_rng_state_index: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# pre-compute the indices of the inputs that are mutated.
|
# pre-compute the indices of the inputs that are mutated.
|
||||||
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
||||||
# handling data-only mutations, because we keep them directly in the graph.
|
# handling data-only mutations, because we keep them directly in the graph.
|
||||||
|
|
||||||
mutated_inp_runtime_indices = [
|
mutated_inp_runtime_indices = [
|
||||||
i
|
i
|
||||||
for i, m in enumerate(self.input_info)
|
for i, m in enumerate(self.input_info)
|
||||||
|
@ -489,3 +489,14 @@ def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool:
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_generator_meta_val(device_idx: int):
|
||||||
|
"""
|
||||||
|
Get a generator value to use as a meta val
|
||||||
|
|
||||||
|
newly cloned generator will not contain tensors. it is only Generators that are
|
||||||
|
registered to a CUDAGraph that contain tensors. since this does not contain Tensor
|
||||||
|
it is fine to use in the meta.
|
||||||
|
"""
|
||||||
|
return torch.cuda.default_generators[device_idx].clone_state()
|
||||||
|
@ -210,6 +210,10 @@ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
|
|||||||
# real tensor outputs.
|
# real tensor outputs.
|
||||||
generate_fake_kernels_from_real_mismatches = False
|
generate_fake_kernels_from_real_mismatches = False
|
||||||
|
|
||||||
|
# CUDAGraph save run_with_rng functionalization.
|
||||||
|
# TODO: turn on by default
|
||||||
|
graphsafe_rng_functionalization = True
|
||||||
|
|
||||||
|
|
||||||
# Error on BypassAOTAutogradCache instead of just a warning
|
# Error on BypassAOTAutogradCache instead of just a warning
|
||||||
# Used for tests
|
# Used for tests
|
||||||
|
@ -40,7 +40,7 @@ from ._activation_checkpointing.knapsack import (
|
|||||||
)
|
)
|
||||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||||
from ._aot_autograd.utils import is_with_effects
|
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||||
from .compile_utils import fx_graph_cse, get_aten_target
|
from .compile_utils import fx_graph_cse, get_aten_target
|
||||||
|
|
||||||
|
|
||||||
@ -622,6 +622,99 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|||||||
return new_gm
|
return new_gm
|
||||||
|
|
||||||
|
|
||||||
|
def apply_graphsafe_rng_functionalization(
|
||||||
|
fw_module: torch.fx.GraphModule,
|
||||||
|
bw_module: torch.fx.GraphModule,
|
||||||
|
fw_node: torch.fx.Node,
|
||||||
|
bw_node: torch.fx.Node,
|
||||||
|
device: torch.device,
|
||||||
|
rng_count: int,
|
||||||
|
last_fwd_input: torch.fx.Node,
|
||||||
|
last_bwd_input: torch.fx.Node,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Note [CUDA Graph Safe RNG Functionalization]
|
||||||
|
|
||||||
|
CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values,
|
||||||
|
while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a
|
||||||
|
CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer
|
||||||
|
to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator
|
||||||
|
(and its cuda-tensor RNG state during graph capture).
|
||||||
|
|
||||||
|
For each RNG operation's forward/backward pair:
|
||||||
|
|
||||||
|
- We create two generators initialized with identical values
|
||||||
|
- Each forward and backward call advances its respective generator equally
|
||||||
|
- This keeps generators synchronized so forward and backward operations use matching RNG values
|
||||||
|
|
||||||
|
When forward is called multiple times before backward (causing desynchronization):
|
||||||
|
|
||||||
|
- We save the forward RNG state
|
||||||
|
- We update the backward Generator's state before executing backward
|
||||||
|
|
||||||
|
Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator
|
||||||
|
changes are reflected during replay.
|
||||||
|
|
||||||
|
This function modifies both forward and backward computation graphs by:
|
||||||
|
|
||||||
|
Creating RNG state placeholders for both passes
|
||||||
|
Updating the forward node to use graph-safe RNG state
|
||||||
|
Updating the backward node to use graph-safe RNG state
|
||||||
|
|
||||||
|
For more details: https://github.com/pytorch/pytorch/issues/113541
|
||||||
|
"""
|
||||||
|
device_idx = device.index
|
||||||
|
assert device_idx is not None
|
||||||
|
fw_graph = fw_module.graph
|
||||||
|
bw_graph = bw_module.graph
|
||||||
|
graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state
|
||||||
|
|
||||||
|
# Handle forward pass
|
||||||
|
|
||||||
|
# Note: [Generator arguments in AOTDispatcher]
|
||||||
|
# Generator arguments in AOTDispatcher are added to support graphsafe rng
|
||||||
|
# functionalization. See note above [CUDA Graph Safe RNG Functionalization]
|
||||||
|
with fw_module.graph.inserting_after(last_fwd_input):
|
||||||
|
fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}")
|
||||||
|
fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
|
||||||
|
last_fwd_input = fwd_rng_state
|
||||||
|
|
||||||
|
# Handle backward pass
|
||||||
|
with bw_module.graph.inserting_after(last_bwd_input):
|
||||||
|
bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}")
|
||||||
|
# as above, clone so that meta val generator will not contain tensors
|
||||||
|
bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
|
||||||
|
last_bwd_input = bwd_rng_state
|
||||||
|
|
||||||
|
# Update forward node
|
||||||
|
fw_kwargs = dict(fw_node.kwargs)
|
||||||
|
fw_kwargs["rng_state"] = fwd_rng_state
|
||||||
|
with fw_module.graph.inserting_after(fw_node):
|
||||||
|
functional_fw_node = fw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
graphsafe_run_with_rng_state,
|
||||||
|
args=(fw_node.target, *fw_node.args), # type: ignore[arg-type]
|
||||||
|
kwargs=fw_kwargs,
|
||||||
|
)
|
||||||
|
fw_node.replace_all_uses_with(functional_fw_node)
|
||||||
|
fw_graph.erase_node(fw_node)
|
||||||
|
|
||||||
|
# Update backward node
|
||||||
|
bwd_kwargs = dict(bw_node.kwargs)
|
||||||
|
bwd_kwargs["rng_state"] = bwd_rng_state
|
||||||
|
with bw_graph.inserting_before(bw_node):
|
||||||
|
rng_output = bw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
graphsafe_run_with_rng_state,
|
||||||
|
args=(bw_node.target, *bw_node.args), # type: ignore[arg-type]
|
||||||
|
kwargs=bwd_kwargs,
|
||||||
|
)
|
||||||
|
bw_node.replace_all_uses_with(rng_output)
|
||||||
|
bw_graph.erase_node(bw_node)
|
||||||
|
|
||||||
|
return last_fwd_input, last_bwd_input
|
||||||
|
|
||||||
|
|
||||||
def functionalize_rng_ops(
|
def functionalize_rng_ops(
|
||||||
joint_module: fx.GraphModule,
|
joint_module: fx.GraphModule,
|
||||||
fw_module: fx.GraphModule,
|
fw_module: fx.GraphModule,
|
||||||
@ -660,7 +753,7 @@ def functionalize_rng_ops(
|
|||||||
random_nodes[node.name] = node
|
random_nodes[node.name] = node
|
||||||
return random_nodes
|
return random_nodes
|
||||||
|
|
||||||
def get_device(node):
|
def get_device(node) -> Optional[torch.device]:
|
||||||
"""
|
"""
|
||||||
Check the example value of the node outputs to find the device type.
|
Check the example value of the node outputs to find the device type.
|
||||||
"""
|
"""
|
||||||
@ -674,12 +767,12 @@ def functionalize_rng_ops(
|
|||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
if isinstance(candidate, torch.Tensor):
|
if isinstance(candidate, torch.Tensor):
|
||||||
if candidate.device.type == "cuda":
|
if candidate.device.type == "cuda":
|
||||||
return "cuda"
|
return candidate.device
|
||||||
|
|
||||||
return "cpu"
|
return torch.device("cpu")
|
||||||
|
|
||||||
def get_sample_rng_state(device):
|
def get_sample_rng_state(device: Optional[torch.device]):
|
||||||
if device == "cuda":
|
if device is not None and device.type == "cuda":
|
||||||
return torch.cuda.get_rng_state()
|
return torch.cuda.get_rng_state()
|
||||||
return torch.get_rng_state()
|
return torch.get_rng_state()
|
||||||
|
|
||||||
@ -701,6 +794,7 @@ def functionalize_rng_ops(
|
|||||||
|
|
||||||
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
|
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
|
||||||
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
|
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
|
||||||
|
|
||||||
bw_tangent_start_node = None
|
bw_tangent_start_node = None
|
||||||
for node in bw_module.graph.find_nodes(op="placeholder"):
|
for node in bw_module.graph.find_nodes(op="placeholder"):
|
||||||
if "tangent" in node.name:
|
if "tangent" in node.name:
|
||||||
@ -712,68 +806,113 @@ def functionalize_rng_ops(
|
|||||||
)
|
)
|
||||||
|
|
||||||
fw_rng_state_outputs = []
|
fw_rng_state_outputs = []
|
||||||
for base_node, node_pair in recomputable_rng_ops_map.items():
|
|
||||||
|
last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder")))
|
||||||
|
last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder")))
|
||||||
|
|
||||||
|
devices = OrderedSet(
|
||||||
|
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
|
||||||
|
)
|
||||||
|
devices.discard(torch.device("cpu"))
|
||||||
|
# multiple cuda devices wont work with cudagraphs anyway,
|
||||||
|
# fallback to non graphsafe rng checkpointing
|
||||||
|
multi_cuda_devices = len(devices) > 1
|
||||||
|
|
||||||
|
# this changes numerics, so if fallback_random is set we will not use it
|
||||||
|
ind_config = torch._inductor.config
|
||||||
|
use_rng_graphsafe_rng_functionalization = (
|
||||||
|
config.graphsafe_rng_functionalization
|
||||||
|
and not multi_cuda_devices
|
||||||
|
and (
|
||||||
|
not ind_config.fallback_random
|
||||||
|
or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for rng_count, (base_node, node_pair) in enumerate(
|
||||||
|
recomputable_rng_ops_map.items()
|
||||||
|
):
|
||||||
# Step 2 - Modify the fwd pass such that
|
# Step 2 - Modify the fwd pass such that
|
||||||
fw_node = node_pair["fwd"]
|
fw_node = node_pair["fwd"]
|
||||||
bw_node = node_pair["bwd"]
|
bw_node = node_pair["bwd"]
|
||||||
|
device = get_device(fw_node)
|
||||||
|
|
||||||
fw_graph = fw_module.graph
|
fw_graph = fw_module.graph
|
||||||
with fw_graph.inserting_before(fw_node):
|
|
||||||
functional_fw_node = fw_graph.create_node(
|
|
||||||
"call_function",
|
|
||||||
run_and_save_rng,
|
|
||||||
args=(fw_node.target, *fw_node.args),
|
|
||||||
kwargs=fw_node.kwargs,
|
|
||||||
)
|
|
||||||
state = fw_graph.create_node(
|
|
||||||
"call_function",
|
|
||||||
operator.getitem,
|
|
||||||
args=(functional_fw_node, 0),
|
|
||||||
kwargs={},
|
|
||||||
)
|
|
||||||
rng_output = fw_graph.create_node(
|
|
||||||
"call_function",
|
|
||||||
operator.getitem,
|
|
||||||
args=(
|
|
||||||
functional_fw_node,
|
|
||||||
1,
|
|
||||||
),
|
|
||||||
kwargs={},
|
|
||||||
)
|
|
||||||
fw_node.replace_all_uses_with(rng_output)
|
|
||||||
fw_graph.erase_node(fw_node)
|
|
||||||
fw_rng_state_outputs.append(state)
|
|
||||||
|
|
||||||
# Step 3 - Modify the bwd pass such that
|
|
||||||
bw_graph = bw_module.graph
|
bw_graph = bw_module.graph
|
||||||
with bw_graph.inserting_before(bw_tangent_start_node):
|
|
||||||
state_name = f"rng_state_output_{next(uid)}"
|
|
||||||
bw_rng_state_node = bw_graph.placeholder(state_name)
|
|
||||||
bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node))
|
|
||||||
|
|
||||||
with bw_graph.inserting_before(bw_node):
|
if (
|
||||||
rng_output = bw_graph.create_node(
|
use_rng_graphsafe_rng_functionalization
|
||||||
"call_function",
|
and device is not None
|
||||||
run_with_rng_state,
|
and device.type == "cuda"
|
||||||
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
|
):
|
||||||
kwargs=bw_node.kwargs,
|
last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization(
|
||||||
|
fw_module,
|
||||||
|
bw_module,
|
||||||
|
fw_node,
|
||||||
|
bw_node,
|
||||||
|
device,
|
||||||
|
rng_count,
|
||||||
|
last_fwd_input,
|
||||||
|
last_bwd_input,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
with fw_graph.inserting_before(fw_node):
|
||||||
|
functional_fw_node = fw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
run_and_save_rng,
|
||||||
|
args=(fw_node.target, *fw_node.args),
|
||||||
|
kwargs=fw_node.kwargs,
|
||||||
|
)
|
||||||
|
state = fw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
operator.getitem,
|
||||||
|
args=(functional_fw_node, 0),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
rng_output = fw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
operator.getitem,
|
||||||
|
args=(
|
||||||
|
functional_fw_node,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
kwargs={},
|
||||||
|
)
|
||||||
|
fw_node.replace_all_uses_with(rng_output)
|
||||||
|
fw_graph.erase_node(fw_node)
|
||||||
|
fw_rng_state_outputs.append(state)
|
||||||
|
|
||||||
bw_node.replace_all_uses_with(rng_output)
|
# Step 3 - Modify the bwd pass such that
|
||||||
bw_graph.erase_node(bw_node)
|
with bw_graph.inserting_before(bw_tangent_start_node):
|
||||||
|
state_name = f"rng_state_output_{next(uid)}"
|
||||||
|
bw_rng_state_node = bw_graph.placeholder(state_name)
|
||||||
|
bw_rng_state_node.meta["val"] = get_sample_rng_state(device)
|
||||||
|
|
||||||
|
with bw_graph.inserting_before(bw_node):
|
||||||
|
rng_output = bw_graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
run_with_rng_state,
|
||||||
|
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
|
||||||
|
kwargs=bw_node.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
bw_node.replace_all_uses_with(rng_output)
|
||||||
|
bw_graph.erase_node(bw_node)
|
||||||
|
|
||||||
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
|
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
|
||||||
# that symints are at the end of forward graph outputs. So, insert the new
|
# that symints are at the end of forward graph outputs. So, insert the new
|
||||||
# rng states accordingly.
|
# rng states accordingly.
|
||||||
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
|
if fw_rng_state_outputs:
|
||||||
fw_outputs = fw_output_node.args[0]
|
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
|
||||||
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
fw_outputs = fw_output_node.args[0]
|
||||||
outputs = (
|
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
||||||
fw_outputs[:sym_node_start_idx]
|
outputs = (
|
||||||
+ tuple(fw_rng_state_outputs)
|
fw_outputs[:sym_node_start_idx]
|
||||||
+ fw_outputs[sym_node_start_idx:]
|
+ tuple(fw_rng_state_outputs)
|
||||||
)
|
+ fw_outputs[sym_node_start_idx:]
|
||||||
fw_module.graph.output(outputs)
|
)
|
||||||
fw_module.graph.erase_node(fw_output_node)
|
fw_module.graph.output(outputs)
|
||||||
|
fw_module.graph.erase_node(fw_output_node)
|
||||||
fw_module.recompile()
|
fw_module.recompile()
|
||||||
bw_module.recompile()
|
bw_module.recompile()
|
||||||
return fw_module, bw_module
|
return fw_module, bw_module
|
||||||
@ -1849,7 +1988,6 @@ def min_cut_rematerialization_partition(
|
|||||||
saved_sym_nodes=saved_sym_nodes,
|
saved_sym_nodes=saved_sym_nodes,
|
||||||
num_fwd_outputs=num_fwd_outputs,
|
num_fwd_outputs=num_fwd_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if graph_has_recomputable_ops:
|
if graph_has_recomputable_ops:
|
||||||
if graph_has_recomputable_rng_ops:
|
if graph_has_recomputable_rng_ops:
|
||||||
fw_module, bw_module = functionalize_rng_ops(
|
fw_module, bw_module = functionalize_rng_ops(
|
||||||
|
@ -898,7 +898,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# a graph partition may take an IRNode output from a previous partition
|
# a graph partition may take an IRNode output from a previous partition
|
||||||
if name not in V.graph.graph_input_names:
|
if name not in V.graph.graph_input_names or isinstance(
|
||||||
|
buf, ir.GeneratorState
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
||||||
@ -1418,7 +1420,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
code.writeline(f"{stride} = {strideof(name)}[{dim}]")
|
code.writeline(f"{stride} = {strideof(name)}[{dim}]")
|
||||||
bound_vars.add(stride)
|
bound_vars.add(stride)
|
||||||
elif isinstance(value, ir.TorchBindObject):
|
elif isinstance(value, ir.TorchBindObject):
|
||||||
pass
|
return
|
||||||
|
elif isinstance(value, ir.GeneratorState):
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
if torch._inductor.config.graph_partition:
|
if torch._inductor.config.graph_partition:
|
||||||
pass
|
pass
|
||||||
@ -1612,6 +1616,11 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
# is actually a valid value for the kernel in question.
|
# is actually a valid value for the kernel in question.
|
||||||
# See https://github.com/pytorch/pytorch/issues/124686
|
# See https://github.com/pytorch/pytorch/issues/124686
|
||||||
add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
|
add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
|
||||||
|
elif isinstance(value, ir.GeneratorState):
|
||||||
|
add_expr_input(
|
||||||
|
name,
|
||||||
|
f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
shape = [
|
shape = [
|
||||||
V.graph.sizevars.size_hint(x, fallback=42)
|
V.graph.sizevars.size_hint(x, fallback=42)
|
||||||
@ -2287,6 +2296,8 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
return s.codegen_reference()
|
return s.codegen_reference()
|
||||||
elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
|
elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
|
||||||
return dtype_to_string(s)
|
return dtype_to_string(s)
|
||||||
|
elif isinstance(s, ir.GeneratorState):
|
||||||
|
return s.codegen_reference()
|
||||||
else:
|
else:
|
||||||
return repr(s)
|
return repr(s)
|
||||||
|
|
||||||
|
@ -1517,6 +1517,8 @@ class test_configs:
|
|||||||
autotune_choice_name_regex: Optional[str] = None
|
autotune_choice_name_regex: Optional[str] = None
|
||||||
autotune_choice_desc_regex: Optional[str] = None
|
autotune_choice_desc_regex: Optional[str] = None
|
||||||
|
|
||||||
|
graphsafe_rng_func_ignores_fallback_random = False
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils._config_typing import * # noqa: F401, F403
|
from torch.utils._config_typing import * # noqa: F401, F403
|
||||||
|
@ -908,6 +908,7 @@ class CUDAGraphNode:
|
|||||||
self.recorded_liveness_before_graph = curr_liveness
|
self.recorded_liveness_before_graph = curr_liveness
|
||||||
self.expected_dead_indices_before_graph = different_indices
|
self.expected_dead_indices_before_graph = different_indices
|
||||||
|
|
||||||
|
rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)]
|
||||||
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
|
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
|
||||||
# recording inputs will copy over memory, so we can free non recording inputs
|
# recording inputs will copy over memory, so we can free non recording inputs
|
||||||
inputs.clear()
|
inputs.clear()
|
||||||
@ -916,6 +917,11 @@ class CUDAGraphNode:
|
|||||||
# graph used for recording model invocation
|
# graph used for recording model invocation
|
||||||
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
|
# TODO: register_generator_state should potentially take explicit device
|
||||||
|
with torch.cuda.device(self.device):
|
||||||
|
for rng_state in rng_states:
|
||||||
|
self.graph.register_generator_state(rng_state)
|
||||||
|
|
||||||
# we allocate non-static inputs within the same memory pool as the CUDAGraph
|
# we allocate non-static inputs within the same memory pool as the CUDAGraph
|
||||||
# which we will record the model with. For memory efficiency, it is important
|
# which we will record the model with. For memory efficiency, it is important
|
||||||
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
|
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
|
||||||
@ -1602,7 +1608,7 @@ class CUDAGraphNode:
|
|||||||
):
|
):
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
if not isinstance(inp, torch.Tensor):
|
if not isinstance(inp, torch.Tensor):
|
||||||
assert isinstance(inp, int)
|
assert isinstance(inp, (int, torch.Generator))
|
||||||
recording_inputs.append(inp)
|
recording_inputs.append(inp)
|
||||||
elif i not in self.static_input_idxs:
|
elif i not in self.static_input_idxs:
|
||||||
# static_input does an allocation!
|
# static_input does an allocation!
|
||||||
|
@ -1040,6 +1040,18 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
# Alternately we could filter this out in AotAutograd
|
# Alternately we could filter this out in AotAutograd
|
||||||
self.graph_input_names.append(target)
|
self.graph_input_names.append(target)
|
||||||
return None
|
return None
|
||||||
|
# See note: Note: [Generator arguments in AOTDispatcher]
|
||||||
|
elif isinstance(example, torch.Generator):
|
||||||
|
assert (
|
||||||
|
len(V.graph.current_node.users) == 1
|
||||||
|
and next(iter(V.graph.current_node.users)).target
|
||||||
|
is torch._prims.rng_prims.graphsafe_run_with_rng_state
|
||||||
|
)
|
||||||
|
gen = ir.GeneratorState(name=target, device=example.device)
|
||||||
|
self.graph_inputs[target] = gen # type: ignore[assignment]
|
||||||
|
self.graph_input_names.append(target)
|
||||||
|
return gen
|
||||||
|
|
||||||
assert isinstance(example, torch.Tensor), example
|
assert isinstance(example, torch.Tensor), example
|
||||||
# todo(chilli): We can remove the last check once we turn buffers into
|
# todo(chilli): We can remove the last check once we turn buffers into
|
||||||
# static shape tensors. That's a hack to workaround Inductor believing
|
# static shape tensors. That's a hack to workaround Inductor believing
|
||||||
@ -1288,7 +1300,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
if isinstance(value, TorchBindObject):
|
if isinstance(value, TorchBindObject):
|
||||||
continue
|
continue
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, (TensorBox, sympy.Expr)
|
value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState)
|
||||||
), f"Unsupported inductor graph input type: {type(value)}"
|
), f"Unsupported inductor graph input type: {type(value)}"
|
||||||
if not isinstance(value, TensorBox):
|
if not isinstance(value, TensorBox):
|
||||||
continue
|
continue
|
||||||
|
@ -4977,7 +4977,9 @@ class ExternKernel(InputsKernel):
|
|||||||
tensor_args = []
|
tensor_args = []
|
||||||
non_tensor_args: list[Any] = []
|
non_tensor_args: list[Any] = []
|
||||||
for arg in args_flat:
|
for arg in args_flat:
|
||||||
is_arg_tensor.append(isinstance(arg, IRNode))
|
is_arg_tensor.append(
|
||||||
|
isinstance(arg, IRNode) and not isinstance(arg, GeneratorState)
|
||||||
|
)
|
||||||
if is_arg_tensor[-1]:
|
if is_arg_tensor[-1]:
|
||||||
tensor_args.append(arg)
|
tensor_args.append(arg)
|
||||||
else:
|
else:
|
||||||
@ -5008,7 +5010,9 @@ class ExternKernel(InputsKernel):
|
|||||||
# Rerun fake tensor propagation, because Inductor may have changed the
|
# Rerun fake tensor propagation, because Inductor may have changed the
|
||||||
# strides of inputs and we need to determine accurately what the
|
# strides of inputs and we need to determine accurately what the
|
||||||
# output stride will be.
|
# output stride will be.
|
||||||
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
example_args: list[
|
||||||
|
Union[torch.Tensor, torch._C.ScriptObject, torch.Generator]
|
||||||
|
] = []
|
||||||
|
|
||||||
# We need to retain the constant values of fake tensors that we originally
|
# We need to retain the constant values of fake tensors that we originally
|
||||||
# propagated the graph with, because for some operators running without a
|
# propagated the graph with, because for some operators running without a
|
||||||
@ -5025,6 +5029,12 @@ class ExternKernel(InputsKernel):
|
|||||||
example_args.append(V.graph.torchbind_constants[x.get_name()])
|
example_args.append(V.graph.torchbind_constants[x.get_name()])
|
||||||
elif isinstance(x, TorchBindObject):
|
elif isinstance(x, TorchBindObject):
|
||||||
example_args.append(x.get_real_obj())
|
example_args.append(x.get_real_obj())
|
||||||
|
elif isinstance(x, torch._inductor.ir.GeneratorState):
|
||||||
|
device_index = x.device.index
|
||||||
|
assert x.device.type == "cuda" and device_index is not None
|
||||||
|
example_args.append(
|
||||||
|
torch.cuda.default_generators[device_index].clone_state()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
example_args.append(ir_node_to_tensor(x, guard_shape=True))
|
example_args.append(ir_node_to_tensor(x, guard_shape=True))
|
||||||
|
|
||||||
@ -5155,7 +5165,7 @@ class ExternKernel(InputsKernel):
|
|||||||
# TODO(jansel): impose layout preference on realized buffer
|
# TODO(jansel): impose layout preference on realized buffer
|
||||||
x.realize()
|
x.realize()
|
||||||
return x
|
return x
|
||||||
if isinstance(x, TorchBindObject):
|
if isinstance(x, (NonTensorObj)):
|
||||||
return x
|
return x
|
||||||
return cls.copy_input(x)
|
return cls.copy_input(x)
|
||||||
|
|
||||||
@ -7570,8 +7580,12 @@ class EffectfulKernel(FallbackKernel):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class NonTensorObj(IRNode):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class TorchBindObject(IRNode):
|
class TorchBindObject(NonTensorObj):
|
||||||
from torch._library.fake_class_registry import FakeScriptObject
|
from torch._library.fake_class_registry import FakeScriptObject
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
@ -7605,6 +7619,18 @@ class TorchBindObject(IRNode):
|
|||||||
return functools.reduce(lambda x, y: x + y, flat_sizes, 0)
|
return functools.reduce(lambda x, y: x + y, flat_sizes, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@ir_dataclass
|
||||||
|
class GeneratorState(NonTensorObj):
|
||||||
|
name: str
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
def get_name(self): # type: ignore[no-untyped-def]
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
class _CollectiveKernel(FallbackKernel):
|
class _CollectiveKernel(FallbackKernel):
|
||||||
def should_allocate(self) -> bool:
|
def should_allocate(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
@ -2718,6 +2718,8 @@ make_fallback(aten.gcd.default, warn=False)
|
|||||||
make_fallback(aten._thnn_fused_lstm_cell, require_dense)
|
make_fallback(aten._thnn_fused_lstm_cell, require_dense)
|
||||||
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
|
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
|
||||||
make_fallback(torch._prims.rng_prims.run_with_rng_state)
|
make_fallback(torch._prims.rng_prims.run_with_rng_state)
|
||||||
|
make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
|
||||||
|
|
||||||
|
|
||||||
# Implmented / Half implemented
|
# Implmented / Half implemented
|
||||||
# Scans. Implemented for CUDA, missing CPU
|
# Scans. Implemented for CUDA, missing CPU
|
||||||
|
@ -426,7 +426,7 @@ class CompiledFxGraph(OutputCode):
|
|||||||
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
||||||
(
|
(
|
||||||
all(
|
all(
|
||||||
isinstance(t, (torch.Tensor, torch.SymInt))
|
isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator))
|
||||||
for t in example_inputs
|
for t in example_inputs
|
||||||
),
|
),
|
||||||
"non-Tensor inputs",
|
"non-Tensor inputs",
|
||||||
|
@ -2104,6 +2104,7 @@ def count_tangents(fx_g: torch.fx.GraphModule) -> int:
|
|||||||
"tangents" not in x.name
|
"tangents" not in x.name
|
||||||
and "bwd_seed" not in x.name
|
and "bwd_seed" not in x.name
|
||||||
and "bwd_base_offset" not in x.name
|
and "bwd_base_offset" not in x.name
|
||||||
|
and "bwd_rng_state" not in x.name
|
||||||
)
|
)
|
||||||
|
|
||||||
arg_count = 0
|
arg_count = 0
|
||||||
|
@ -315,5 +315,75 @@ run_and_save_rng_state = register_run_and_save_rng_state_op()
|
|||||||
run_with_rng_state = register_run_with_rng_state_op()
|
run_with_rng_state = register_run_with_rng_state_op()
|
||||||
|
|
||||||
|
|
||||||
|
def register_graphsafe_run_with_rng_state_op():
|
||||||
|
class GraphSafeRunWithRngState(HigherOrderOperator):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("graphsafe_run_with_rng_state")
|
||||||
|
|
||||||
|
def __call__(self, op, *args, rng_state=None, **kwargs):
|
||||||
|
return super().__call__(op, *args, rng_state=rng_state, **kwargs)
|
||||||
|
|
||||||
|
graphsafe_run_with_rng_state = GraphSafeRunWithRngState()
|
||||||
|
|
||||||
|
graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)(
|
||||||
|
autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
|
||||||
|
def impl_cuda(op, *args, rng_state=None, **kwargs):
|
||||||
|
device_idx = rng_state.device.index
|
||||||
|
generator = torch.cuda.default_generators[device_idx]
|
||||||
|
current_state = generator.graphsafe_get_state()
|
||||||
|
generator.graphsafe_set_state(rng_state)
|
||||||
|
out = op(*args, **kwargs)
|
||||||
|
generator.graphsafe_set_state(current_state)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
|
||||||
|
def impl_backend_select(op, *args, rng_state=None, **kwargs):
|
||||||
|
device = get_device(args, kwargs)
|
||||||
|
assert (
|
||||||
|
device == "cuda"
|
||||||
|
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
|
||||||
|
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
|
||||||
|
def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs):
|
||||||
|
with mode:
|
||||||
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode)
|
||||||
|
def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs):
|
||||||
|
with disable_proxy_modes_tracing():
|
||||||
|
out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs)
|
||||||
|
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
|
||||||
|
proxy_kwargs = pytree.tree_map(
|
||||||
|
mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs}
|
||||||
|
)
|
||||||
|
out_proxy = mode.tracer.create_proxy(
|
||||||
|
"call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs
|
||||||
|
)
|
||||||
|
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||||
|
|
||||||
|
@graphsafe_run_with_rng_state.py_functionalize_impl
|
||||||
|
def impl_functional(ctx, op, *args, rng_state=None, **kwargs):
|
||||||
|
unwrapped_rng_state = (
|
||||||
|
ctx.unwrap_tensors(rng_state) if rng_state is not None else None
|
||||||
|
)
|
||||||
|
unwrapped_args = ctx.unwrap_tensors(args)
|
||||||
|
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||||
|
|
||||||
|
with ctx.redispatch_to_next():
|
||||||
|
out = graphsafe_run_with_rng_state(
|
||||||
|
op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs
|
||||||
|
)
|
||||||
|
return ctx.wrap_tensors(out)
|
||||||
|
|
||||||
|
return graphsafe_run_with_rng_state
|
||||||
|
|
||||||
|
|
||||||
|
graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op()
|
||||||
|
|
||||||
|
|
||||||
def register_rng_prims():
|
def register_rng_prims():
|
||||||
register_philox_rand()
|
register_philox_rand()
|
||||||
|
@ -819,6 +819,9 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
|
|||||||
yield from _iterate_exprs(val.storage_offset())
|
yield from _iterate_exprs(val.storage_offset())
|
||||||
elif val is None:
|
elif val is None:
|
||||||
pass
|
pass
|
||||||
|
# see Note: [Generator arguments in AOTDispatcher]
|
||||||
|
elif isinstance(val, torch.Generator):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
||||||
|
|
||||||
|
@ -77,6 +77,7 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
|
|||||||
"autograd_function_apply",
|
"autograd_function_apply",
|
||||||
"run_and_save_rng_state",
|
"run_and_save_rng_state",
|
||||||
"run_with_rng_state",
|
"run_with_rng_state",
|
||||||
|
"graphsafe_run_with_rng_state",
|
||||||
"out_dtype",
|
"out_dtype",
|
||||||
"trace_wrapped",
|
"trace_wrapped",
|
||||||
'tag_activation_checkpoint',
|
'tag_activation_checkpoint',
|
||||||
|
Reference in New Issue
Block a user