Files
pytorch/test/inductor/test_cudagraph_trees.py
eellison 481a57bc37 Support torch.compile rng selective activation checkpointing with cudagraph (#146878)
TODO:
- [x]  Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x]  Tests
- [x] handling of retain_graph
- [x] respect fallback random

Fix for https://github.com/pytorch/pytorch/issues/130123.

Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.

We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.

```
 ===== Forward graph 1 =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0);  fwd_rng_state_0 = None
        ...

 ===== Backward graph 1 =====
    def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
        sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)

        # No stacktrace found for following nodes
        graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0);  bwd_rng_state_0 = None
```

There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0

Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.

Other notes:

Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.

Questions for reviewers:

This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.

Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set

I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.

Edit: updated to be taken from randint()

Update: initializing rng states from torch.randint..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2025-02-28 00:47:03 +00:00

2926 lines
108 KiB
Python

# Owner(s): ["module: inductor"]
# ruff: noqa: F841
import contextlib
import functools
import gc
import importlib
import itertools
import sys
import unittest
import warnings
from collections import defaultdict
from collections.abc import Mapping, Sequence
import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode
from torch._dynamo.utils import counters
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._inductor import config
from torch._inductor.codecache import FxGraphCache
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch._inductor.cudagraph_utils import FunctionID
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._ops import OpOverload
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.immutable_collections import immutable_dict
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_CI,
IS_LINUX,
IS_WINDOWS,
parametrize,
skipIfRocm,
TEST_CUDA_GRAPH,
TEST_WITH_ASAN,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_multigpu = functools.partial(
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
)
from io import StringIO
def get_compile_fn(backend):
if backend == "cudagraphs":
return functools.partial(torch.compile, backend="cudagraphs")
else:
return functools.partial(torch.compile, mode="reduce-overhead")
class capture_stderr(list):
"""
Replace sys.stderr with a temporary StringIO
"""
def __enter__(self):
self.sys_stderr = sys.stderr
self.stringio = StringIO()
sys.stderr = self.stringio
return self
def __exit__(self, *args):
self.append(str(self.stringio.getvalue()))
del self.stringio
sys.stderr = self.sys_stderr
def cdata(t):
return t.untyped_storage()._cdata
class TestCase(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
torch._dynamo.reset()
super().setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
if HAS_CUDA and not TEST_WITH_ASAN:
def get_all_cudagraph_segments():
segments = torch.cuda.memory_snapshot()
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
def all_live_blocks():
blocks_addrs = []
for segment in get_all_cudagraph_segments():
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
blocks_addrs.append(addr)
addr += block["size"]
return blocks_addrs
def all_live_block_count():
return len(all_live_blocks())
class CudaGraphTreeTests(TestCase):
def setUp(self):
super().setUp()
self.graph_stack = contextlib.ExitStack()
self.graph_stack.enter_context(
config.patch(
{
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
"triton.fast_path_cudagraph_asserts": True, # too slow
"triton.slow_path_cudagraph_asserts": True,
}
)
)
self.graph_stack.enter_context(
dynamo_config.patch(automatic_dynamic_shapes=True)
)
self.device_idx = torch.rand([0], device="cuda").device.index
warnings.filterwarnings("ignore")
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
self.graph_stack.close()
self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(len(get_all_cudagraph_segments()), 0)
warnings.resetwarnings()
def get_manager(self, device_index=None):
return torch._inductor.cudagraph_trees.get_container(
self.device_idx if not device_index else device_index
).tree_manager
def get_roots(self):
return self.get_manager().get_roots()
def curr_node(self):
return self.get_manager().current_node
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_inference=is_inference,
is_backward=is_backward,
)
@staticmethod
def run_twc(fn, *args, **kwargs):
fn(*args, **kwargs)
return fn(*args, **kwargs)
def num_checkpoints(self):
return self.get_manager().debug_checkpointing_counter
def test_run_simple(self):
def foo(x):
return x * x * x
foo_opt = torch.compile(foo)
ones = torch.ones([4, 4], device="cuda")
zeros = torch.zeros([5, 5], device="cuda")
self.run_twc(foo_opt, ones)
self.run_twc(foo_opt, zeros)
self.assertEqual(self.get_root_children(), [0, 0])
def check_rng(self):
@torch.compile(mode="reduce-overhead")
def foo():
return torch.rand([20])
torch.manual_seed(0)
out = foo()
out2 = foo()
out3 = foo()
torch.manual_seed(0)
self.assertEqual(out, foo())
self.assertEqual(out2, foo())
self.assertEqual(out3, foo())
@torch._inductor.config.patch("fallback_random", True)
def test_rng_trees(self):
self.check_rng()
@torch._inductor.config.patch("triton.cudagraph_trees", False)
@torch._inductor.config.patch("fallback_random", True)
def test_rng_non_trees(self):
self.check_rng()
def test_mutation_reinplaced(self):
import torch.nn as nn
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, other, out):
input = torch.logical_xor(input=input, other=other, out=out)
return input
x = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
y = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
z = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float16).cuda()
model = Model().cuda()
eag = model(x, y, z)
with capture_stderr() as captured_output:
opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check("torch.logical_xor").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@requires_multigpu()
@parametrize("backend", ("inductor", "cudagraphs"))
def test_multiple_devices_msg(self, backend):
def foo(x, y):
return (x + 1, y + 2)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
foo(torch.ones([10], device="cuda"), torch.ones([20]))
FileCheck().check(
"skipping cudagraphs due to cpu device (arg1_1). Found from"
).check("y + 2").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
with capture_stderr() as captured_output:
foo(
torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1")
)
FileCheck().check("skipping cudagraphs due to multiple devices").run(
captured_output[0]
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
@torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
def test_skip_symbolic(self):
@torch.compile(dynamic=True)
def foo(x, y):
return x + y
with capture_stderr() as captured_output:
foo(torch.rand([10], device="cuda"), torch.rand([10], device="cuda"))
FileCheck().check(
"skipping cudagraphs due to graph with symbolic shapes inputs"
).check("x + y").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_on_inp(self, backend):
def foo(x):
x.add_(2)
return x
foo = get_compile_fn(backend)(foo)
def inp():
return torch.ones([10], device="cuda")
with capture_stderr() as captured_output:
foo(inp())
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check(".add_(2)").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
# mutation on inp doesnt hit cudagraphs
self.assertEqual(len(self.get_manager().roots), 0)
# mutation on parameters/buffers hits cudagraphs
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buf = torch.ones([10], device="cuda")
def forward(self, x):
self.buf.add_(x)
return self.buf + x
def foo(mod, x):
return mod(x)
foo = get_compile_fn(backend)(foo)
mod = Mod()
mod2 = Mod()
for _ in range(3):
self.assertEqual(foo(mod, inp()), mod2(inp()))
self.assertEqual(mod.buf, mod2.buf)
self.assertIsNotNone(self.get_manager())
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_mutation_cudagraph_managed_tensors_config(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def non_mut(x):
return x.add(2)
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_out = mut(tmp)
self.assertEqual(mut_out, non_mut(foo(inp)))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensors(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def non_mut(x):
return x.add(2)
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_out = mut(tmp)
self.assertEqual(mut_out, non_mut(foo(inp)))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
0,
exactly=True,
).run(captured_output[0])
self.assertTrue("cudagraph_skips" not in counters["inductor"])
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_inp = tmp.clone()
# in this case, what previously a mutated cudagraph managed tensor is no longer,
# now its an input from eager we should fallback to inductor without cudagraphs
with capture_stderr() as captured_output:
mut(mut_inp)
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check("x.add_(2)").run(captured_output[0])
self.assertEqual(mut_inp, non_mut(foo(inp)))
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn(self, backend):
def foo(x):
return x.add_(1)
def fee(y, z):
return z.add(3)
def inp():
return torch.rand([4], device="cuda")
foo = get_compile_fn(backend)(foo)
fee = get_compile_fn(backend)(fee)
with capture_stderr() as captured_output:
for _ in range(3):
torch.compiler.cudagraph_mark_step_begin()
fee(inp(), foo(inp()))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def inp():
return torch.rand([4], device="cuda")
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
# Should warn for current_node=None
mut(inp())
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
tmp = foo(inp())
mut(tmp) # should not warn
mut_inp = tmp.clone()
mut(mut_inp) # should not warn since mut has warned
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
def test_function_compiled_multiple_times(self):
def foo(x):
y = foo2(x)
y2 = foo2(y)
return y + y2
def foo2(x):
torch._dynamo.graph_break()
return x * x * x
foo_opt = torch.compile(foo)
ones = torch.ones([4, 4], device="cuda")
foo(ones)
foo_opt(ones)
foo_opt(ones)
self.assertEqual(foo_opt(ones), foo(ones))
# paths
children = self.get_root_children()
# one root with two children
self.assertEqual(children, [2])
def test_end_recording_early(self):
def foo(x):
y = x * x * x
torch._dynamo.graph_break()
z = x + y
return z
@torch.compile
def foo2(x):
return x + 4
foo_opt = torch.compile(foo)
for _ in range(3):
out = foo_opt(torch.ones([4, 4], device="cuda"))
del out
# when I tried inducing separate recordings via graph break,
# the frame kept interferring by keeping outputs alive
# this isnt great by simulates the logic.
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out = foo2(torch.ones([4, 4], device="cuda"))
del out
foo_opt(torch.ones([4, 4], device="cuda"))
# Two separate traces - one has a child, one doesnt
self.assertEqual(self.get_root_children(), [1, 0])
def test_execution_into_recording(self):
def foo(x):
y = x + x
if y.sum() > 0:
return y + 10
else:
return y - 10
foo_opt = torch.compile(foo)
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
self.assertEqual(foo_opt(inp), foo(inp))
self.assertEqual(foo_opt(inp), foo(inp))
inp.add_(1)
out_eager = foo(inp)
out_warmup = foo_opt(inp)
self.assertEqual(out_warmup, out_eager)
# warmup should be have storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
out_live = foo_opt(inp)
self.assertEqual(out_live, out_eager)
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
# warmup should have been freed
del out_warmup
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
del out_live
self.assertEqual(all_live_block_count(), 0)
out = foo_opt(inp)
self.assertEqual(foo(inp), out)
# should be in execution mode
self.assertEqual(all_live_block_count(), 0)
def test_forward_with_skipped_cudagraphed_backward(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
for _ in range(3):
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
# we should not have cudagraph'd the backwards
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 1)
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_cache_hit_forward_miss_backward(self):
# Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
# Run forwards, fx graph should cache miss
for _ in range(3):
torch._dynamo.reset()
counters.clear()
FxGraphCache.clear()
AOTAutogradCache.clear()
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
# Reset dynamo and related caches except for FXGraphCache
torch._dynamo.reset()
# Forwards should be a cache hit now, we still skip cudagraphs
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
# Run backward without complex memory overlap being set
# Run the backward without complex memory overlap reason
# cache should miss, but cudagraphs should not run
# because forward skipped it
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
# Run it one more time, this time AOTAutogradCache will hit
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
torch._dynamo.reset()
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
# we should not have cudagraph'd anything
assert self.get_manager() is None
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_backward_gets_cached_cudagraphs(self):
# We pass cpu tensors to foo and save that into the cache
# On a subsequent run in a new process, cudagraphs should be
# disabled properly on both forward and backwards runs.
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
torch._dynamo.reset()
counters.clear()
FxGraphCache.clear()
AOTAutogradCache.clear()
# Use cpu device to disable cudagraphs during compilation
inp = torch.rand([20, 20], device="cpu", requires_grad=True)
out = foo(inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
out.backward(back_inp)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
# Run again on new process
torch._dynamo.reset()
# Forward and backward should also disable cudagraphs without compilation
inp = torch.rand([20, 20], device="cpu", requires_grad=True)
out = foo(inp)
# AOTAutogradCache will load the forward and the backward from cache immediately, so fx_graph_cache_hit will equal 2
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
torch._dynamo.reset()
back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
out.backward(back_inp)
# we should not have cudagraph'd anything
assert self.get_manager() is None
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
@torch._functorch.config.patch("enable_autograd_cache", True)
@torch._inductor.config.patch("fx_graph_cache", True)
@torch._inductor.config.patch("fx_graph_remote_cache", False)
# Currently fx graph cache is turned off for specialize_float=False
@torch._dynamo.config.patch("specialize_float", True)
def test_cached_forward_backward(self):
counters.clear()
AOTAutogradCache.clear()
FxGraphCache.clear()
@torch.compile
def foo(x):
torch.manual_seed(0)
y = x * 2
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
inp2 = inp.detach().clone().requires_grad_(True)
out = foo(inp)
out.sum().backward()
self.assertEqual(self.get_root_children(), [1])
# the three saved tensors should die in the backward
# we kept alive the output
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
self.assertEqual(
self.curr_node().expected_dead_indices_after_graph,
[(0, 1), (0, 2)],
)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
# Reset dynamo and rerun. We should see a cache hit now
torch._dynamo.reset()
out2 = foo(inp2)
out2.sum().backward()
self.assertEqual(out, out2)
self.assertEqual(inp.grad, inp2.grad)
self.assertEqual(self.get_root_children(), [1])
self.assertFalse(self.get_manager().new_graph_id().id == 0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
def test_forward_backward_not_called(self, backend):
def foo(x, y):
x_out = x * x * x
torch._dynamo.graph_break()
y_out = y * y * y
return x_out, y_out
foo = get_compile_fn(backend)(foo)
for _ in range(3):
inps = [
torch.rand([20, 20], requires_grad=True, device="cuda")
for _ in range(2)
]
x_out, y_out = foo(inps[0], inps[1])
x_out.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
# we should not have cudagraph'd the y backward
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 3)
def _test_unaligned_static_input_impl(self, expected_clones):
def fn(x, y):
return (x + y,)
def get_aligned_inputs():
return [torch.rand([5, 5], device="cuda") for _ in range(2)]
mod = make_fx(fn)(*get_aligned_inputs())
mode = torch._subclasses.FakeTensorMode()
with mode:
inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
compiled_f = compile_fx_inner(
mod, inps, static_input_idxs=[0], cudagraphs=True
)
def get_unaligned_inputs():
return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
class CloneCounterMode(TorchDispatchMode):
def __init__(self) -> None:
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
self.count += func is torch.ops.aten.clone.default
return func(*args, **kwargs)
for _ in range(3):
with CloneCounterMode() as m:
compiled_f(get_unaligned_inputs())
self.assertEqual(m.count, expected_clones)
compiled_f(get_aligned_inputs())
self.assertEqual(m.count, expected_clones)
def test_unaligned_static_input_trees(self):
self._test_unaligned_static_input_impl(expected_clones=0)
@torch._inductor.config.patch("triton.cudagraph_trees", False)
def test_unaligned_static_input_non_trees(self):
self._test_unaligned_static_input_impl(expected_clones=0)
@torch._inductor.config.patch("triton.cudagraphs", False)
def test_unaligned_static_input_no_cudagraphs(self):
self._test_unaligned_static_input_impl(expected_clones=0)
def test_sparsity(self):
def foo(view_6, buf31):
return aten._sparse_coo_tensor_with_dims_and_tensors(
1,
1,
[1000000, 64],
view_6,
buf31,
dtype=torch.float32,
layout=torch.sparse_coo,
device="cuda",
pin_memory=None,
)
foo_opt = torch.compile(foo)
view_6 = torch.zeros([1, 102397], dtype=torch.int64, device="cuda")
buf31 = torch.rand([102397, 64], device="cuda")
for _ in range(3):
self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31))
def test_accumulate_multiple_recordings(self):
def foo(x):
y = x + x + x
torch._dynamo.graph_break()
if y.sum() <= 0:
return y
else:
return y * 10
foo_opt = torch.compile(foo)
# two separate compilations & recordings
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
# out1 gets manually freed
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
self.assertEqual(all_live_block_count(), 1)
del out1, out2
self.assertEqual(all_live_block_count(), 1)
del out3
gc.collect()
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("freezing", True)
def test_constant_output(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(
torch.tensor([float(i) for i in range(10)], device="cuda")
)
def forward(self, inp):
return self.param, self.param[0:2], inp + 2
inp = torch.tensor([2], device="cuda")
m = Mod()
with torch.no_grad():
out_eager = m(inp)
m_comp = torch.compile(m)
for _ in range(3):
self.assertEqual(out_eager, m_comp(inp))
def test_live_outputs_multiple_graphs(self):
def foo(x):
x = x + x + x
y = x + 1
torch._dynamo.graph_break()
z = x * x
if z.sum() > 0:
return y + 1
else:
return y
foo_opt = torch.compile(foo)
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
self.assertEqual(self.num_checkpoints(), 0)
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
del out
self.assertEqual(all_live_block_count(), 0)
# we need to checkpoint from function to warmup y + 1,
# and then again to record it
self.assertEqual(self.num_checkpoints(), 2)
def test_expanded_inputs(self):
x = torch.rand(1, 512, device="cuda").expand(4, 512)
def foo(x):
return x + 4 + torch.ones([4, 512], device="cuda")
foo_opt = torch.compile()(foo)
for _ in range(3):
self.assertEqual(foo_opt(x), foo(x))
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_dies_between_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
foo_cg(inp_list)
foo_cg([inp])
out1, out2 = foo_cg([inp])
inp = [out1]
del out1, out2
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
def test_aliased_storage_single_weakref(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
y = x * 10
y_alias = y[0]
torch._dynamo.graph_break()
ind = torch.tensor(4, device="cuda")
x_alias2 = x[ind:]
y_alias2 = y[ind:]
return x, x_alias, x_alias2, y_alias, y_alias2
for _ in range(4):
outs = foo(torch.rand([20, 20], device="cuda"))
ptr_to_ref = {
out.untyped_storage().data_ptr(): out.untyped_storage()._cdata
for out in outs
}
self.assertEqual(len(ptr_to_ref), 2)
for out in outs:
self.assertEqual(
ptr_to_ref[out.untyped_storage().data_ptr()],
out.untyped_storage()._cdata,
)
del outs
del out
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]
m = Mod(10, 10).cuda()
@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]
param_c = cdata(m.weight)
for _ in range(3):
x = torch.rand([10, 10], device="cuda", requires_grad=True)
torch.compiler.cudagraph_mark_step_begin()
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))
m.weight.grad = None
m.bias.grad = None
node = self.curr_node()
first_node = next(node._path_from_root)
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
@torch._inductor.config.patch("implicit_fallbacks", True)
def test_multinomial(self):
def sample_multinomial(probs, num_samples, replacement=True):
return torch.multinomial(probs, num_samples, replacement=replacement)
# Create and prepare probability tensor on GPU
probs = torch.tensor([0.1, 0.2, 0.3, 0.4]).cuda()
probs = probs / probs.sum()
# Sample using the function
num_skipped = counters["inductor"]["cudagraph_skips"]
with torch._dynamo.utils.preserve_rng_state():
samples = self.run_twc(
sample_multinomial, probs, num_samples=5, replacement=True
)
with torch._dynamo.utils.preserve_rng_state():
samples_compiled = self.run_twc(
torch.compile(sample_multinomial),
probs,
num_samples=5,
replacement=True,
)
self.assertEqual(samples, samples_compiled)
self.assertEqual(num_skipped, counters["inductor"]["cudagraph_skips"])
@skipIfRocm
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x @ x
def inp():
return torch.rand([20, 20], device="cuda", requires_grad=False)
for _ in range(3):
foo(inp())
self.assertEqual(self.num_checkpoints(), 0)
out = foo(inp())
out_id = id(out)
del out
self.assertEqual(id(foo(inp())), out_id)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[0], x @ x
for i in range(2):
out = foo(inp())
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out_alias, out2 = foo2(out)
del out_alias
self.assertEqual(all_live_block_count(), 2)
del out
self.assertEqual(all_live_block_count(), 1)
del out2
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(self.num_checkpoints(), i + 1)
new_out = foo(inp())
curr_node = self.curr_node()
self.assertFalse(curr_node.unaliased_in_all_paths[0])
self.assertFalse(out_id == id(new_out))
def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
return (x[0],)
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))
node = self.curr_node()
self.assertEqual(node.cached_tensor_outputs, [None])
self.assertEqual(node.unaliased_in_all_paths, [False])
def test_warmup_stream_sync(self):
def foo(args):
x = args[0]
args.clear()
x_orig = x
for _ in range(100):
x = x @ x
return (x,)
inp = torch.rand([4096, 4096], device="cuda")
ref = foo([inp])[0]
torch.cuda.synchronize()
user_stream = torch.cuda.Stream()
with torch.cuda.stream(user_stream):
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
out = foo_cg([inp])[0]
y = out + 1
self.assertEqual(y, ref + 1)
def test_unaligned_static_parameter(self):
def gen_inp():
inp = torch.ones([20], device="cuda")
return [inp[1:]]
def foo(args):
x = args[0]
args.clear()
return (x + x,)
foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,))
for _ in range(3):
out = foo_cg(gen_inp())
self.assertEqual(out, foo(gen_inp()))
del out
node = self.curr_node()
self.assertEqual(node.static_input_data_ptrs, [None])
def test_amp_cache_disabled(self):
@torch.compile()
def foo(x):
return x + x
for _ in range(3):
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
# amp cache for cudagraph outputs should be disabled
t2 = torch.rand([4, 4], device="cuda")
with torch.cuda.amp.autocast():
run_once = out @ t2
out.detach().zero_()
run_twice = out @ t2
self.assertNotEqual(run_once, run_twice)
def test_remove_hooks_on_cached_tensors(self):
@torch.compile()
def foo(x):
return x * x
inp = torch.rand([4], device="cuda", requires_grad=True)
for _ in range(5):
out = foo(inp)
self.assertIsNone(out._backward_hooks)
out.register_hook(lambda: None)
# today, torch.compile never outputs a leaf tensor which is the only
# tensor that can register _post_accumulate_grad_hooks
# add this as a preventative test
@torch.compile()
def foo(x):
return torch.rand([4], device="cuda", requires_grad=True)
for _ in range(5):
out = foo(inp)
self.assertIsNone(out._post_accumulate_grad_hooks)
out.register_post_accumulate_grad_hook(lambda: None)
def test_multiple_insert_removal_caching(self):
torch._C._set_cached_tensors_enabled(True)
try:
x = torch.rand([4], device="cuda")
torch._C._add_cached_tensor(x)
self.assertTrue(torch._C._is_cached_tensor(x))
torch._C._add_cached_tensor(x)
torch._C._remove_cached_tensor(x)
self.assertFalse(torch._C._is_cached_tensor(x))
finally:
torch._C._set_cached_tensors_enabled(False)
def test_accumulate_grad(self):
# cudagraph trees shouldnt interfere with accumulation logic
def compute_grad(grad_output, create_graph):
x = torch.randn(5, 5, requires_grad=True, device="cuda")
@torch.compile()
def foo(x):
return x + 2
y = foo(x)
y.backward(grad_output, retain_graph=True)
x_grad = x.grad
x_grad_clone = x.grad.clone()
y.backward(grad_output, create_graph=create_graph)
return x_grad, x_grad_clone
for _ in range(3):
grad_output = torch.ones(5, 5, device="cuda")
# Accumulate in-place when create_graph is False
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=False)
self.assertEqual(x_grad, x_grad_clone * 2)
# Accumulate out-of-place when create_graph is False
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=True)
self.assertEqual(x_grad, x_grad_clone)
def test_frozen_fn(self):
@torch.compile()
def foo(x):
return x @ x
for _ in range(3):
out = foo(torch.rand([10, 10], device="cuda"))
self.assertTrue(self.get_manager().new_graph_id().id == 1)
frozen = torch._dynamo.run(foo)
for _ in range(3):
out = frozen(torch.rand([10, 10], device="cuda"))
# didnt do additional recordings
self.assertTrue(self.get_manager().new_graph_id().id == 2)
def test_empty_cpu_tensor(self):
def foo(x):
return x @ x, torch.tensor([])
foo_opt = torch.compile(foo)
x = torch.rand([4], device="cuda")
for _ in range(3):
out_opt = foo_opt(x)
self.assertEqual(foo(x), out_opt)
self.assertTrue(self.get_manager().new_graph_id().id == 1)
def test_output_alias(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
out = x + x
return (x, x[0])
foo_cg = self.cudagraphify_impl(foo, [inp], ())
for _ in range(3):
out_1, out_2 = foo_cg([inp])
self.assertEqual(cdata(out_1), cdata(out_2))
del out_1, out_2
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
def test_empty_storage(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return (
(x + x + x),
torch.zeros([0], device="cuda"),
torch.zeros([100], device="cuda")[0:0],
)
inp = torch.rand([4], device="cuda")
for _ in range(3):
out = foo(inp)
node = self.curr_node()
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@torch.compile(mode="reduce-overhead")
def foo(x):
return (x + x + x), torch.rand([4], device="cuda") + 10
inp = torch.rand([0], device="cuda")
for _ in range(3):
out = foo(inp)
node = self.curr_node()
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
y = x + 2
return x + 1, y, y[0]
inp = torch.rand([4, 4], device="cuda")
foo_cg = self.cudagraphify_impl(foo, [inp], ())
foo_cg([inp])
foo_cg([inp])
out1, out2, out3 = foo_cg([inp])
inp = [out1]
del out1, out2, out3
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 and out3 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
@skipIfRocm
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
def test_workspace_allocation_error(self):
torch._C._cuda_clearCublasWorkspaces()
prev = torch._inductor.cudagraph_trees.clear_cublas_manager
try:
torch._inductor.cudagraph_trees.clear_cublas_manager = (
contextlib.nullcontext
)
@torch.compile()
def foo(x, y):
return x @ x
inps = [torch.rand([400, 400], device="cuda") for _ in range(2)]
thrown = False
try:
foo(*inps)
except Exception as e:
thrown = True
self.assertTrue(
"at::cuda::blas::gemm<float>" in str(e)
or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
)
self.assertTrue(
"getCurrentCUDABlasHandle" in str(e)
or "getNewWorkspace" in str(e)
)
self.assertTrue(thrown)
finally:
torch._C._cuda_clearCublasWorkspaces()
torch._inductor.cudagraph_trees.clear_cublas_manager = prev
torch._inductor.cudagraph_trees.get_container(
self.device_idx
).tree_manager = None
def test_peristed_output_livenes(self):
@torch.compile
def foo(x):
return x + x
for _ in range(3):
foo(torch.rand([2, 2], device="cuda"))
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
out = foo(torch.rand([2, 2], device="cuda"))
self.assertTrue(out is node.cached_tensor_outputs[0])
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
out_ref = out[0:]
del out
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
del out_ref
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_no_longer_in_pool(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
x1, x2 = foo_cg(inp_list)
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
inp_list = [x1]
foo2_cg = self.cudagraphify_impl(foo2, inp_list, ())
foo2_cg(inp_list)
del x1, x2
# TODO make configurable
x1, x2 = foo_cg([inp])
self.assertEqual(self.num_checkpoints(), 0)
# input location has changed, should force recompile and checkpointing
foo2_cg([torch.zeros_like(x1)])
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(self.get_root_children(), [2])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_checkpoint_shared_output_storage_deallocation(self):
def foo(args):
x = args[0]
args.clear()
x_tmp = x + 1
return x[0], x[1]
inp = torch.rand([2, 2], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
foo_cg(inp_list)
foo_cg([inp])
x1, x2 = foo_cg([inp])
inp = [x1]
def foo2(args):
x = args[0]
args.clear()
y = x * x
return y[0], y[1]
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
foo2_cg(inp)
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(
x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
)
self.assertEqual(all_live_block_count(), 1)
del x1
self.assertEqual(all_live_block_count(), 1)
del x2
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_cleanup(self):
def test_closure():
@torch.compile
def foo(x):
return x + 1 + 2, x * 10
foo(torch.rand([4], device="cuda"))
return foo(torch.rand([4], device="cuda"))
out1, out2 = test_closure()
torch._dynamo.reset()
# TODO - deallocate on tensor deallocation
# self.assertTrue(self.get_manager() is not None)
# del out1
# self.assertTrue(self.get_manager() is not None)
# del out2
self.assertTrue(self.get_manager() is None)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_forward_backward(self):
@torch.compile
def foo(x):
y = x * 2
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
out = foo(inp)
out.sum().backward()
self.assertEqual(self.get_root_children(), [1])
# the three saved tensors should die in the backward
# we kept alive the output
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
self.assertEqual(
self.curr_node().expected_dead_indices_after_graph,
[(0, 1), (0, 2)],
)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_separate_recordings(self):
def foo_unopt(x, y):
return (x + 1) @ y
foo = torch.compile(foo_unopt)
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps = [
torch.ones([20, 20], device="cuda", requires_grad=False)
for _ in range(2)
]
out = foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps2 = [
torch.rand([40, 40], device="cuda", requires_grad=False)
for _ in range(2)
]
foo(*inps2)
foo(*inps2)
foo(*inps2)
# two separate roots
self.assertEqual(self.get_root_children(), [0, 0])
def test_alias_of_parameter(self):
class AliasMod(nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
def forward(self, x):
return self.param[0], self.param, self.param + x
@torch.compile(mode="reduce-overhead")
def foo(mod, inp):
return mod(inp)
inp = torch.rand([20, 20], device="cuda")
mod = AliasMod()
storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
mod.param.untyped_storage()
)
for _ in range(3):
outs = foo(mod, inp)
self.assertEqual(mod(inp), outs)
self.assertFalse(storage_ref.expired())
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_unstable_ptr(self):
import torch
@torch.compile(mode="reduce-overhead")
def foo(m, inp):
return m(inp)
def f():
l = []
m = torch.nn.Linear(20, 20).cuda()
for _ in range(4):
inp = torch.rand([20, 20], device="cuda")
foo(m, inp)
m.weight.data = torch.rand([20, 20], device="cuda")
self.assertRaises(RuntimeError, f)
@requires_multigpu()
def test_manager_per_device(self):
def test():
def foo(args):
x = args[0]
args.clear()
return (x + 3,)
inp = torch.rand([20, 20], device="cuda:1")
inp_list = [inp]
foo_cg = tree_cudagraphify_impl(
foo,
inp_list,
(),
device_index=1,
is_backward=False,
is_inference=True,
)
for _ in range(3):
self.assertEqual(foo_cg([inp]), foo([inp]))
self.assertTrue(self.get_manager(device_index=0) is None)
self.assertFalse(self.get_manager(device_index=1) is None)
test()
self.assertTrue(self.get_manager(device_index=1) is None)
def test_error_on_dealloc_use(self):
@torch.compile()
def foo(x):
return x * x * x
inp = torch.rand([4], device="cuda")
out = foo(inp)
out2 = foo(inp)
with self.assertRaisesRegex(Exception, "overwritten by a subsequent"):
out + out
foo(inp)
with self.assertRaisesRegex(Exception, "overwritten by a subsequent"):
out2 + out2
def test_error_on_dealloc_use2(self):
@torch.compile()
def foo(x):
return x * x * x
inp = torch.rand([4], device="cuda")
out = foo(inp).detach()
out2 = foo(inp).detach()
with self.assertRaises(Exception) as exc:
out + out
FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception))
foo(inp)
with self.assertRaises(Exception) as exc:
out2 + out2
FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception))
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
def test_conv_benchmark(self):
with torch.backends.cudnn.flags(
enabled=True, benchmark=True, deterministic=False
):
m = torch.nn.Conv2d(5, 6, [3, 3]).cuda()
inp = torch.randn([2, 5, 16, 16]).cuda()
@torch.compile()
def foo(m, inp):
return m(inp)
foo(m, inp)
def test_single_stream_use(self):
@torch.compile()
def foo(x):
return (x * x * x).relu()
inp = torch.rand([4], device="cuda", requires_grad=True)
streams = set()
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
for _ in range(4):
foo(inp).sum().backward()
inp.grad = None
streams = {
seg["stream"] for seg in get_all_cudagraph_segments()
} - streams_init
self.assertEqual(len(streams), 1)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._dynamo.config.patch("assume_static_by_default", False)
def test_dynamic_backward(self):
def foo(x):
x = torch.cat([x, x])
return torch.addmm(x, x, x).relu(), x.size(0)
opt_foo = torch.compile(mode="reduce-overhead")(foo)
def run_test(foo, inp):
r, s = foo(inp)
r.sum().backward()
g = inp.grad.clone()
inp.grad = None
r = r.clone()
return r, s, g
def run_big_test(inp):
r0, s0, g0 = run_test(foo, inp)
r1, s1, g1 = run_test(opt_foo, inp)
r2, s2, g2 = run_test(opt_foo, inp)
self.assertEqual(r0, r1)
self.assertEqual(r0, r2)
self.assertEqual(s0, s1)
self.assertEqual(s0, s2)
self.assertEqual(g0, g1)
self.assertEqual(g0, g2)
inp = torch.randn(2, 4, device="cuda", requires_grad=True)
run_big_test(inp)
inp = torch.randn(3, 6, device="cuda", requires_grad=True)
run_big_test(inp)
def test_dynamic_warmup(self):
COUNTER = 0
def f(inps):
i, x = inps
inps.clear()
nonlocal COUNTER
COUNTER += 1
return x * 2
x = torch.randn(2, device="cuda")
inp_list = [2, x]
foo_cg = self.cudagraphify_impl(f, inp_list, ())
foo_cg(inp_list) # warmup
foo_cg([2, x]) # record
foo_cg([2, x]) # replay
self.assertEqual(COUNTER, 2)
# Switching the size will require a warmup again
x = torch.randn(3, device="cuda")
inp_list = [3, x]
foo_cg(inp_list) # warmup
foo_cg([3, x]) # record
foo_cg([3, x]) # replay
self.assertEqual(COUNTER, 4)
def test_forward_generation(self):
def foo(x):
return x * x * x
def foo2(x):
return x * 12
foo_opt = torch.compile(foo)
foo2_opt = torch.compile(foo2)
ones = torch.ones([4, 4], device="cuda", requires_grad=True)
out = foo_opt(ones)
out2 = foo2_opt(out)
self.assertEqual(all_live_block_count(), 2)
self.assertTrue(self.get_manager().running_forwards_with_pending_backwards)
out2.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
ones.grad = None
del out
del out2
foo2_opt(foo_opt(ones)).sum().backward()
out = foo_opt(ones.detach())
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_warn_on_pending_backward(self):
@torch.compile
def foo(x):
return x * x * x
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
warnings.resetwarnings()
with warnings.catch_warnings(record=True) as w:
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
FileCheck().check(
"Unable to hit fast path of CUDAGraphs because of pending"
).run(str(w[0]))
self.assertTrue(self.get_manager().new_graph_id().id == 0)
def test_mark_step(self):
@torch.compile
def foo(x):
return x * x * x
torch.compiler.cudagraph_mark_step_begin()
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
torch.compiler.cudagraph_mark_step_begin()
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_incompatible_cudagraph_ops_item(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x.item()
# NB: This doesn't work with float, because float unbacked codegen
# is currently broken. But testing the float case here is also
# awkward, because we plan to Tensor-ify the float compute, and as
# a result we'd actually expect this to work with cuda graphs!
with capture_stderr() as captured_output:
self.assertEqual(foo(torch.tensor(3, device="cuda")), 3)
self.assertEqual(foo(torch.tensor(6, device="cuda")), 6)
# NOTE: this test is named after incompatible ops, but is not skipping due to incompatible ops.
# This should get fixed.
FileCheck().check(
" to incompatible op aten._local_scalar_dense.default"
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("compiled_autograd", True)
def test_compiled_autograd_static_input_params(self):
@torch.compile(mode="reduce-overhead")
def bwd(loss):
loss.backward()
model = torch.nn.Linear(10, 10, bias=False, device="cuda")
x = torch.randn(10, 10, device="cuda")
for i in range(5):
out = model(x)
bwd(out.sum())
model.weight.grad = None
# i=0, 0 copies (warmup)
# i=1, 2 copies (record, 1/3 inputs marked as static)
# i>1, 0 copies (run)
self.assertEqual(
counters["inductor"]["cudagraph_recorded_non_static_inputs"], 2
)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_incompatible_cudagraph_ops_nonzero(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x.nonzero()
with capture_stderr() as captured_output:
self.assertEqual(
foo(torch.tensor([1, 0, 2], device="cuda")),
torch.tensor([[0], [2]]),
)
self.assertEqual(
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
)
FileCheck().check("incompatible op aten.nonzero.default").check("foo").run(
captured_output[0]
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_incompatible_cudagraph_ops_nonzero_graph_breaks(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
y = x.nonzero() # skip
torch._dynamo.graph_break()
return y.nonzero() # skip 2 times (due to recompile)
foo(torch.tensor([1, 0, 2], device="cuda"))
foo(torch.tensor([1, 0, 0], device="cuda"))
self.assertEqual(counters["inductor"]["cudagraph_skips"], 3)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_incompatible_cudagraph_ops_nonzero_backend(self):
@torch.compile(backend="cudagraphs")
def foo(x):
return x.nonzero()
with capture_stderr() as captured_output:
self.assertEqual(
foo(torch.tensor([1, 0, 2], device="cuda")),
torch.tensor([[0], [2]]),
)
self.assertEqual(
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
)
FileCheck().check(
"skipping cudagraphs due to incompatible op (nonzero)"
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
def test_storage_access_error(self):
x = torch.rand([4], device="cuda")
torch._C._set_storage_access_error_msg(x, "custom error msg")
with self.assertRaisesRegex(Exception, "custom error msg"):
device = x.untyped_storage()
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_static_inputs_address_mutation_log(self):
class Goo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2, device="cuda")
def forward(self, x) -> torch.Tensor:
return self.linear(x)
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.static_tensor = torch.zeros((2, 2), device="cuda")
self.goo = Goo()
def forward(self, x) -> torch.Tensor:
self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
return self.static_tensor + x + self.goo(x)
foo = Foo()
foo = torch.compile(foo, mode="reduce-overhead")
inp = torch.rand((2, 2), device="cuda")
for _ in range(3):
foo(inp)
# mutates static input tensors' addresses
foo.static_tensor = torch.ones((2, 2), device="cuda")
foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
with self.assertRaisesRegex(
Exception,
r"(?s)static input data pointer changed.\n"
r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*"
r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
):
self.curr_node().run(
[foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]
)
def _run_iter(self, param, fn):
fwd_output = fn(torch.ones(2, 2), param)
fwd_output.sum().backward()
grad_output = param.grad.detach().clone()
param.grad = None
return fwd_output, grad_output
def _assert_equal_multi_loop(self, param, fn_eager, fn_compiled):
exp_output, exp_grad = self._run_iter(param, fn_eager)
for _ in range(5):
compiled_output, compiled_grad = self._run_iter(param, fn_compiled)
self.assertEqual(exp_output, compiled_output)
self.assertEqual(exp_grad, compiled_grad)
def run_static_input_param_test(self, fn_eager, num_graphs):
with torch.device("cuda"):
fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")
p1 = torch.nn.Parameter(torch.rand([2, 2]))
self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
p2 = torch.nn.Parameter(torch.rand([2, 2]))
self._assert_equal_multi_loop(p2, fn_eager, fn_compiled)
# Run p1 again to ensure we reuse the previous recording
self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
def _module_test(self, mod, name="weight", param_wrapping=True):
with torch.device("cuda"):
def fn(x, mod):
return mod(x)
fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True)
def run_test_iter(mod, fn):
fwd_output = fn(torch.ones(2, 2), mod)
fwd_output.sum().backward()
grad_output = mod.weight.grad.detach().clone()
mod.zero_grad()
return fwd_output, grad_output
def run_test():
exp_output, exp_grad = run_test_iter(mod, fn)
for _ in range(5):
compiled_output, compiled_grad = run_test_iter(mod, fn_compiled)
self.assertEqual(exp_output, compiled_output)
self.assertEqual(exp_grad, compiled_grad)
run_test()
old_attr = getattr(mod, name)
modified_attr = torch.rand_like(old_attr)
if param_wrapping:
modified_attr = torch.nn.Parameter(modified_attr)
setattr(mod, name, modified_attr)
run_test()
# Run original version to verify we reuse the other recording
setattr(mod, name, old_attr)
run_test()
# Fwd + bwd graphs for each version of the function => 4 graphs
self.assertEqual(self.get_manager().new_graph_id().id, 4)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_param_inputs(self):
# Verify that we can record multiple cudagraphs for a single
# compiled function with param inputs
def fn(x, y):
return x * y
# Fwd + bwd graphs for each version of the function => 4 graphs
self.run_static_input_param_test(fn, 4)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_builtin_module(self):
# Verify that we don't recompile when changing the param of a builtin module
# and that we record another cudagraph
# Note: Linear is a builtin module so we enable that config setting above
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
# Verify that we don't recompile when changing the buffer of a builtin module
# and that we record another cudagraph
self._module_test(
torch.nn.BatchNorm1d(2, device="cuda"),
name="running_mean",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module(self):
# Test that we can correctly dispatch multiple graphs
# if params of a custom module change
class TestModule(torch.nn.Module):
def __init__(self, param) -> None:
super().__init__()
self.weight = param
def forward(self, x):
return self.weight * x
self._module_test(
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_custom_module_buffer(self):
# Test that we can correctly dispatch multiple graphs
# if buffers of a custom module change
class TestModule(torch.nn.Module):
def __init__(self, param, buf) -> None:
super().__init__()
self.weight = param
self.buf = torch.nn.Buffer(buf)
def forward(self, x):
return x * self.weight + self.buf
self._module_test(
TestModule(
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
torch.rand([2, 2], device="cuda"),
),
name="buf",
param_wrapping=False,
)
@torch._inductor.config.patch("triton.cudagraphs", True)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_child_node(self):
# Test that we can correctly dispatch multiple graphs if a child node
# in the tree has stable input pointers change
def fn(x, p):
# Graph 1
y = x * x
torch._dynamo.graph_break()
# Graph 2
return y * p
# We have 5 graphs here
# Graph 1
# / \
# Graph 2 w/ p1 Graph 2 w/ p2
# and then two backward graphs
self.run_static_input_param_test(fn, 5)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
def test_multi_dispatch_parent_node(self):
def fn(x, p):
# Graph 1
y = x * p
torch._dynamo.graph_break()
# Graph 2
return y + x
# We have 6 graphs here
# Graph 1 w/ p1 Graph 1 w/ p2
# | |
# Graph 2 (v1) Graph 2 (v2)
# There are two versions of graph 2 because
# we re-record due to different memory state after running the
# two versions of Graph 1
# and then two backward graphs
self.run_static_input_param_test(fn, 6)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
def forward(self, x):
return x * self.param
with capture_stderr() as captured_output:
# We have 3 graphs here
# None
# / \
# (fwd w/ p1, Graph 0) (bwd w/p2, Graph2)
# (bwd w/ p1, Graph 1)
# All other graphs are skipped because we hit the max recording limit
# (=0 for each node and function pair)
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
for _ in range(3):
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
# Change static tensor address
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
self.assertEqual(self.get_manager().new_graph_id().id, 3)
FileCheck().check(
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
"on cudagraph node None due to static input data pointer changed."
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
def forward(self, x):
return x * self.param
with capture_stderr() as captured_output:
with torch.device("cuda"):
# We have 3 graphs here
# None
# / \
# (fwd w/ p1, Graph 0) (bwd w/p2, Graph2)
# (bwd w/ p1, Graph 1)
# All other graphs are skipped because we hit the max recording limit
# (=0 for each node and function pair)
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
for _ in range(3):
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
for _ in range(5):
# Change static tensor address
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
fn_compiled.param.grad = None
FileCheck().check_count(
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
"on cudagraph node None due to static input data pointer changed.",
1,
exactly=True,
).check_count(
"skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
"on cudagraph node None due to static input data pointer changed.",
1,
exactly=True,
).run(
captured_output[0]
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor(
self,
):
# By setting triton.cudagraph_support_input_mutation=True, we force re-record
# if cudagraph managed tensor addresses changed.
@torch.compile(mode="reduce-overhead")
def foo(x):
return x + 1
@torch.compile(mode="reduce-overhead")
def goo(x):
return x * 2
for _ in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand((2, 3), device="cuda")
y = foo(inp)
z = goo(y)
with capture_stderr() as captured_output:
torch.compiler.cudagraph_mark_step_begin()
x = torch.rand(2, 3, device="cuda")
y = foo(x)
y_clone = y.clone()
z = goo(y_clone)
# eager function should run successfully
for _ in range(5):
torch.compiler.cudagraph_mark_step_begin()
x = torch.rand(2, 3, device="cuda")
y = foo(x)
y_clone = y.clone()
z = goo(y_clone)
FileCheck().check_count(
"skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
"on cudagraph node 0 due to cudagraph managed tensor data pointer changed",
1,
exactly=True,
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
@torch._dynamo.config.patch("error_on_recompile", True)
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1)
def test_not_fallback_to_eager_if_have_not_recompiling_too_many_times(self):
def fn(x, y):
return x * y
# We have 4 graphs here
# None
# / \
# (fwd w/ p1, Graph 0) (fwd w/p2, Graph2)
# (bwd w/ p1, Graph 1) (bwd w/p2, Graph3)
self.run_static_input_param_test(fn, 4)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
def test_tensor_constant_mutation(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.tensor_constant = torch.ones((2, 3), device="cuda")
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.tensor_constant += 1
return x + self.tensor_constant
foo = Foo()
foo = torch.compile(foo, mode="reduce-overhead")
inp = torch.rand((2, 3), device="cuda")
for _ in range(3):
foo(inp)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_rerecord_if_static_input_address_changed(self):
# By setting triton.cudagraph_support_input_mutation=True, we force re-record
# if static tensor addresses changed.
class Goo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2, device="cuda")
def forward(self, x) -> torch.Tensor:
return self.linear(x)
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"static_tensor", torch.zeros((2, 2), device="cuda")
)
self.goo = Goo()
def forward(self, x) -> torch.Tensor:
self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
return self.static_tensor + x + self.goo(x)
foo = Foo()
foo = torch.compile(foo, mode="reduce-overhead")
inp = torch.rand((2, 2), device="cuda")
for _ in range(3):
foo(inp)
# mutates static input tensors' addresses
foo.static_tensor = torch.ones((2, 2), device="cuda")
foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
if torch._dynamo.config.inline_inbuilt_nn_modules:
for _ in range(3):
foo(inp)
else:
# Run with specific function id to avoid dynamo recompiling
self.get_manager().run(
[
foo.goo.linear.weight,
foo.goo.linear.bias,
foo.static_tensor,
inp,
],
FunctionID(0),
)
self.assertEqual(self.get_manager().new_graph_id().id, 2)
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
def test_skip_if_dynamic_shape_limit_reached1(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(3, 3, device="cuda")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def iter(batch_size: int, mod: torch.nn.Module):
x = torch.rand((batch_size, 3), device="cuda")
for _ in range(3):
mod(x)
mod = torch.compile(Mod(), mode="reduce-overhead")
with capture_stderr() as captured_output:
for batch_size in range(10, 40, 10):
iter(batch_size, mod)
FileCheck().check(
"CUDAGraph supports dynamic shapes by recording a new graph for each "
"distinct input size. Recording too many CUDAGraphs may lead to "
"extra overhead. We have observed 2 distinct sizes. "
"Please consider the following options for better performance: "
"a) padding inputs to a few fixed number of shapes; or b) set "
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
"to silence this warning."
).run("\n".join(captured_output))
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
def test_skip_if_dynamic_shape_limit_reached2(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attn = torch.nn.MultiheadAttention(
embed_dim=3, num_heads=3, device="cuda"
)
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
return self.attn(q, k, v)
mod = torch.compile(Mod(), mode="reduce-overhead")
def iter(batch_size: int, length: int):
q = torch.rand((batch_size, length, 3), device="cuda")
k = torch.rand((batch_size, length, 3), device="cuda")
v = torch.rand((batch_size, length, 3), device="cuda")
for _ in range(3):
mod(q, k, v)
with capture_stderr() as captured_output:
for batch_size in range(10, 40, 10):
for length in range(10, 30, 10):
iter(batch_size, length)
print(captured_output)
FileCheck().check(
"CUDAGraph supports dynamic shapes by recording a new graph for each "
"distinct input size. Recording too many CUDAGraphs may lead to "
"extra overhead. We have observed 2 distinct sizes. "
"Please consider the following options for better performance: "
"a) padding inputs to a few fixed number of shapes; or b) set "
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
"to silence this warning."
).run(captured_output[0])
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
def test_warn_once_if_dynamic_shape_limit_reached(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(3, 3, device="cuda")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def iter(batch_size: int, mod: torch.nn.Module):
x = torch.rand((batch_size, 3), device="cuda")
for _ in range(3):
mod(x)
mod = torch.compile(Mod(), mode="reduce-overhead")
with capture_stderr() as captured_output:
for batch_size in range(10, 200, 10):
iter(batch_size, mod)
print(captured_output)
FileCheck().check_count(
"CUDAGraph supports dynamic shapes by recording a new graph for each "
"distinct input size. Recording too many CUDAGraphs may lead to "
"extra overhead. We have observed 2 distinct sizes. "
"Please consider the following options for better performance: "
"a) padding inputs to a few fixed number of shapes; or b) set "
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
"to silence this warning.",
1,
exactly=True,
).run("\n".join(captured_output))
@torch._inductor.config.patch("cpp_wrapper", 1)
def test_cpp_wrapper(self):
def f(x):
return torch.sin(x)
compiled = torch.compile(f, mode="reduce-overhead")
example_input = torch.randn(10, device="cuda")
compiled_result = self.run_twc(compiled, example_input)
eager_result = f(example_input)
self.assertEqual(compiled_result, eager_result)
class TestSAC(TestCase):
def _make_observer_mode(self):
class ObserverMode(TorchDispatchMode):
def __init__(self):
super().__init__()
self.curr_run = 0
self.op_outputs = defaultdict(list)
def __torch_dispatch__(
self,
func: OpOverload,
types: Sequence[type],
args: Sequence[object] = (),
kwargs: Mapping[str, object] = immutable_dict(),
) -> object:
return func(*args, **kwargs)
return ObserverMode
def test_simple(self):
device = "cuda"
from torch._prims.rng_prims import graphsafe_run_with_rng_state
ObserverMode = self._make_observer_mode()
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[op].append(out)
return out
obs = ObserverMode()
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
for _ in range(2):
torch._dynamo.reset()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
fn(x, y).sum().backward()
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4)
for i in range(2):
self.assertEqual(
obs.op_outputs[aten.rand.default][0 + 2 * i],
obs.op_outputs[aten.rand.default][1 + 2 * i],
)
self.assertNotEqual(
obs.op_outputs[aten.rand.default][0],
obs.op_outputs[aten.rand.default][2],
)
def test_cudagraph_uneven_forward_backward(self):
# torch.compile cudagraphs are difficult to test
# the rng updating bc is sensitive to duration of pending backwards, etc.
# this is a short repro to mimic the runtime wrappers integration
# and show that updating the backward rng state with cudagraphs works:
def forward():
state = torch.cuda.get_rng_state()
perm = torch.randperm(10, device="cuda")
return state, perm
def backward(rng_state):
current_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state.cpu())
perm = torch.randperm(10, device="cuda")
torch.cuda.set_rng_state(current_state)
return perm
def normal_test():
state, perm = forward()
repro_perm = backward(state)
return perm, repro_perm
def graphsafe_forward():
perm = torch.randperm(10, device="cuda")
return perm
def graphsafe_backward(generator, new_state):
current_state = generator.graphsafe_get_state()
generator.graphsafe_set_state(new_state)
perm = torch.randperm(10, device="cuda")
generator.graphsafe_set_state(current_state)
return perm
def graph_test(generator, capture_cuda_graph):
if capture_cuda_graph:
graph = torch.cuda.CUDAGraph()
# state should be cloned before the graph
old_state = generator.graphsafe_get_state()
new_state = old_state.clone_state()
if capture_cuda_graph:
# state should be register to the graph
graph.register_generator_state(new_state)
# only capturing the backward
with torch.cuda.graph(graph):
repro_perm = graphsafe_backward(generator, new_state)
# some number of uneven forwards
graphsafe_forward()
graphsafe_forward()
graphsafe_forward()
# state prior to rng invocation
state = generator.get_state()
perm = graphsafe_forward()
new_state.set_state(state)
if capture_cuda_graph:
graph.replay()
else:
repro_perm = graphsafe_backward(generator, new_state)
return perm, repro_perm
self.assertEqual(*normal_test())
generator = torch.cuda.default_generators[0]
self.assertEqual(*graph_test(generator, capture_cuda_graph=False))
self.assertEqual(*graph_test(generator, capture_cuda_graph=True))
def test_cpu_and_cuda_rng(self):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import (
graphsafe_run_with_rng_state,
run_and_save_rng_state,
run_with_rng_state,
)
for hop in [
graphsafe_run_with_rng_state,
run_and_save_rng_state,
run_with_rng_state,
]:
def make_impl(hop):
@hop.py_impl(ObserverMode)
def _(mode, *args, **kwargs):
with no_dispatch():
out = hop(*args, **kwargs)
op = None
for inp in itertools.chain(args, kwargs.values()):
if isinstance(inp, torch._ops.OpOverload):
op = inp
break
assert op is not None
if hop is run_and_save_rng_state:
mode.op_outputs[op].append(out[1])
else:
mode.op_outputs[op].append(out)
return out
make_impl(hop)
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def gn2(x):
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
def fn(x, y, z):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True)
return x * z.cuda()
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
z = torch.randn(4, 4, requires_grad=True)
fn(x, y, z).sum().backward()
for op in [aten.rand.default, aten.randperm.default]:
self.assertEqual(len(obs.op_outputs[op]), 2)
self.assertEqual(
obs.op_outputs[op][0],
obs.op_outputs[op][1],
)
self.assertEqual(
obs.op_outputs[op][0].device.type,
"cpu" if op == aten.randperm.default else "cuda",
)
@parametrize("order", (list(itertools.permutations([0, 1, 2]))))
def test_uneven_forward_backward(self, order):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import graphsafe_run_with_rng_state
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[(mode.curr_run, op)].append(out)
return out
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def gn2(x):
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True)
return x
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn_c = torch.compile(fn, backend=aot_eager_decomp_partition)
torch.manual_seed(0)
outs = []
for i in range(len(order)):
obs.curr_run = i
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
outs.append(fn_c(x, y))
for idx in order:
obs.curr_run = idx
outs[idx].sum().backward()
for run in range(len(order)):
for op in (aten.rand.default, aten.randperm.default):
self.assertEqual(len(obs.op_outputs[(run, op)]), 2)
self.assertEqual(
obs.op_outputs[(run, op)][0],
obs.op_outputs[(run, op)][1],
)
if run != 0:
self.assertNotEqual(
obs.op_outputs[(run - 1, op)][0],
obs.op_outputs[(run, op)][0],
)
@config.patch(fallback_random=True)
@config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True)
def _test_cudagraphs_aot_eager_compat_equal(self, device):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
outs = []
grads = []
outs2 = []
grads2 = []
compile_fns = [
lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"),
lambda fn: torch.compile(fn, mode="reduce-overhead"),
]
for i, compile_fn in enumerate(compile_fns):
torch.manual_seed(0)
for index in range(3):
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
out = compile_fn(fn)(x, y)
torch.cuda.synchronize()
out.sum().backward()
if i == 0:
outs.append(out.clone())
grads.append((x.grad.clone(), y.grad.clone()))
else:
outs2.append(out.clone())
grads2.append((x.grad.clone(), y.grad.clone()))
self.assertEqual(outs, outs2)
self.assertEqual(grads, grads2)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
def test_cudagraphs_aot_eager_compat_equal(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
@requires_multigpu()
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
@requires_multigpu()
def test_multi_device(self):
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
def multi_fn(x, y, a, b):
return fn(x, y), fn(a, b)
x = torch.randn(4, 4, device="cuda:0", requires_grad=True)
y = torch.randn(4, 4, device="cuda:0", requires_grad=True)
a = torch.randn(4, 4, device="cuda:1", requires_grad=True)
b = torch.randn(4, 4, device="cuda:1", requires_grad=True)
# No errors. TODO - get graphs from logging, couldnt figure out how
multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition")
out = multi_fn_c(x, y, a, b)
out[0].sum().backward()
def test_retain_graph(self):
device = "cuda"
ObserverMode = self._make_observer_mode()
from torch._prims.rng_prims import graphsafe_run_with_rng_state
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
def _(mode, op, *args, **kwargs):
with no_dispatch():
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
mode.op_outputs[op].append(out)
return out
obs = ObserverMode()
def gn(x, y):
return torch.sigmoid(torch.rand_like(x) * y) * x
def fn(x, y):
x = torch.sin(x)
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
x = torch.sin(x)
return x
x = torch.randn(4, 4, device=device, requires_grad=True)
y = torch.randn(4, 4, device=device, requires_grad=True)
aot_eager_decomp_partition = functools.partial(
aot_eager_decomp_partition_with_mode, mode=obs
)
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
out = fn(x, y).sum()
out.backward(retain_graph=True)
out.backward()
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3)
self.assertEqual(
obs.op_outputs[aten.rand.default][0],
obs.op_outputs[aten.rand.default][1],
)
self.assertEqual(
obs.op_outputs[aten.rand.default][1],
obs.op_outputs[aten.rand.default][2],
)
instantiate_parametrized_tests(CudaGraphTreeTests)
instantiate_parametrized_tests(TestSAC)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if not TEST_CUDA_GRAPH:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("cuda graph test is skipped")
if HAS_CUDA:
run_tests(needs="filelock")