mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support torch.compile rng selective activation checkpointing with cudagraph (#146878)"
This reverts commit ad0c879e2203145f6d56df0b95af36822220ab8f. Reverted https://github.com/pytorch/pytorch/pull/146878 on behalf of https://github.com/wdvr due to lint failure ([comment](https://github.com/pytorch/pytorch/pull/146878#issuecomment-2686767956))
This commit is contained in:
@ -54,8 +54,6 @@ def count_ops(
|
||||
return node.args[0] == op
|
||||
elif node.name == "run_with_rng_state":
|
||||
return node.args[1] == op
|
||||
elif node.name == "graphsafe_run_with_rng_state":
|
||||
return node.args[0] == op
|
||||
return False
|
||||
|
||||
# assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
|
||||
@ -1020,7 +1018,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
|
@ -4,17 +4,13 @@ import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor import config
|
||||
@ -23,9 +19,7 @@ from torch._inductor.compile_fx import compile_fx_inner
|
||||
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
|
||||
from torch._inductor.cudagraph_utils import FunctionID
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -38,7 +32,6 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_CUDA_GRAPH,
|
||||
TEST_WITH_ASAN,
|
||||
)
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
@ -53,7 +46,7 @@ if IS_WINDOWS and IS_CI:
|
||||
importlib.import_module("functorch")
|
||||
importlib.import_module("filelock")
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -2511,407 +2504,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
eager_result = f(example_input)
|
||||
self.assertEqual(compiled_result, eager_result)
|
||||
|
||||
class TestSAC(TestCase):
|
||||
def _make_observer_mode(self):
|
||||
class ObserverMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.curr_run = 0
|
||||
self.op_outputs = defaultdict(list)
|
||||
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
types: Sequence[type],
|
||||
args: Sequence[object] = (),
|
||||
kwargs: Mapping[str, object] = immutable_dict(),
|
||||
) -> object:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return ObserverMode
|
||||
|
||||
def test_simple(self):
|
||||
device = "cuda"
|
||||
|
||||
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||
|
||||
ObserverMode = self._make_observer_mode()
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||
def _(mode, op, *args, **kwargs):
|
||||
with no_dispatch():
|
||||
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||
|
||||
mode.op_outputs[op].append(out)
|
||||
return out
|
||||
|
||||
obs = ObserverMode()
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
|
||||
for _ in range(2):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def fn(x, y):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
return x
|
||||
|
||||
aot_eager_decomp_partition = functools.partial(
|
||||
aot_eager_decomp_partition_with_mode, mode=obs
|
||||
)
|
||||
|
||||
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||
|
||||
fn(x, y).sum().backward()
|
||||
|
||||
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4)
|
||||
for i in range(2):
|
||||
self.assertEqual(
|
||||
obs.op_outputs[aten.rand.default][0 + 2 * i],
|
||||
obs.op_outputs[aten.rand.default][1 + 2 * i],
|
||||
)
|
||||
self.assertNotEqual(
|
||||
obs.op_outputs[aten.rand.default][0],
|
||||
obs.op_outputs[aten.rand.default][2],
|
||||
)
|
||||
|
||||
def test_cudagraph_uneven_forward_backward(self):
|
||||
# torch.compile cudagraphs are difficult to test
|
||||
# the rng updating bc is sensitive to duration of pending backwards, etc.
|
||||
# this is a short repro to mimic the runtime wrappers integration
|
||||
# and show that updating the backward rng state with cudagraphs works:
|
||||
def forward():
|
||||
state = torch.cuda.get_rng_state()
|
||||
perm = torch.randperm(10, device="cuda")
|
||||
return state, perm
|
||||
|
||||
def backward(rng_state):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state.cpu())
|
||||
perm = torch.randperm(10, device="cuda")
|
||||
torch.cuda.set_rng_state(current_state)
|
||||
return perm
|
||||
|
||||
def normal_test():
|
||||
state, perm = forward()
|
||||
repro_perm = backward(state)
|
||||
return perm, repro_perm
|
||||
|
||||
def graphsafe_forward():
|
||||
perm = torch.randperm(10, device="cuda")
|
||||
return perm
|
||||
|
||||
def graphsafe_backward(generator, new_state):
|
||||
current_state = generator.graphsafe_get_state()
|
||||
generator.graphsafe_set_state(new_state)
|
||||
perm = torch.randperm(10, device="cuda")
|
||||
generator.graphsafe_set_state(current_state)
|
||||
return perm
|
||||
|
||||
def graph_test(generator, capture_cuda_graph):
|
||||
if capture_cuda_graph:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# state should be cloned before the graph
|
||||
old_state = generator.graphsafe_get_state()
|
||||
new_state = old_state.clone_state()
|
||||
|
||||
if capture_cuda_graph:
|
||||
# state should be register to the graph
|
||||
graph.register_generator_state(new_state)
|
||||
|
||||
# only capturing the backward
|
||||
with torch.cuda.graph(graph):
|
||||
repro_perm = graphsafe_backward(generator, new_state)
|
||||
|
||||
# some number of uneven forwards
|
||||
graphsafe_forward()
|
||||
graphsafe_forward()
|
||||
graphsafe_forward()
|
||||
|
||||
# state prior to rng invocation
|
||||
state = generator.get_state()
|
||||
perm = graphsafe_forward()
|
||||
|
||||
new_state.set_state(state)
|
||||
|
||||
if capture_cuda_graph:
|
||||
graph.replay()
|
||||
else:
|
||||
repro_perm = graphsafe_backward(generator, new_state)
|
||||
|
||||
return perm, repro_perm
|
||||
|
||||
self.assertEqual(*normal_test())
|
||||
generator = torch.cuda.default_generators[0]
|
||||
self.assertEqual(*graph_test(generator, capture_cuda_graph=False))
|
||||
self.assertEqual(*graph_test(generator, capture_cuda_graph=True))
|
||||
|
||||
def test_cpu_and_cuda_rng(self):
|
||||
device = "cuda"
|
||||
|
||||
ObserverMode = self._make_observer_mode()
|
||||
from torch._prims.rng_prims import (
|
||||
graphsafe_run_with_rng_state,
|
||||
run_and_save_rng_state,
|
||||
run_with_rng_state,
|
||||
)
|
||||
|
||||
for hop in [
|
||||
graphsafe_run_with_rng_state,
|
||||
run_and_save_rng_state,
|
||||
run_with_rng_state,
|
||||
]:
|
||||
|
||||
def make_impl(hop):
|
||||
@hop.py_impl(ObserverMode)
|
||||
def _(mode, *args, **kwargs):
|
||||
with no_dispatch():
|
||||
out = hop(*args, **kwargs)
|
||||
|
||||
op = None
|
||||
for inp in itertools.chain(args, kwargs.values()):
|
||||
if isinstance(inp, torch._ops.OpOverload):
|
||||
op = inp
|
||||
break
|
||||
assert op is not None
|
||||
if hop is run_and_save_rng_state:
|
||||
mode.op_outputs[op].append(out[1])
|
||||
else:
|
||||
mode.op_outputs[op].append(out)
|
||||
return out
|
||||
|
||||
make_impl(hop)
|
||||
|
||||
obs = ObserverMode()
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def gn2(x):
|
||||
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
||||
|
||||
def fn(x, y, z):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True)
|
||||
return x * z.cuda()
|
||||
|
||||
aot_eager_decomp_partition = functools.partial(
|
||||
aot_eager_decomp_partition_with_mode, mode=obs
|
||||
)
|
||||
|
||||
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
z = torch.randn(4, 4, requires_grad=True)
|
||||
|
||||
fn(x, y, z).sum().backward()
|
||||
for op in [aten.rand.default, aten.randperm.default]:
|
||||
self.assertEqual(len(obs.op_outputs[op]), 2)
|
||||
self.assertEqual(
|
||||
obs.op_outputs[op][0],
|
||||
obs.op_outputs[op][1],
|
||||
)
|
||||
self.assertEqual(
|
||||
obs.op_outputs[op][0].device.type,
|
||||
"cpu" if op == aten.randperm.default else "cuda",
|
||||
)
|
||||
|
||||
@parametrize("order", (list(itertools.permutations([0, 1, 2]))))
|
||||
def test_uneven_forward_backward(self, order):
|
||||
device = "cuda"
|
||||
|
||||
ObserverMode = self._make_observer_mode()
|
||||
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||
def _(mode, op, *args, **kwargs):
|
||||
with no_dispatch():
|
||||
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||
|
||||
mode.op_outputs[(mode.curr_run, op)].append(out)
|
||||
return out
|
||||
|
||||
obs = ObserverMode()
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def gn2(x):
|
||||
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
||||
|
||||
def fn(x, y):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True)
|
||||
return x
|
||||
|
||||
aot_eager_decomp_partition = functools.partial(
|
||||
aot_eager_decomp_partition_with_mode, mode=obs
|
||||
)
|
||||
|
||||
fn_c = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||
|
||||
torch.manual_seed(0)
|
||||
outs = []
|
||||
for i in range(len(order)):
|
||||
obs.curr_run = i
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
outs.append(fn_c(x, y))
|
||||
|
||||
for idx in order:
|
||||
obs.curr_run = idx
|
||||
outs[idx].sum().backward()
|
||||
|
||||
for run in range(len(order)):
|
||||
for op in (aten.rand.default, aten.randperm.default):
|
||||
self.assertEqual(len(obs.op_outputs[(run, op)]), 2)
|
||||
self.assertEqual(
|
||||
obs.op_outputs[(run, op)][0],
|
||||
obs.op_outputs[(run, op)][1],
|
||||
)
|
||||
if run != 0:
|
||||
self.assertNotEqual(
|
||||
obs.op_outputs[(run - 1, op)][0],
|
||||
obs.op_outputs[(run, op)][0],
|
||||
)
|
||||
|
||||
@config.patch(fallback_random=True)
|
||||
@config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True)
|
||||
def _test_cudagraphs_aot_eager_compat_equal(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def fn(x, y):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
return x
|
||||
|
||||
outs = []
|
||||
grads = []
|
||||
|
||||
outs2 = []
|
||||
grads2 = []
|
||||
|
||||
compile_fns = [
|
||||
lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"),
|
||||
lambda fn: torch.compile(fn, mode="reduce-overhead"),
|
||||
]
|
||||
for i, compile_fn in enumerate(compile_fns):
|
||||
torch.manual_seed(0)
|
||||
for index in range(3):
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
|
||||
out = compile_fn(fn)(x, y)
|
||||
torch.cuda.synchronize()
|
||||
out.sum().backward()
|
||||
if i == 0:
|
||||
outs.append(out.clone())
|
||||
grads.append((x.grad.clone(), y.grad.clone()))
|
||||
else:
|
||||
outs2.append(out.clone())
|
||||
grads2.append((x.grad.clone(), y.grad.clone()))
|
||||
|
||||
self.assertEqual(outs, outs2)
|
||||
self.assertEqual(grads, grads2)
|
||||
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
||||
|
||||
def test_cudagraphs_aot_eager_compat_equal(self):
|
||||
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
|
||||
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_multi_device(self):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def fn(x, y):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
return x
|
||||
|
||||
def multi_fn(x, y, a, b):
|
||||
return fn(x, y), fn(a, b)
|
||||
|
||||
x = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
||||
y = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
||||
|
||||
a = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
||||
b = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
||||
|
||||
# No errors. TODO - get graphs from logging, couldnt figure out how
|
||||
multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition")
|
||||
|
||||
out = multi_fn_c(x, y, a, b)
|
||||
out[0].sum().backward()
|
||||
|
||||
def test_retain_graph(self):
|
||||
device = "cuda"
|
||||
|
||||
ObserverMode = self._make_observer_mode()
|
||||
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
||||
def _(mode, op, *args, **kwargs):
|
||||
with no_dispatch():
|
||||
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
||||
|
||||
mode.op_outputs[op].append(out)
|
||||
return out
|
||||
|
||||
obs = ObserverMode()
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.rand_like(x) * y) * x
|
||||
|
||||
def fn(x, y):
|
||||
x = torch.sin(x)
|
||||
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
||||
x = torch.sin(x)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
y = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
|
||||
aot_eager_decomp_partition = functools.partial(
|
||||
aot_eager_decomp_partition_with_mode, mode=obs
|
||||
)
|
||||
|
||||
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
||||
|
||||
out = fn(x, y).sum()
|
||||
out.backward(retain_graph=True)
|
||||
out.backward()
|
||||
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3)
|
||||
self.assertEqual(
|
||||
obs.op_outputs[aten.rand.default][0],
|
||||
obs.op_outputs[aten.rand.default][1],
|
||||
)
|
||||
self.assertEqual(
|
||||
obs.op_outputs[aten.rand.default][1],
|
||||
obs.op_outputs[aten.rand.default][2],
|
||||
)
|
||||
|
||||
instantiate_parametrized_tests(CudaGraphTreeTests)
|
||||
instantiate_parametrized_tests(TestSAC)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
@ -2921,5 +2514,5 @@ if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("cuda graph test is skipped")
|
||||
|
||||
if HAS_CUDA:
|
||||
if HAS_CPU or HAS_CUDA:
|
||||
run_tests(needs="filelock")
|
||||
|
@ -62,7 +62,6 @@ class TestHOPInfra(TestCase):
|
||||
FIXME_ALLOWLIST = {
|
||||
"autograd_function_apply",
|
||||
"run_with_rng_state",
|
||||
"graphsafe_run_with_rng_state",
|
||||
"map_impl",
|
||||
"_export_tracepoint",
|
||||
"run_and_save_rng_state",
|
||||
|
@ -66,7 +66,6 @@ from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
|
||||
from .utils import (
|
||||
_get_symint_hints,
|
||||
contain_metadata_mutation_ops,
|
||||
get_cuda_generator_meta_val,
|
||||
make_boxed_func,
|
||||
strict_zip,
|
||||
unlift_tokens,
|
||||
@ -449,20 +448,9 @@ def aot_dispatch_autograd(
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode is not None and fake_mode.shape_env is not None:
|
||||
tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode)
|
||||
|
||||
fw_module, bw_module = aot_config.partition_fn(
|
||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||
)
|
||||
rng_states = [
|
||||
n
|
||||
for n in fw_module.graph.find_nodes(op="placeholder")
|
||||
if "fwd_rng_state" in n.name
|
||||
]
|
||||
fw_metadata.num_graphsafe_rng_states = len(rng_states)
|
||||
if rng_states:
|
||||
fw_metadata.graphsafe_rng_state_index = (
|
||||
rng_states[0].meta["val"].device.index
|
||||
)
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
if config.unlift_effect_tokens and (
|
||||
@ -678,16 +666,6 @@ def aot_dispatch_autograd(
|
||||
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
|
||||
return_new_outs=False
|
||||
)
|
||||
|
||||
if rng_states:
|
||||
index = fw_metadata.graphsafe_rng_state_index
|
||||
assert index is not None
|
||||
rng_states = [
|
||||
get_cuda_generator_meta_val(index)
|
||||
for _ in range(fw_metadata.num_graphsafe_rng_states)
|
||||
]
|
||||
adjusted_flat_args.extend(rng_states) # type: ignore[arg-type]
|
||||
|
||||
(
|
||||
fw_module,
|
||||
adjusted_flat_args,
|
||||
|
@ -1684,45 +1684,6 @@ def _backward_prologue_functional(
|
||||
return all_args
|
||||
|
||||
|
||||
def initialize_rng_states(
|
||||
num_rng: int,
|
||||
graphsafe_idx: int,
|
||||
fwd_rng_states: list[torch.Generator],
|
||||
bwd_rng_states: list[torch.Generator],
|
||||
):
|
||||
"""
|
||||
Initialize the cudagraph safe rng states.
|
||||
|
||||
Initialization of rng states should have a few properties:
|
||||
- the initialization for each rng state should be independent
|
||||
- the initialization should be deterministic
|
||||
- the initialization should be based off current rng state, so that independent graphs do not
|
||||
have equal rng behavior
|
||||
|
||||
We defer initialization of rng states until runtime because compilation is wrapped
|
||||
with preserve_rng_states. Seed initialization should advance the rng states so consecutive compilations
|
||||
do not give equal randomness.
|
||||
"""
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
seeds = torch.randint(0, torch.iinfo(torch.int64).max, (num_rng,), device="cpu")
|
||||
fwd_rng_states.extend(
|
||||
[
|
||||
torch.cuda.default_generators[graphsafe_idx]
|
||||
.clone_state()
|
||||
.manual_seed(int(seeds[i]))
|
||||
for i in range(num_rng)
|
||||
]
|
||||
)
|
||||
bwd_rng_states.extend(
|
||||
[
|
||||
torch.cuda.default_generators[graphsafe_idx]
|
||||
.clone_state()
|
||||
.manual_seed(int(seeds[i]))
|
||||
for i in range(num_rng)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this function must be torch._dynamo.allow_in_graph-able. Non tensor/symnode inputs must be constants.
|
||||
def _backward_epilogue_functional(
|
||||
metadata, maybe_subclass_metadata, out, *, make_subclass_override=None
|
||||
@ -1858,34 +1819,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
fw_metadata: ViewAndMutationMeta, # runtime metadata
|
||||
try_save_cache_entry: Optional[Callable], # Save cache entry after compilation
|
||||
):
|
||||
# For additional context see Note [CUDA Graph Safe RNG Functionalization]
|
||||
# Each pair forward, backward rng states must be equal prior to its invocation on any
|
||||
# iteration of forward, backward. Because they are initialized equal, and are computing the same rng op,
|
||||
# running forward then backward advances them the same amount and keeps them equal.
|
||||
# However, a user may invoke multiple forwards, then backwards, such that they are not in sync.
|
||||
# Initially we have:
|
||||
# fwd_state0 == bwd_state0.
|
||||
# Lets say we run:
|
||||
# fwd0: fwd_state0 -> fwd_state1
|
||||
# fwd1: fwd_state1 -> fwd_state2
|
||||
# fwd2: fwd_state2 -> fwd_state3
|
||||
# If we now invoke bwd2,
|
||||
# we need to update bwd_state equal to the rng that was observed in fwd2.
|
||||
# we save the rng_state fwd_state2 in forward because we detect that it is not the
|
||||
# current backward state and therefore would not be accessible if we do not save it.
|
||||
# Similarly, if we are going to update the backward state to a new value, and there is a pending
|
||||
# forwards which needs its current state, we will save it.
|
||||
# Within the autograd context, we keep track of the curr iteration so that on backward
|
||||
# we know what the generator state must be before the backward is run.
|
||||
num_rng = fw_metadata.num_graphsafe_rng_states
|
||||
graphsafe_idx = fw_metadata.graphsafe_rng_state_index
|
||||
fwd_rng_states: list[torch.Generator] = []
|
||||
bwd_rng_states: list[torch.Generator] = []
|
||||
curr_fwd_iter = itertools.count(0)
|
||||
backward_state_position = 0
|
||||
pending_forwards: set[int] = set()
|
||||
saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {}
|
||||
|
||||
class CompiledFunction(torch.autograd.Function):
|
||||
compiled_fw = compiled_fw_func
|
||||
compiled_bw = compiled_bw_func
|
||||
@ -1907,26 +1840,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
assert isinstance(bw_state, BackwardState)
|
||||
ctx._compiled_autograd_backward_state = bw_state
|
||||
|
||||
if num_rng:
|
||||
if len(fwd_rng_states) == 0:
|
||||
assert graphsafe_idx is not None
|
||||
initialize_rng_states(
|
||||
num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states
|
||||
)
|
||||
|
||||
_curr_iter = next(curr_fwd_iter)
|
||||
ctx._curr_iter = _curr_iter
|
||||
|
||||
# if this state is not contained in the backward,
|
||||
# we need to save it for when its backward pass happens
|
||||
if _curr_iter != backward_state_position:
|
||||
saved_backward_tensor_states[_curr_iter] = [
|
||||
rng_state.get_state() for rng_state in fwd_rng_states
|
||||
]
|
||||
|
||||
pending_forwards.add(_curr_iter)
|
||||
args = (*args, *fwd_rng_states)
|
||||
|
||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||
# The full list of outputs and their relative order is:
|
||||
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
||||
@ -2054,45 +1967,6 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
*flat_args,
|
||||
)
|
||||
|
||||
if num_rng:
|
||||
nonlocal backward_state_position, bwd_rng_states
|
||||
curr_backward_iter = ctx._curr_iter
|
||||
retain_graph = (
|
||||
torch._C._autograd._get_current_graph_task_keep_graph()
|
||||
)
|
||||
|
||||
# Save current state if we have a pending forward that needs this state
|
||||
# or this state may be needed again because of retain graph
|
||||
if (
|
||||
backward_state_position in pending_forwards
|
||||
and backward_state_position not in saved_backward_tensor_states
|
||||
and (
|
||||
backward_state_position != curr_backward_iter
|
||||
or retain_graph
|
||||
)
|
||||
):
|
||||
saved_backward_tensor_states[backward_state_position] = [
|
||||
rng_state.get_state() for rng_state in bwd_rng_states
|
||||
]
|
||||
|
||||
# Restore saved states if needed
|
||||
if curr_backward_iter in saved_backward_tensor_states:
|
||||
if backward_state_position != curr_backward_iter:
|
||||
for bwd_state, saved_state in zip(
|
||||
bwd_rng_states,
|
||||
saved_backward_tensor_states[curr_backward_iter],
|
||||
):
|
||||
bwd_state.set_state(saved_state)
|
||||
if not retain_graph:
|
||||
del saved_backward_tensor_states[curr_backward_iter]
|
||||
else:
|
||||
assert backward_state_position == curr_backward_iter
|
||||
|
||||
backward_state_position = curr_backward_iter + 1
|
||||
if not retain_graph:
|
||||
pending_forwards.remove(curr_backward_iter)
|
||||
all_args.extend(bwd_rng_states)
|
||||
|
||||
def impl_fn(double_ctx=None):
|
||||
out = CompiledFunction._backward_impl(ctx, all_args)
|
||||
return _backward_epilogue_functional(
|
||||
|
@ -420,16 +420,11 @@ class ViewAndMutationMeta:
|
||||
# Filled after tracing joint function.
|
||||
num_backward_tokens: int = 0
|
||||
|
||||
# Number of rng states that will get thread into the forward and backward for
|
||||
# cudagraph compatible run_and_save_rng
|
||||
num_graphsafe_rng_states: int = 0
|
||||
|
||||
graphsafe_rng_state_index: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# pre-compute the indices of the inputs that are mutated.
|
||||
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
||||
# handling data-only mutations, because we keep them directly in the graph.
|
||||
|
||||
mutated_inp_runtime_indices = [
|
||||
i
|
||||
for i, m in enumerate(self.input_info)
|
||||
|
@ -489,14 +489,3 @@ def contain_metadata_mutation_ops(module: torch.fx.GraphModule) -> bool:
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_cuda_generator_meta_val(device_idx: int):
|
||||
"""
|
||||
Get a generator value to use as a meta val
|
||||
|
||||
newly cloned generator will not contain tensors. it is only Generators that are
|
||||
registered to a CUDAGraph that contain tensors. since this does not contain Tensor
|
||||
it is fine to use in the meta.
|
||||
"""
|
||||
return torch.cuda.default_generators[device_idx].clone_state()
|
||||
|
@ -210,10 +210,6 @@ torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
|
||||
# real tensor outputs.
|
||||
generate_fake_kernels_from_real_mismatches = False
|
||||
|
||||
# CUDAGraph save run_with_rng functionalization.
|
||||
# TODO: turn on by default
|
||||
graphsafe_rng_functionalization = True
|
||||
|
||||
|
||||
# Error on BypassAOTAutogradCache instead of just a warning
|
||||
# Used for tests
|
||||
|
@ -40,7 +40,7 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from ._aot_autograd.utils import is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target
|
||||
|
||||
|
||||
@ -622,99 +622,6 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
||||
return new_gm
|
||||
|
||||
|
||||
def apply_graphsafe_rng_functionalization(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
bw_module: torch.fx.GraphModule,
|
||||
fw_node: torch.fx.Node,
|
||||
bw_node: torch.fx.Node,
|
||||
device: torch.device,
|
||||
rng_count: int,
|
||||
last_fwd_input: torch.fx.Node,
|
||||
last_bwd_input: torch.fx.Node,
|
||||
):
|
||||
"""
|
||||
Note [CUDA Graph Safe RNG Functionalization]
|
||||
|
||||
CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values,
|
||||
while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a
|
||||
CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer
|
||||
to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator
|
||||
(and its cuda-tensor RNG state during graph capture).
|
||||
|
||||
For each RNG operation's forward/backward pair:
|
||||
|
||||
- We create two generators initialized with identical values
|
||||
- Each forward and backward call advances its respective generator equally
|
||||
- This keeps generators synchronized so forward and backward operations use matching RNG values
|
||||
|
||||
When forward is called multiple times before backward (causing desynchronization):
|
||||
|
||||
- We save the forward RNG state
|
||||
- We update the backward Generator's state before executing backward
|
||||
|
||||
Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator
|
||||
changes are reflected during replay.
|
||||
|
||||
This function modifies both forward and backward computation graphs by:
|
||||
|
||||
Creating RNG state placeholders for both passes
|
||||
Updating the forward node to use graph-safe RNG state
|
||||
Updating the backward node to use graph-safe RNG state
|
||||
|
||||
For more details: https://github.com/pytorch/pytorch/issues/113541
|
||||
"""
|
||||
device_idx = device.index
|
||||
assert device_idx is not None
|
||||
fw_graph = fw_module.graph
|
||||
bw_graph = bw_module.graph
|
||||
graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state
|
||||
|
||||
# Handle forward pass
|
||||
|
||||
# Note: [Generator arguments in AOTDispatcher]
|
||||
# Generator arguments in AOTDispatcher are added to support graphsafe rng
|
||||
# functionalization. See note above [CUDA Graph Safe RNG Functionalization]
|
||||
with fw_module.graph.inserting_after(last_fwd_input):
|
||||
fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}")
|
||||
fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
|
||||
last_fwd_input = fwd_rng_state
|
||||
|
||||
# Handle backward pass
|
||||
with bw_module.graph.inserting_after(last_bwd_input):
|
||||
bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}")
|
||||
# as above, clone so that meta val generator will not contain tensors
|
||||
bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
|
||||
last_bwd_input = bwd_rng_state
|
||||
|
||||
# Update forward node
|
||||
fw_kwargs = dict(fw_node.kwargs)
|
||||
fw_kwargs["rng_state"] = fwd_rng_state
|
||||
with fw_module.graph.inserting_after(fw_node):
|
||||
functional_fw_node = fw_graph.create_node(
|
||||
"call_function",
|
||||
graphsafe_run_with_rng_state,
|
||||
args=(fw_node.target, *fw_node.args), # type: ignore[arg-type]
|
||||
kwargs=fw_kwargs,
|
||||
)
|
||||
fw_node.replace_all_uses_with(functional_fw_node)
|
||||
fw_graph.erase_node(fw_node)
|
||||
|
||||
# Update backward node
|
||||
bwd_kwargs = dict(bw_node.kwargs)
|
||||
bwd_kwargs["rng_state"] = bwd_rng_state
|
||||
with bw_graph.inserting_before(bw_node):
|
||||
rng_output = bw_graph.create_node(
|
||||
"call_function",
|
||||
graphsafe_run_with_rng_state,
|
||||
args=(bw_node.target, *bw_node.args), # type: ignore[arg-type]
|
||||
kwargs=bwd_kwargs,
|
||||
)
|
||||
bw_node.replace_all_uses_with(rng_output)
|
||||
bw_graph.erase_node(bw_node)
|
||||
|
||||
return last_fwd_input, last_bwd_input
|
||||
|
||||
|
||||
def functionalize_rng_ops(
|
||||
joint_module: fx.GraphModule,
|
||||
fw_module: fx.GraphModule,
|
||||
@ -753,7 +660,7 @@ def functionalize_rng_ops(
|
||||
random_nodes[node.name] = node
|
||||
return random_nodes
|
||||
|
||||
def get_device(node) -> Optional[torch.device]:
|
||||
def get_device(node):
|
||||
"""
|
||||
Check the example value of the node outputs to find the device type.
|
||||
"""
|
||||
@ -767,12 +674,12 @@ def functionalize_rng_ops(
|
||||
for candidate in candidates:
|
||||
if isinstance(candidate, torch.Tensor):
|
||||
if candidate.device.type == "cuda":
|
||||
return candidate.device
|
||||
return "cuda"
|
||||
|
||||
return torch.device("cpu")
|
||||
return "cpu"
|
||||
|
||||
def get_sample_rng_state(device: Optional[torch.device]):
|
||||
if device is not None and device.type == "cuda":
|
||||
def get_sample_rng_state(device):
|
||||
if device == "cuda":
|
||||
return torch.cuda.get_rng_state()
|
||||
return torch.get_rng_state()
|
||||
|
||||
@ -794,7 +701,6 @@ def functionalize_rng_ops(
|
||||
|
||||
run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
|
||||
run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
|
||||
|
||||
bw_tangent_start_node = None
|
||||
for node in bw_module.graph.find_nodes(op="placeholder"):
|
||||
if "tangent" in node.name:
|
||||
@ -806,113 +712,68 @@ def functionalize_rng_ops(
|
||||
)
|
||||
|
||||
fw_rng_state_outputs = []
|
||||
|
||||
last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder")))
|
||||
last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder")))
|
||||
|
||||
devices = OrderedSet(
|
||||
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
|
||||
)
|
||||
devices.discard(torch.device("cpu"))
|
||||
# multiple cuda devices wont work with cudagraphs anyway,
|
||||
# fallback to non graphsafe rng checkpointing
|
||||
multi_cuda_devices = len(devices) > 1
|
||||
|
||||
# this changes numerics, so if fallback_random is set we will not use it
|
||||
ind_config = torch._inductor.config
|
||||
use_rng_graphsafe_rng_functionalization = (
|
||||
config.graphsafe_rng_functionalization
|
||||
and not multi_cuda_devices
|
||||
and (
|
||||
not ind_config.fallback_random
|
||||
or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random
|
||||
)
|
||||
)
|
||||
|
||||
for rng_count, (base_node, node_pair) in enumerate(
|
||||
recomputable_rng_ops_map.items()
|
||||
):
|
||||
for base_node, node_pair in recomputable_rng_ops_map.items():
|
||||
# Step 2 - Modify the fwd pass such that
|
||||
fw_node = node_pair["fwd"]
|
||||
bw_node = node_pair["bwd"]
|
||||
device = get_device(fw_node)
|
||||
|
||||
fw_graph = fw_module.graph
|
||||
bw_graph = bw_module.graph
|
||||
|
||||
if (
|
||||
use_rng_graphsafe_rng_functionalization
|
||||
and device is not None
|
||||
and device.type == "cuda"
|
||||
):
|
||||
last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization(
|
||||
fw_module,
|
||||
bw_module,
|
||||
fw_node,
|
||||
bw_node,
|
||||
device,
|
||||
rng_count,
|
||||
last_fwd_input,
|
||||
last_bwd_input,
|
||||
with fw_graph.inserting_before(fw_node):
|
||||
functional_fw_node = fw_graph.create_node(
|
||||
"call_function",
|
||||
run_and_save_rng,
|
||||
args=(fw_node.target, *fw_node.args),
|
||||
kwargs=fw_node.kwargs,
|
||||
)
|
||||
else:
|
||||
with fw_graph.inserting_before(fw_node):
|
||||
functional_fw_node = fw_graph.create_node(
|
||||
"call_function",
|
||||
run_and_save_rng,
|
||||
args=(fw_node.target, *fw_node.args),
|
||||
kwargs=fw_node.kwargs,
|
||||
)
|
||||
state = fw_graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
args=(functional_fw_node, 0),
|
||||
kwargs={},
|
||||
)
|
||||
rng_output = fw_graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
args=(
|
||||
functional_fw_node,
|
||||
1,
|
||||
),
|
||||
kwargs={},
|
||||
)
|
||||
fw_node.replace_all_uses_with(rng_output)
|
||||
fw_graph.erase_node(fw_node)
|
||||
fw_rng_state_outputs.append(state)
|
||||
state = fw_graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
args=(functional_fw_node, 0),
|
||||
kwargs={},
|
||||
)
|
||||
rng_output = fw_graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
args=(
|
||||
functional_fw_node,
|
||||
1,
|
||||
),
|
||||
kwargs={},
|
||||
)
|
||||
fw_node.replace_all_uses_with(rng_output)
|
||||
fw_graph.erase_node(fw_node)
|
||||
fw_rng_state_outputs.append(state)
|
||||
|
||||
# Step 3 - Modify the bwd pass such that
|
||||
with bw_graph.inserting_before(bw_tangent_start_node):
|
||||
state_name = f"rng_state_output_{next(uid)}"
|
||||
bw_rng_state_node = bw_graph.placeholder(state_name)
|
||||
bw_rng_state_node.meta["val"] = get_sample_rng_state(device)
|
||||
# Step 3 - Modify the bwd pass such that
|
||||
bw_graph = bw_module.graph
|
||||
with bw_graph.inserting_before(bw_tangent_start_node):
|
||||
state_name = f"rng_state_output_{next(uid)}"
|
||||
bw_rng_state_node = bw_graph.placeholder(state_name)
|
||||
bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node))
|
||||
|
||||
with bw_graph.inserting_before(bw_node):
|
||||
rng_output = bw_graph.create_node(
|
||||
"call_function",
|
||||
run_with_rng_state,
|
||||
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
|
||||
kwargs=bw_node.kwargs,
|
||||
)
|
||||
with bw_graph.inserting_before(bw_node):
|
||||
rng_output = bw_graph.create_node(
|
||||
"call_function",
|
||||
run_with_rng_state,
|
||||
args=(bw_rng_state_node, bw_node.target, *bw_node.args),
|
||||
kwargs=bw_node.kwargs,
|
||||
)
|
||||
|
||||
bw_node.replace_all_uses_with(rng_output)
|
||||
bw_graph.erase_node(bw_node)
|
||||
bw_node.replace_all_uses_with(rng_output)
|
||||
bw_graph.erase_node(bw_node)
|
||||
|
||||
# Add the rng states in the output of the fwd graph. AOT Autograd assumes
|
||||
# that symints are at the end of forward graph outputs. So, insert the new
|
||||
# rng states accordingly.
|
||||
if fw_rng_state_outputs:
|
||||
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
|
||||
fw_outputs = fw_output_node.args[0]
|
||||
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
||||
outputs = (
|
||||
fw_outputs[:sym_node_start_idx]
|
||||
+ tuple(fw_rng_state_outputs)
|
||||
+ fw_outputs[sym_node_start_idx:]
|
||||
)
|
||||
fw_module.graph.output(outputs)
|
||||
fw_module.graph.erase_node(fw_output_node)
|
||||
fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
|
||||
fw_outputs = fw_output_node.args[0]
|
||||
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
||||
outputs = (
|
||||
fw_outputs[:sym_node_start_idx]
|
||||
+ tuple(fw_rng_state_outputs)
|
||||
+ fw_outputs[sym_node_start_idx:]
|
||||
)
|
||||
fw_module.graph.output(outputs)
|
||||
fw_module.graph.erase_node(fw_output_node)
|
||||
fw_module.recompile()
|
||||
bw_module.recompile()
|
||||
return fw_module, bw_module
|
||||
@ -1988,6 +1849,7 @@ def min_cut_rematerialization_partition(
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
|
@ -884,9 +884,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
if isinstance(buf, (sympy.Expr, ir.TorchBindObject)):
|
||||
continue
|
||||
|
||||
if isinstance(buf, ir.GeneratorState):
|
||||
continue
|
||||
|
||||
# comparing strides for 0 size tensor is tricky. Ignore them for now.
|
||||
if sympy_product(buf.get_size()) == 0:
|
||||
continue
|
||||
@ -1349,9 +1346,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
code.writeline(f"{stride} = {strideof(name)}[{dim}]")
|
||||
bound_vars.add(stride)
|
||||
elif isinstance(value, ir.TorchBindObject):
|
||||
return
|
||||
elif isinstance(value, ir.GeneratorState):
|
||||
return
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"Unknown value type: {type(value)}")
|
||||
|
||||
@ -1547,11 +1542,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
# is actually a valid value for the kernel in question.
|
||||
# See https://github.com/pytorch/pytorch/issues/124686
|
||||
add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42))
|
||||
elif isinstance(value, ir.GeneratorState):
|
||||
add_expr_input(
|
||||
name,
|
||||
f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()",
|
||||
)
|
||||
else:
|
||||
shape = [
|
||||
V.graph.sizevars.size_hint(x, fallback=42)
|
||||
@ -2222,8 +2212,6 @@ class PythonWrapperCodegen(CodeGen):
|
||||
return s.codegen_reference()
|
||||
elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined]
|
||||
return dtype_to_string(s)
|
||||
elif isinstance(s, ir.GeneratorState):
|
||||
return s.codegen_reference()
|
||||
else:
|
||||
return repr(s)
|
||||
|
||||
|
@ -1505,8 +1505,6 @@ class test_configs:
|
||||
autotune_choice_name_regex: Optional[str] = None
|
||||
autotune_choice_desc_regex: Optional[str] = None
|
||||
|
||||
graphsafe_rng_func_ignores_fallback_random = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
@ -908,7 +908,6 @@ class CUDAGraphNode:
|
||||
self.recorded_liveness_before_graph = curr_liveness
|
||||
self.expected_dead_indices_before_graph = different_indices
|
||||
|
||||
rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)]
|
||||
recording_inputs = self._allocate_and_copy_recording_inputs(inputs)
|
||||
# recording inputs will copy over memory, so we can free non recording inputs
|
||||
inputs.clear()
|
||||
@ -917,11 +916,6 @@ class CUDAGraphNode:
|
||||
# graph used for recording model invocation
|
||||
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
||||
|
||||
# TODO: register_generator_state should potentially take explicit device
|
||||
with torch.cuda.device(self.device):
|
||||
for rng_state in rng_states:
|
||||
self.graph.register_generator_state(rng_state)
|
||||
|
||||
# we allocate non-static inputs within the same memory pool as the CUDAGraph
|
||||
# which we will record the model with. For memory efficiency, it is important
|
||||
# to reclaim the input memory when the inputs are no longer live. To accomplish this,
|
||||
@ -1608,7 +1602,7 @@ class CUDAGraphNode:
|
||||
):
|
||||
for i, inp in enumerate(inputs):
|
||||
if not isinstance(inp, torch.Tensor):
|
||||
assert isinstance(inp, (int, torch.Generator))
|
||||
assert isinstance(inp, int)
|
||||
recording_inputs.append(inp)
|
||||
elif i not in self.static_input_idxs:
|
||||
# static_input does an allocation!
|
||||
|
@ -1039,18 +1039,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# Alternately we could filter this out in AotAutograd
|
||||
self.graph_input_names.append(target)
|
||||
return None
|
||||
# See note: Note: [Generator arguments in AOTDispatcher]
|
||||
elif isinstance(example, torch.Generator):
|
||||
assert (
|
||||
len(V.graph.current_node.users) == 1
|
||||
and next(iter(V.graph.current_node.users)).target
|
||||
is torch._prims.rng_prims.graphsafe_run_with_rng_state
|
||||
)
|
||||
gen = ir.GeneratorState(name=target, device=example.device)
|
||||
self.graph_inputs[target] = gen # type: ignore[assignment]
|
||||
self.graph_input_names.append(target)
|
||||
return gen
|
||||
|
||||
assert isinstance(example, torch.Tensor), example
|
||||
# todo(chilli): We can remove the last check once we turn buffers into
|
||||
# static shape tensors. That's a hack to workaround Inductor believing
|
||||
@ -1299,7 +1287,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if isinstance(value, TorchBindObject):
|
||||
continue
|
||||
assert isinstance(
|
||||
value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState)
|
||||
value, (TensorBox, sympy.Expr)
|
||||
), f"Unsupported inductor graph input type: {type(value)}"
|
||||
if not isinstance(value, TensorBox):
|
||||
continue
|
||||
|
@ -4965,9 +4965,7 @@ class ExternKernel(InputsKernel):
|
||||
tensor_args = []
|
||||
non_tensor_args: list[Any] = []
|
||||
for arg in args_flat:
|
||||
is_arg_tensor.append(
|
||||
isinstance(arg, IRNode) and not isinstance(arg, NonTensorObj)
|
||||
)
|
||||
is_arg_tensor.append(isinstance(arg, IRNode))
|
||||
if is_arg_tensor[-1]:
|
||||
tensor_args.append(arg)
|
||||
else:
|
||||
@ -4998,9 +4996,7 @@ class ExternKernel(InputsKernel):
|
||||
# Rerun fake tensor propagation, because Inductor may have changed the
|
||||
# strides of inputs and we need to determine accurately what the
|
||||
# output stride will be.
|
||||
example_args: list[
|
||||
Union[torch.Tensor, torch._C.ScriptObject, torch.Generator]
|
||||
] = []
|
||||
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
||||
|
||||
# We need to retain the constant values of fake tensors that we originally
|
||||
# propagated the graph with, because for some operators running without a
|
||||
@ -5017,12 +5013,6 @@ class ExternKernel(InputsKernel):
|
||||
example_args.append(V.graph.torchbind_constants[x.get_name()])
|
||||
elif isinstance(x, TorchBindObject):
|
||||
example_args.append(x.get_real_obj())
|
||||
elif isinstance(x, torch._inductor.ir.GeneratorState):
|
||||
device_index = x.device.index
|
||||
assert x.device.type == "cuda" and device_index is not None
|
||||
example_args.append(
|
||||
torch.cuda.default_generators[device_index].clone_state()
|
||||
)
|
||||
else:
|
||||
example_args.append(ir_node_to_tensor(x, guard_shape=True))
|
||||
|
||||
@ -5153,7 +5143,7 @@ class ExternKernel(InputsKernel):
|
||||
# TODO(jansel): impose layout preference on realized buffer
|
||||
x.realize()
|
||||
return x
|
||||
if isinstance(x, (NonTensorObj)):
|
||||
if isinstance(x, TorchBindObject):
|
||||
return x
|
||||
return cls.copy_input(x)
|
||||
|
||||
@ -7568,12 +7558,8 @@ class EffectfulKernel(FallbackKernel):
|
||||
return True
|
||||
|
||||
|
||||
class NonTensorObj(IRNode):
|
||||
pass
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
class TorchBindObject(NonTensorObj):
|
||||
class TorchBindObject(IRNode):
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
name: str
|
||||
@ -7607,18 +7593,6 @@ class TorchBindObject(NonTensorObj):
|
||||
return functools.reduce(lambda x, y: x + y, flat_sizes, 0)
|
||||
|
||||
|
||||
@ir_dataclass
|
||||
class GeneratorState(NonTensorObj):
|
||||
name: str
|
||||
device: torch.device
|
||||
|
||||
def get_name(self): # type: ignore[no-untyped-def]
|
||||
return self.name
|
||||
|
||||
def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class _CollectiveKernel(FallbackKernel):
|
||||
def should_allocate(self) -> bool:
|
||||
return False
|
||||
|
@ -2718,8 +2718,6 @@ make_fallback(aten.gcd.default, warn=False)
|
||||
make_fallback(aten._thnn_fused_lstm_cell, require_dense)
|
||||
make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
|
||||
make_fallback(torch._prims.rng_prims.run_with_rng_state)
|
||||
make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
|
||||
|
||||
|
||||
# Implmented / Half implemented
|
||||
# Scans. Implemented for CUDA, missing CPU
|
||||
|
@ -426,7 +426,7 @@ class CompiledFxGraph(OutputCode):
|
||||
(not complex_memory_overlap_inputs, "complex memory overlap"),
|
||||
(
|
||||
all(
|
||||
isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator))
|
||||
isinstance(t, (torch.Tensor, torch.SymInt))
|
||||
for t in example_inputs
|
||||
),
|
||||
"non-Tensor inputs",
|
||||
|
@ -2104,7 +2104,6 @@ def count_tangents(fx_g: torch.fx.GraphModule) -> int:
|
||||
"tangents" not in x.name
|
||||
and "bwd_seed" not in x.name
|
||||
and "bwd_base_offset" not in x.name
|
||||
and "bwd_rng_state" not in x.name
|
||||
)
|
||||
|
||||
arg_count = 0
|
||||
|
@ -315,75 +315,5 @@ run_and_save_rng_state = register_run_and_save_rng_state_op()
|
||||
run_with_rng_state = register_run_with_rng_state_op()
|
||||
|
||||
|
||||
def register_graphsafe_run_with_rng_state_op():
|
||||
class GraphSafeRunWithRngState(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("graphsafe_run_with_rng_state")
|
||||
|
||||
def __call__(self, op, *args, rng_state=None, **kwargs):
|
||||
return super().__call__(op, *args, rng_state=rng_state, **kwargs)
|
||||
|
||||
graphsafe_run_with_rng_state = GraphSafeRunWithRngState()
|
||||
|
||||
graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)(
|
||||
autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True)
|
||||
)
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
|
||||
def impl_cuda(op, *args, rng_state=None, **kwargs):
|
||||
device_idx = rng_state.device.index
|
||||
generator = torch.cuda.default_generators[device_idx]
|
||||
current_state = generator.graphsafe_get_state()
|
||||
generator.graphsafe_set_state(rng_state)
|
||||
out = op(*args, **kwargs)
|
||||
generator.graphsafe_set_state(current_state)
|
||||
return out
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
|
||||
def impl_backend_select(op, *args, rng_state=None, **kwargs):
|
||||
device = get_device(args, kwargs)
|
||||
assert (
|
||||
device == "cuda"
|
||||
), f"GraphSafe RNG operations only supported for CUDA, got {device}"
|
||||
return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
|
||||
def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs):
|
||||
with mode:
|
||||
return op(*args, **kwargs)
|
||||
|
||||
@graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode)
|
||||
def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs):
|
||||
with disable_proxy_modes_tracing():
|
||||
out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs)
|
||||
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
|
||||
proxy_kwargs = pytree.tree_map(
|
||||
mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs}
|
||||
)
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs
|
||||
)
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||
|
||||
@graphsafe_run_with_rng_state.py_functionalize_impl
|
||||
def impl_functional(ctx, op, *args, rng_state=None, **kwargs):
|
||||
unwrapped_rng_state = (
|
||||
ctx.unwrap_tensors(rng_state) if rng_state is not None else None
|
||||
)
|
||||
unwrapped_args = ctx.unwrap_tensors(args)
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
|
||||
|
||||
with ctx.redispatch_to_next():
|
||||
out = graphsafe_run_with_rng_state(
|
||||
op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs
|
||||
)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
return graphsafe_run_with_rng_state
|
||||
|
||||
|
||||
graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op()
|
||||
|
||||
|
||||
def register_rng_prims():
|
||||
register_philox_rand()
|
||||
|
@ -819,9 +819,6 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
|
||||
yield from _iterate_exprs(val.storage_offset())
|
||||
elif val is None:
|
||||
pass
|
||||
# see Note: [Generator arguments in AOTDispatcher]
|
||||
elif isinstance(val, torch.Generator):
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
||||
|
||||
|
@ -77,7 +77,6 @@ FIXME_hop_that_doesnt_have_opinfo_test_allowlist = [
|
||||
"autograd_function_apply",
|
||||
"run_and_save_rng_state",
|
||||
"run_with_rng_state",
|
||||
"graphsafe_run_with_rng_state",
|
||||
"out_dtype",
|
||||
"trace_wrapped",
|
||||
'tag_activation_checkpoint',
|
||||
|
Reference in New Issue
Block a user