mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
TODO: - [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync - [x] Update rng state initialization to take from correct device - [x] Tests - [x] handling of retain_graph - [x] respect fallback random Fix for https://github.com/pytorch/pytorch/issues/130123. Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states. We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward. ``` ===== Forward graph 1 ===== /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None ... ===== Backward graph 1 ===== def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0): sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1) # No stacktrace found for following nodes graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None ``` There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls: - fwd0: fwd_rng_state0 -> fwd_rng_state1 - fwd1: fwd_rng_state1 -> fwd_rng_state2 - bwd1 - bwd0 Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary. Other notes: Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order. Questions for reviewers: This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`. Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts. Edit: updated to be taken from randint() Update: initializing rng states from torch.randint.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878 Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2926 lines
108 KiB
Python
2926 lines
108 KiB
Python
# Owner(s): ["module: inductor"]
|
|
# ruff: noqa: F841
|
|
import contextlib
|
|
import functools
|
|
import gc
|
|
import importlib
|
|
import itertools
|
|
import sys
|
|
import unittest
|
|
import warnings
|
|
from collections import defaultdict
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
import torch
|
|
import torch._dynamo.config as dynamo_config
|
|
import torch.nn as nn
|
|
from torch._dynamo.backends.debugging import aot_eager_decomp_partition_with_mode
|
|
from torch._dynamo.utils import counters
|
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
|
from torch._inductor import config
|
|
from torch._inductor.codecache import FxGraphCache
|
|
from torch._inductor.compile_fx import compile_fx_inner
|
|
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
|
|
from torch._inductor.cudagraph_utils import FunctionID
|
|
from torch._inductor.test_case import TestCase as InductorTestCase
|
|
from torch._ops import OpOverload
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.immutable_collections import immutable_dict
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_CI,
|
|
IS_LINUX,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
skipIfRocm,
|
|
TEST_CUDA_GRAPH,
|
|
TEST_WITH_ASAN,
|
|
)
|
|
from torch.utils._mode_utils import no_dispatch
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
if IS_WINDOWS and IS_CI:
|
|
sys.stderr.write(
|
|
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
|
|
)
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
|
|
|
importlib.import_module("functorch")
|
|
importlib.import_module("filelock")
|
|
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
|
|
aten = torch.ops.aten
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
requires_multigpu = functools.partial(
|
|
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
|
|
)
|
|
from io import StringIO
|
|
|
|
|
|
def get_compile_fn(backend):
|
|
if backend == "cudagraphs":
|
|
return functools.partial(torch.compile, backend="cudagraphs")
|
|
else:
|
|
return functools.partial(torch.compile, mode="reduce-overhead")
|
|
|
|
|
|
class capture_stderr(list):
|
|
"""
|
|
Replace sys.stderr with a temporary StringIO
|
|
"""
|
|
|
|
def __enter__(self):
|
|
self.sys_stderr = sys.stderr
|
|
self.stringio = StringIO()
|
|
sys.stderr = self.stringio
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.append(str(self.stringio.getvalue()))
|
|
del self.stringio
|
|
sys.stderr = self.sys_stderr
|
|
|
|
|
|
def cdata(t):
|
|
return t.untyped_storage()._cdata
|
|
|
|
|
|
class TestCase(InductorTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._stack = contextlib.ExitStack()
|
|
cls._stack.enter_context(
|
|
config.patch(
|
|
{
|
|
"debug": True,
|
|
"cpp.min_chunk_size": 1,
|
|
"triton.autotune_pointwise": False, # too slow
|
|
"implicit_fallbacks": False,
|
|
}
|
|
)
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._stack.close()
|
|
super().tearDownClass()
|
|
|
|
def setUp(self):
|
|
torch._dynamo.reset()
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch._dynamo.reset()
|
|
|
|
|
|
if HAS_CUDA and not TEST_WITH_ASAN:
|
|
|
|
def get_all_cudagraph_segments():
|
|
segments = torch.cuda.memory_snapshot()
|
|
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
|
|
|
|
def all_live_blocks():
|
|
blocks_addrs = []
|
|
for segment in get_all_cudagraph_segments():
|
|
addr = segment["address"]
|
|
for block in segment["blocks"]:
|
|
if block["state"] == "active_allocated":
|
|
blocks_addrs.append(addr)
|
|
addr += block["size"]
|
|
|
|
return blocks_addrs
|
|
|
|
def all_live_block_count():
|
|
return len(all_live_blocks())
|
|
|
|
class CudaGraphTreeTests(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.graph_stack = contextlib.ExitStack()
|
|
self.graph_stack.enter_context(
|
|
config.patch(
|
|
{
|
|
"triton.cudagraphs": True,
|
|
"triton.cudagraph_trees": True,
|
|
"triton.fast_path_cudagraph_asserts": True, # too slow
|
|
"triton.slow_path_cudagraph_asserts": True,
|
|
}
|
|
)
|
|
)
|
|
self.graph_stack.enter_context(
|
|
dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
)
|
|
self.device_idx = torch.rand([0], device="cuda").device.index
|
|
warnings.filterwarnings("ignore")
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch._dynamo.reset()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
self.graph_stack.close()
|
|
|
|
self.assertIsNone(self.get_manager())
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
self.assertEqual(len(get_all_cudagraph_segments()), 0)
|
|
warnings.resetwarnings()
|
|
|
|
def get_manager(self, device_index=None):
|
|
return torch._inductor.cudagraph_trees.get_container(
|
|
self.device_idx if not device_index else device_index
|
|
).tree_manager
|
|
|
|
def get_roots(self):
|
|
return self.get_manager().get_roots()
|
|
|
|
def curr_node(self):
|
|
return self.get_manager().current_node
|
|
|
|
def get_root_children(self):
|
|
return [root.num_descendants() for root in self.get_roots()]
|
|
|
|
def cudagraphify_impl(
|
|
self, *args, is_inference=True, is_backward=False, **kwargs
|
|
):
|
|
return tree_cudagraphify_impl(
|
|
*args,
|
|
**kwargs,
|
|
device_index=self.device_idx,
|
|
is_inference=is_inference,
|
|
is_backward=is_backward,
|
|
)
|
|
|
|
@staticmethod
|
|
def run_twc(fn, *args, **kwargs):
|
|
fn(*args, **kwargs)
|
|
return fn(*args, **kwargs)
|
|
|
|
def num_checkpoints(self):
|
|
return self.get_manager().debug_checkpointing_counter
|
|
|
|
def test_run_simple(self):
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
foo_opt = torch.compile(foo)
|
|
ones = torch.ones([4, 4], device="cuda")
|
|
zeros = torch.zeros([5, 5], device="cuda")
|
|
self.run_twc(foo_opt, ones)
|
|
self.run_twc(foo_opt, zeros)
|
|
self.assertEqual(self.get_root_children(), [0, 0])
|
|
|
|
def check_rng(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo():
|
|
return torch.rand([20])
|
|
|
|
torch.manual_seed(0)
|
|
|
|
out = foo()
|
|
out2 = foo()
|
|
out3 = foo()
|
|
|
|
torch.manual_seed(0)
|
|
|
|
self.assertEqual(out, foo())
|
|
self.assertEqual(out2, foo())
|
|
self.assertEqual(out3, foo())
|
|
|
|
@torch._inductor.config.patch("fallback_random", True)
|
|
def test_rng_trees(self):
|
|
self.check_rng()
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_trees", False)
|
|
@torch._inductor.config.patch("fallback_random", True)
|
|
def test_rng_non_trees(self):
|
|
self.check_rng()
|
|
|
|
def test_mutation_reinplaced(self):
|
|
import torch.nn as nn
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, input, other, out):
|
|
input = torch.logical_xor(input=input, other=other, out=out)
|
|
return input
|
|
|
|
x = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
|
|
y = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
|
|
z = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float16).cuda()
|
|
|
|
model = Model().cuda()
|
|
eag = model(x, y, z)
|
|
with capture_stderr() as captured_output:
|
|
opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
|
|
).check("torch.logical_xor").run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@requires_multigpu()
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
def test_multiple_devices_msg(self, backend):
|
|
def foo(x, y):
|
|
return (x + 1, y + 2)
|
|
|
|
foo = get_compile_fn(backend)(foo)
|
|
with capture_stderr() as captured_output:
|
|
foo(torch.ones([10], device="cuda"), torch.ones([20]))
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to cpu device (arg1_1). Found from"
|
|
).check("y + 2").run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
with capture_stderr() as captured_output:
|
|
foo(
|
|
torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1")
|
|
)
|
|
|
|
FileCheck().check("skipping cudagraphs due to multiple devices").run(
|
|
captured_output[0]
|
|
)
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
|
|
def test_skip_symbolic(self):
|
|
@torch.compile(dynamic=True)
|
|
def foo(x, y):
|
|
return x + y
|
|
|
|
with capture_stderr() as captured_output:
|
|
foo(torch.rand([10], device="cuda"), torch.rand([10], device="cuda"))
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to graph with symbolic shapes inputs"
|
|
).check("x + y").run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
|
|
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
def test_mutation_on_inp(self, backend):
|
|
def foo(x):
|
|
x.add_(2)
|
|
return x
|
|
|
|
foo = get_compile_fn(backend)(foo)
|
|
|
|
def inp():
|
|
return torch.ones([10], device="cuda")
|
|
|
|
with capture_stderr() as captured_output:
|
|
foo(inp())
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
|
|
).check(".add_(2)").run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
# mutation on inp doesnt hit cudagraphs
|
|
self.assertEqual(len(self.get_manager().roots), 0)
|
|
|
|
# mutation on parameters/buffers hits cudagraphs
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buf = torch.ones([10], device="cuda")
|
|
|
|
def forward(self, x):
|
|
self.buf.add_(x)
|
|
return self.buf + x
|
|
|
|
def foo(mod, x):
|
|
return mod(x)
|
|
|
|
foo = get_compile_fn(backend)(foo)
|
|
mod = Mod()
|
|
mod2 = Mod()
|
|
|
|
for _ in range(3):
|
|
self.assertEqual(foo(mod, inp()), mod2(inp()))
|
|
self.assertEqual(mod.buf, mod2.buf)
|
|
|
|
self.assertIsNotNone(self.get_manager())
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
|
|
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
|
|
def test_mutation_cudagraph_managed_tensors_config(self, backend):
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def mut(x):
|
|
x.add_(2)
|
|
return x
|
|
|
|
def non_mut(x):
|
|
return x.add(2)
|
|
|
|
mut = get_compile_fn(backend)(mut)
|
|
foo = get_compile_fn(backend)(foo)
|
|
|
|
with capture_stderr() as captured_output:
|
|
for i in range(3):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
inp = torch.rand([4], device="cuda")
|
|
|
|
tmp = foo(inp)
|
|
mut_out = mut(tmp)
|
|
self.assertEqual(mut_out, non_mut(foo(inp)))
|
|
FileCheck().check_count(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
|
|
1,
|
|
exactly=True,
|
|
).run(captured_output[0])
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
|
|
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
def test_mutation_cudagraph_managed_tensors(self, backend):
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def mut(x):
|
|
x.add_(2)
|
|
return x
|
|
|
|
def non_mut(x):
|
|
return x.add(2)
|
|
|
|
mut = get_compile_fn(backend)(mut)
|
|
foo = get_compile_fn(backend)(foo)
|
|
|
|
with capture_stderr() as captured_output:
|
|
for i in range(3):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
inp = torch.rand([4], device="cuda")
|
|
|
|
tmp = foo(inp)
|
|
mut_out = mut(tmp)
|
|
self.assertEqual(mut_out, non_mut(foo(inp)))
|
|
FileCheck().check_count(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
|
|
0,
|
|
exactly=True,
|
|
).run(captured_output[0])
|
|
self.assertTrue("cudagraph_skips" not in counters["inductor"])
|
|
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
inp = torch.rand([4], device="cuda")
|
|
tmp = foo(inp)
|
|
mut_inp = tmp.clone()
|
|
# in this case, what previously a mutated cudagraph managed tensor is no longer,
|
|
# now its an input from eager we should fallback to inductor without cudagraphs
|
|
with capture_stderr() as captured_output:
|
|
mut(mut_inp)
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
|
|
).check("x.add_(2)").run(captured_output[0])
|
|
self.assertEqual(mut_inp, non_mut(foo(inp)))
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
|
|
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
def test_mutation_cudagraph_managed_tensor_warn(self, backend):
|
|
def foo(x):
|
|
return x.add_(1)
|
|
|
|
def fee(y, z):
|
|
return z.add(3)
|
|
|
|
def inp():
|
|
return torch.rand([4], device="cuda")
|
|
|
|
foo = get_compile_fn(backend)(foo)
|
|
fee = get_compile_fn(backend)(fee)
|
|
|
|
with capture_stderr() as captured_output:
|
|
for _ in range(3):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
fee(inp(), foo(inp()))
|
|
FileCheck().check_count(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
|
|
1,
|
|
exactly=True,
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
|
|
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend):
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
def mut(x):
|
|
x.add_(2)
|
|
return x
|
|
|
|
def inp():
|
|
return torch.rand([4], device="cuda")
|
|
|
|
mut = get_compile_fn(backend)(mut)
|
|
foo = get_compile_fn(backend)(foo)
|
|
|
|
with capture_stderr() as captured_output:
|
|
# Should warn for current_node=None
|
|
mut(inp())
|
|
|
|
for i in range(3):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
tmp = foo(inp())
|
|
mut(tmp) # should not warn
|
|
|
|
mut_inp = tmp.clone()
|
|
mut(mut_inp) # should not warn since mut has warned
|
|
|
|
FileCheck().check_count(
|
|
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
|
|
1,
|
|
exactly=True,
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
def test_function_compiled_multiple_times(self):
|
|
def foo(x):
|
|
y = foo2(x)
|
|
y2 = foo2(y)
|
|
return y + y2
|
|
|
|
def foo2(x):
|
|
torch._dynamo.graph_break()
|
|
return x * x * x
|
|
|
|
foo_opt = torch.compile(foo)
|
|
ones = torch.ones([4, 4], device="cuda")
|
|
foo(ones)
|
|
foo_opt(ones)
|
|
foo_opt(ones)
|
|
self.assertEqual(foo_opt(ones), foo(ones))
|
|
# paths
|
|
children = self.get_root_children()
|
|
# one root with two children
|
|
self.assertEqual(children, [2])
|
|
|
|
def test_end_recording_early(self):
|
|
def foo(x):
|
|
y = x * x * x
|
|
torch._dynamo.graph_break()
|
|
z = x + y
|
|
return z
|
|
|
|
@torch.compile
|
|
def foo2(x):
|
|
return x + 4
|
|
|
|
foo_opt = torch.compile(foo)
|
|
|
|
for _ in range(3):
|
|
out = foo_opt(torch.ones([4, 4], device="cuda"))
|
|
del out
|
|
|
|
# when I tried inducing separate recordings via graph break,
|
|
# the frame kept interferring by keeping outputs alive
|
|
# this isnt great by simulates the logic.
|
|
from torch._dynamo.mutation_guard import GenerationTracker
|
|
|
|
GenerationTracker.generation -= 1
|
|
|
|
out = foo2(torch.ones([4, 4], device="cuda"))
|
|
del out
|
|
|
|
foo_opt(torch.ones([4, 4], device="cuda"))
|
|
|
|
# Two separate traces - one has a child, one doesnt
|
|
self.assertEqual(self.get_root_children(), [1, 0])
|
|
|
|
def test_execution_into_recording(self):
|
|
def foo(x):
|
|
y = x + x
|
|
|
|
if y.sum() > 0:
|
|
return y + 10
|
|
else:
|
|
return y - 10
|
|
|
|
foo_opt = torch.compile(foo)
|
|
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
|
|
self.assertEqual(foo_opt(inp), foo(inp))
|
|
self.assertEqual(foo_opt(inp), foo(inp))
|
|
|
|
inp.add_(1)
|
|
out_eager = foo(inp)
|
|
out_warmup = foo_opt(inp)
|
|
self.assertEqual(out_warmup, out_eager)
|
|
# warmup should be have storage deallocator hooked on
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
|
|
out_live = foo_opt(inp)
|
|
self.assertEqual(out_live, out_eager)
|
|
|
|
# should be in recording mode, with storage deallocator hooked on
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
# warmup should have been freed
|
|
del out_warmup
|
|
# should be in recording mode, with storage deallocator hooked on
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
|
|
del out_live
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
out = foo_opt(inp)
|
|
self.assertEqual(foo(inp), out)
|
|
|
|
# should be in execution mode
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
def test_forward_with_skipped_cudagraphed_backward(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
for _ in range(3):
|
|
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
|
|
out = foo(inp)
|
|
|
|
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
|
|
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
|
|
out.backward(back_inp)
|
|
|
|
# we should not have cudagraph'd the backwards
|
|
new_id = self.get_manager().new_graph_id().id
|
|
self.assertEqual(new_id, 1)
|
|
|
|
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
|
|
|
@torch._functorch.config.patch("enable_autograd_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
|
# Currently fx graph cache is turned off for specialize_float=False
|
|
@torch._dynamo.config.patch("specialize_float", True)
|
|
def test_cache_hit_forward_miss_backward(self):
|
|
# Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
# Run forwards, fx graph should cache miss
|
|
for _ in range(3):
|
|
torch._dynamo.reset()
|
|
counters.clear()
|
|
FxGraphCache.clear()
|
|
AOTAutogradCache.clear()
|
|
|
|
with config.patch(always_complex_memory_overlap_TESTING_ONLY=True):
|
|
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
|
|
out = foo(inp)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
|
|
|
# Reset dynamo and related caches except for FXGraphCache
|
|
torch._dynamo.reset()
|
|
# Forwards should be a cache hit now, we still skip cudagraphs
|
|
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
|
|
out = foo(inp)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
|
|
|
# Run backward without complex memory overlap being set
|
|
|
|
# Run the backward without complex memory overlap reason
|
|
# cache should miss, but cudagraphs should not run
|
|
# because forward skipped it
|
|
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
|
|
out.backward(back_inp)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
|
|
|
# Run it one more time, this time AOTAutogradCache will hit
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
|
|
|
|
torch._dynamo.reset()
|
|
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
|
|
out = foo(inp)
|
|
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
|
|
out.backward(back_inp)
|
|
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
|
|
|
# we should not have cudagraph'd anything
|
|
assert self.get_manager() is None
|
|
|
|
@torch._functorch.config.patch("enable_autograd_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
|
# Currently fx graph cache is turned off for specialize_float=False
|
|
@torch._dynamo.config.patch("specialize_float", True)
|
|
def test_backward_gets_cached_cudagraphs(self):
|
|
# We pass cpu tensors to foo and save that into the cache
|
|
# On a subsequent run in a new process, cudagraphs should be
|
|
# disabled properly on both forward and backwards runs.
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
torch._dynamo.reset()
|
|
counters.clear()
|
|
FxGraphCache.clear()
|
|
AOTAutogradCache.clear()
|
|
|
|
# Use cpu device to disable cudagraphs during compilation
|
|
inp = torch.rand([20, 20], device="cpu", requires_grad=True)
|
|
out = foo(inp)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
|
|
|
back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
|
|
out.backward(back_inp)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
|
|
|
# Run again on new process
|
|
torch._dynamo.reset()
|
|
|
|
# Forward and backward should also disable cudagraphs without compilation
|
|
inp = torch.rand([20, 20], device="cpu", requires_grad=True)
|
|
out = foo(inp)
|
|
# AOTAutogradCache will load the forward and the backward from cache immediately, so fx_graph_cache_hit will equal 2
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2)
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
|
torch._dynamo.reset()
|
|
|
|
back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu")
|
|
out.backward(back_inp)
|
|
|
|
# we should not have cudagraph'd anything
|
|
assert self.get_manager() is None
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
@torch._functorch.config.patch("enable_autograd_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_cache", True)
|
|
@torch._inductor.config.patch("fx_graph_remote_cache", False)
|
|
# Currently fx graph cache is turned off for specialize_float=False
|
|
@torch._dynamo.config.patch("specialize_float", True)
|
|
def test_cached_forward_backward(self):
|
|
counters.clear()
|
|
AOTAutogradCache.clear()
|
|
FxGraphCache.clear()
|
|
|
|
@torch.compile
|
|
def foo(x):
|
|
torch.manual_seed(0)
|
|
y = x * 2
|
|
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
|
|
|
|
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
|
|
inp2 = inp.detach().clone().requires_grad_(True)
|
|
out = foo(inp)
|
|
|
|
out.sum().backward()
|
|
|
|
self.assertEqual(self.get_root_children(), [1])
|
|
|
|
# the three saved tensors should die in the backward
|
|
# we kept alive the output
|
|
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
|
self.assertEqual(
|
|
self.curr_node().expected_dead_indices_after_graph,
|
|
[(0, 1), (0, 2)],
|
|
)
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
|
|
|
# Reset dynamo and rerun. We should see a cache hit now
|
|
torch._dynamo.reset()
|
|
|
|
out2 = foo(inp2)
|
|
out2.sum().backward()
|
|
self.assertEqual(out, out2)
|
|
self.assertEqual(inp.grad, inp2.grad)
|
|
|
|
self.assertEqual(self.get_root_children(), [1])
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
|
|
|
@parametrize("backend", ("inductor", "cudagraphs"))
|
|
def test_forward_backward_not_called(self, backend):
|
|
def foo(x, y):
|
|
x_out = x * x * x
|
|
torch._dynamo.graph_break()
|
|
y_out = y * y * y
|
|
return x_out, y_out
|
|
|
|
foo = get_compile_fn(backend)(foo)
|
|
|
|
for _ in range(3):
|
|
inps = [
|
|
torch.rand([20, 20], requires_grad=True, device="cuda")
|
|
for _ in range(2)
|
|
]
|
|
x_out, y_out = foo(inps[0], inps[1])
|
|
x_out.sum().backward()
|
|
|
|
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
|
|
|
# we should not have cudagraph'd the y backward
|
|
new_id = self.get_manager().new_graph_id().id
|
|
self.assertEqual(new_id, 3)
|
|
|
|
def _test_unaligned_static_input_impl(self, expected_clones):
|
|
def fn(x, y):
|
|
return (x + y,)
|
|
|
|
def get_aligned_inputs():
|
|
return [torch.rand([5, 5], device="cuda") for _ in range(2)]
|
|
|
|
mod = make_fx(fn)(*get_aligned_inputs())
|
|
|
|
mode = torch._subclasses.FakeTensorMode()
|
|
|
|
with mode:
|
|
inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
|
|
|
|
compiled_f = compile_fx_inner(
|
|
mod, inps, static_input_idxs=[0], cudagraphs=True
|
|
)
|
|
|
|
def get_unaligned_inputs():
|
|
return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
|
|
|
|
class CloneCounterMode(TorchDispatchMode):
|
|
def __init__(self) -> None:
|
|
self.count = 0
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = {} if kwargs is None else kwargs
|
|
self.count += func is torch.ops.aten.clone.default
|
|
return func(*args, **kwargs)
|
|
|
|
for _ in range(3):
|
|
with CloneCounterMode() as m:
|
|
compiled_f(get_unaligned_inputs())
|
|
self.assertEqual(m.count, expected_clones)
|
|
|
|
compiled_f(get_aligned_inputs())
|
|
self.assertEqual(m.count, expected_clones)
|
|
|
|
def test_unaligned_static_input_trees(self):
|
|
self._test_unaligned_static_input_impl(expected_clones=0)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_trees", False)
|
|
def test_unaligned_static_input_non_trees(self):
|
|
self._test_unaligned_static_input_impl(expected_clones=0)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraphs", False)
|
|
def test_unaligned_static_input_no_cudagraphs(self):
|
|
self._test_unaligned_static_input_impl(expected_clones=0)
|
|
|
|
def test_sparsity(self):
|
|
def foo(view_6, buf31):
|
|
return aten._sparse_coo_tensor_with_dims_and_tensors(
|
|
1,
|
|
1,
|
|
[1000000, 64],
|
|
view_6,
|
|
buf31,
|
|
dtype=torch.float32,
|
|
layout=torch.sparse_coo,
|
|
device="cuda",
|
|
pin_memory=None,
|
|
)
|
|
|
|
foo_opt = torch.compile(foo)
|
|
|
|
view_6 = torch.zeros([1, 102397], dtype=torch.int64, device="cuda")
|
|
buf31 = torch.rand([102397, 64], device="cuda")
|
|
|
|
for _ in range(3):
|
|
self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31))
|
|
|
|
def test_accumulate_multiple_recordings(self):
|
|
def foo(x):
|
|
y = x + x + x
|
|
torch._dynamo.graph_break()
|
|
if y.sum() <= 0:
|
|
return y
|
|
else:
|
|
return y * 10
|
|
|
|
foo_opt = torch.compile(foo)
|
|
|
|
# two separate compilations & recordings
|
|
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
|
|
|
|
# out1 gets manually freed
|
|
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
|
|
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
|
|
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
|
|
|
|
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
|
|
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del out1, out2
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
|
|
del out3
|
|
gc.collect()
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
@torch._inductor.config.patch("freezing", True)
|
|
def test_constant_output(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(
|
|
torch.tensor([float(i) for i in range(10)], device="cuda")
|
|
)
|
|
|
|
def forward(self, inp):
|
|
return self.param, self.param[0:2], inp + 2
|
|
|
|
inp = torch.tensor([2], device="cuda")
|
|
m = Mod()
|
|
with torch.no_grad():
|
|
out_eager = m(inp)
|
|
|
|
m_comp = torch.compile(m)
|
|
for _ in range(3):
|
|
self.assertEqual(out_eager, m_comp(inp))
|
|
|
|
def test_live_outputs_multiple_graphs(self):
|
|
def foo(x):
|
|
x = x + x + x
|
|
y = x + 1
|
|
torch._dynamo.graph_break()
|
|
z = x * x
|
|
if z.sum() > 0:
|
|
return y + 1
|
|
else:
|
|
return y
|
|
|
|
foo_opt = torch.compile(foo)
|
|
|
|
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
|
|
self.assertEqual(self.num_checkpoints(), 0)
|
|
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
|
|
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
|
|
del out
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
# we need to checkpoint from function to warmup y + 1,
|
|
# and then again to record it
|
|
self.assertEqual(self.num_checkpoints(), 2)
|
|
|
|
def test_expanded_inputs(self):
|
|
x = torch.rand(1, 512, device="cuda").expand(4, 512)
|
|
|
|
def foo(x):
|
|
return x + 4 + torch.ones([4, 512], device="cuda")
|
|
|
|
foo_opt = torch.compile()(foo)
|
|
|
|
for _ in range(3):
|
|
self.assertEqual(foo_opt(x), foo(x))
|
|
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_tensor_dies_between_checkpoint(self):
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return x + 1, x + 2
|
|
|
|
inp = torch.rand([4], device="cuda")
|
|
inp_list = [inp]
|
|
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
|
|
foo_cg(inp_list)
|
|
foo_cg([inp])
|
|
|
|
out1, out2 = foo_cg([inp])
|
|
inp = [out1]
|
|
|
|
del out1, out2
|
|
|
|
def foo2(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return [x * x * x]
|
|
|
|
self.assertEqual(self.num_checkpoints(), 0)
|
|
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
|
|
|
|
x = foo2_cg(inp)[0]
|
|
|
|
self.assertEqual(self.num_checkpoints(), 1)
|
|
# out2 dies between the previous recording and the new one,
|
|
# need to be manually deallocated after the checkpoint
|
|
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del x
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
def test_aliased_storage_single_weakref(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
x = x * 20
|
|
x_alias = x[0]
|
|
y = x * 10
|
|
y_alias = y[0]
|
|
torch._dynamo.graph_break()
|
|
ind = torch.tensor(4, device="cuda")
|
|
x_alias2 = x[ind:]
|
|
y_alias2 = y[ind:]
|
|
return x, x_alias, x_alias2, y_alias, y_alias2
|
|
|
|
for _ in range(4):
|
|
outs = foo(torch.rand([20, 20], device="cuda"))
|
|
|
|
ptr_to_ref = {
|
|
out.untyped_storage().data_ptr(): out.untyped_storage()._cdata
|
|
for out in outs
|
|
}
|
|
|
|
self.assertEqual(len(ptr_to_ref), 2)
|
|
for out in outs:
|
|
self.assertEqual(
|
|
ptr_to_ref[out.untyped_storage().data_ptr()],
|
|
out.untyped_storage()._cdata,
|
|
)
|
|
del outs
|
|
del out
|
|
|
|
node = self.get_manager().current_node
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
def test_aliasing_static_ref(self):
|
|
class Mod(torch.nn.Linear):
|
|
def forward(self, x):
|
|
return self.weight.T @ x, self.weight.T, self.weight[0:4]
|
|
|
|
m = Mod(10, 10).cuda()
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(mod, x):
|
|
return mod(x)
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo2(x):
|
|
return x[2:]
|
|
|
|
param_c = cdata(m.weight)
|
|
for _ in range(3):
|
|
x = torch.rand([10, 10], device="cuda", requires_grad=True)
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
out1, alias_1, alias_2 = foo(m, x)
|
|
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
|
|
|
|
out2 = foo2(out1)
|
|
out2.sum().backward()
|
|
self.assertEqual(cdata(out1), cdata(out2))
|
|
m.weight.grad = None
|
|
m.bias.grad = None
|
|
|
|
node = self.curr_node()
|
|
first_node = next(node._path_from_root)
|
|
self.assertFalse(first_node.unaliased_in_all_paths[0])
|
|
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
|
|
|
|
@torch._inductor.config.patch("implicit_fallbacks", True)
|
|
def test_multinomial(self):
|
|
def sample_multinomial(probs, num_samples, replacement=True):
|
|
return torch.multinomial(probs, num_samples, replacement=replacement)
|
|
|
|
# Create and prepare probability tensor on GPU
|
|
probs = torch.tensor([0.1, 0.2, 0.3, 0.4]).cuda()
|
|
probs = probs / probs.sum()
|
|
|
|
# Sample using the function
|
|
num_skipped = counters["inductor"]["cudagraph_skips"]
|
|
|
|
with torch._dynamo.utils.preserve_rng_state():
|
|
samples = self.run_twc(
|
|
sample_multinomial, probs, num_samples=5, replacement=True
|
|
)
|
|
|
|
with torch._dynamo.utils.preserve_rng_state():
|
|
samples_compiled = self.run_twc(
|
|
torch.compile(sample_multinomial),
|
|
probs,
|
|
num_samples=5,
|
|
replacement=True,
|
|
)
|
|
|
|
self.assertEqual(samples, samples_compiled)
|
|
self.assertEqual(num_skipped, counters["inductor"]["cudagraph_skips"])
|
|
|
|
@skipIfRocm
|
|
def test_checkpointing_resets_persistent_refs(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x @ x
|
|
|
|
def inp():
|
|
return torch.rand([20, 20], device="cuda", requires_grad=False)
|
|
|
|
for _ in range(3):
|
|
foo(inp())
|
|
|
|
self.assertEqual(self.num_checkpoints(), 0)
|
|
|
|
out = foo(inp())
|
|
out_id = id(out)
|
|
del out
|
|
self.assertEqual(id(foo(inp())), out_id)
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo2(x):
|
|
return x[0], x @ x
|
|
|
|
for i in range(2):
|
|
out = foo(inp())
|
|
|
|
from torch._dynamo.mutation_guard import GenerationTracker
|
|
|
|
GenerationTracker.generation -= 1
|
|
|
|
out_alias, out2 = foo2(out)
|
|
del out_alias
|
|
|
|
self.assertEqual(all_live_block_count(), 2)
|
|
del out
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del out2
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
self.assertEqual(self.num_checkpoints(), i + 1)
|
|
|
|
new_out = foo(inp())
|
|
curr_node = self.curr_node()
|
|
self.assertFalse(curr_node.unaliased_in_all_paths[0])
|
|
self.assertFalse(out_id == id(new_out))
|
|
|
|
def test_aliased_static_parameter(self):
|
|
inp = torch.rand([20, 20], device="cuda")
|
|
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return (x[0],)
|
|
|
|
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
|
|
|
|
for _ in range(3):
|
|
out = foo_cg([inp])[0]
|
|
self.assertEqual(cdata(inp), cdata(out))
|
|
|
|
node = self.curr_node()
|
|
self.assertEqual(node.cached_tensor_outputs, [None])
|
|
self.assertEqual(node.unaliased_in_all_paths, [False])
|
|
|
|
def test_warmup_stream_sync(self):
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
x_orig = x
|
|
for _ in range(100):
|
|
x = x @ x
|
|
return (x,)
|
|
|
|
inp = torch.rand([4096, 4096], device="cuda")
|
|
ref = foo([inp])[0]
|
|
torch.cuda.synchronize()
|
|
|
|
user_stream = torch.cuda.Stream()
|
|
with torch.cuda.stream(user_stream):
|
|
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
|
|
out = foo_cg([inp])[0]
|
|
y = out + 1
|
|
self.assertEqual(y, ref + 1)
|
|
|
|
def test_unaligned_static_parameter(self):
|
|
def gen_inp():
|
|
inp = torch.ones([20], device="cuda")
|
|
return [inp[1:]]
|
|
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return (x + x,)
|
|
|
|
foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,))
|
|
|
|
for _ in range(3):
|
|
out = foo_cg(gen_inp())
|
|
self.assertEqual(out, foo(gen_inp()))
|
|
del out
|
|
|
|
node = self.curr_node()
|
|
self.assertEqual(node.static_input_data_ptrs, [None])
|
|
|
|
def test_amp_cache_disabled(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x + x
|
|
|
|
for _ in range(3):
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
|
|
# amp cache for cudagraph outputs should be disabled
|
|
t2 = torch.rand([4, 4], device="cuda")
|
|
|
|
with torch.cuda.amp.autocast():
|
|
run_once = out @ t2
|
|
|
|
out.detach().zero_()
|
|
|
|
run_twice = out @ t2
|
|
|
|
self.assertNotEqual(run_once, run_twice)
|
|
|
|
def test_remove_hooks_on_cached_tensors(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x * x
|
|
|
|
inp = torch.rand([4], device="cuda", requires_grad=True)
|
|
|
|
for _ in range(5):
|
|
out = foo(inp)
|
|
self.assertIsNone(out._backward_hooks)
|
|
out.register_hook(lambda: None)
|
|
|
|
# today, torch.compile never outputs a leaf tensor which is the only
|
|
# tensor that can register _post_accumulate_grad_hooks
|
|
# add this as a preventative test
|
|
|
|
@torch.compile()
|
|
def foo(x):
|
|
return torch.rand([4], device="cuda", requires_grad=True)
|
|
|
|
for _ in range(5):
|
|
out = foo(inp)
|
|
self.assertIsNone(out._post_accumulate_grad_hooks)
|
|
out.register_post_accumulate_grad_hook(lambda: None)
|
|
|
|
def test_multiple_insert_removal_caching(self):
|
|
torch._C._set_cached_tensors_enabled(True)
|
|
try:
|
|
x = torch.rand([4], device="cuda")
|
|
|
|
torch._C._add_cached_tensor(x)
|
|
self.assertTrue(torch._C._is_cached_tensor(x))
|
|
|
|
torch._C._add_cached_tensor(x)
|
|
torch._C._remove_cached_tensor(x)
|
|
|
|
self.assertFalse(torch._C._is_cached_tensor(x))
|
|
finally:
|
|
torch._C._set_cached_tensors_enabled(False)
|
|
|
|
def test_accumulate_grad(self):
|
|
# cudagraph trees shouldnt interfere with accumulation logic
|
|
|
|
def compute_grad(grad_output, create_graph):
|
|
x = torch.randn(5, 5, requires_grad=True, device="cuda")
|
|
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x + 2
|
|
|
|
y = foo(x)
|
|
y.backward(grad_output, retain_graph=True)
|
|
x_grad = x.grad
|
|
x_grad_clone = x.grad.clone()
|
|
y.backward(grad_output, create_graph=create_graph)
|
|
return x_grad, x_grad_clone
|
|
|
|
for _ in range(3):
|
|
grad_output = torch.ones(5, 5, device="cuda")
|
|
|
|
# Accumulate in-place when create_graph is False
|
|
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=False)
|
|
self.assertEqual(x_grad, x_grad_clone * 2)
|
|
|
|
# Accumulate out-of-place when create_graph is False
|
|
x_grad, x_grad_clone = compute_grad(grad_output, create_graph=True)
|
|
self.assertEqual(x_grad, x_grad_clone)
|
|
|
|
def test_frozen_fn(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x @ x
|
|
|
|
for _ in range(3):
|
|
out = foo(torch.rand([10, 10], device="cuda"))
|
|
|
|
self.assertTrue(self.get_manager().new_graph_id().id == 1)
|
|
frozen = torch._dynamo.run(foo)
|
|
|
|
for _ in range(3):
|
|
out = frozen(torch.rand([10, 10], device="cuda"))
|
|
|
|
# didnt do additional recordings
|
|
self.assertTrue(self.get_manager().new_graph_id().id == 2)
|
|
|
|
def test_empty_cpu_tensor(self):
|
|
def foo(x):
|
|
return x @ x, torch.tensor([])
|
|
|
|
foo_opt = torch.compile(foo)
|
|
x = torch.rand([4], device="cuda")
|
|
|
|
for _ in range(3):
|
|
out_opt = foo_opt(x)
|
|
self.assertEqual(foo(x), out_opt)
|
|
|
|
self.assertTrue(self.get_manager().new_graph_id().id == 1)
|
|
|
|
def test_output_alias(self):
|
|
inp = torch.rand([20, 20], device="cuda")
|
|
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
out = x + x
|
|
return (x, x[0])
|
|
|
|
foo_cg = self.cudagraphify_impl(foo, [inp], ())
|
|
|
|
for _ in range(3):
|
|
out_1, out_2 = foo_cg([inp])
|
|
self.assertEqual(cdata(out_1), cdata(out_2))
|
|
del out_1, out_2
|
|
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
|
|
|
|
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
|
|
|
|
def test_empty_storage(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return (
|
|
(x + x + x),
|
|
torch.zeros([0], device="cuda"),
|
|
torch.zeros([100], device="cuda")[0:0],
|
|
)
|
|
|
|
inp = torch.rand([4], device="cuda")
|
|
for _ in range(3):
|
|
out = foo(inp)
|
|
node = self.curr_node()
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return (x + x + x), torch.rand([4], device="cuda") + 10
|
|
|
|
inp = torch.rand([0], device="cuda")
|
|
for _ in range(3):
|
|
out = foo(inp)
|
|
node = self.curr_node()
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_aliased_output_checkpoint(self):
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
y = x + 2
|
|
return x + 1, y, y[0]
|
|
|
|
inp = torch.rand([4, 4], device="cuda")
|
|
foo_cg = self.cudagraphify_impl(foo, [inp], ())
|
|
foo_cg([inp])
|
|
foo_cg([inp])
|
|
|
|
out1, out2, out3 = foo_cg([inp])
|
|
inp = [out1]
|
|
|
|
del out1, out2, out3
|
|
|
|
def foo2(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return [x * x * x]
|
|
|
|
self.assertEqual(self.num_checkpoints(), 0)
|
|
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
|
|
|
|
x = foo2_cg(inp)[0]
|
|
|
|
self.assertEqual(self.num_checkpoints(), 1)
|
|
# out2 and out3 dies between the previous recording and the new one,
|
|
# need to be manually deallocated after the checkpoint
|
|
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del x
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
@skipIfRocm
|
|
@unittest.skipIf(not IS_LINUX, "cpp contexts are linux only")
|
|
@torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True)
|
|
def test_workspace_allocation_error(self):
|
|
torch._C._cuda_clearCublasWorkspaces()
|
|
|
|
prev = torch._inductor.cudagraph_trees.clear_cublas_manager
|
|
|
|
try:
|
|
torch._inductor.cudagraph_trees.clear_cublas_manager = (
|
|
contextlib.nullcontext
|
|
)
|
|
|
|
@torch.compile()
|
|
def foo(x, y):
|
|
return x @ x
|
|
|
|
inps = [torch.rand([400, 400], device="cuda") for _ in range(2)]
|
|
|
|
thrown = False
|
|
try:
|
|
foo(*inps)
|
|
except Exception as e:
|
|
thrown = True
|
|
self.assertTrue(
|
|
"at::cuda::blas::gemm<float>" in str(e)
|
|
or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
|
|
)
|
|
self.assertTrue(
|
|
"getCurrentCUDABlasHandle" in str(e)
|
|
or "getNewWorkspace" in str(e)
|
|
)
|
|
|
|
self.assertTrue(thrown)
|
|
|
|
finally:
|
|
torch._C._cuda_clearCublasWorkspaces()
|
|
torch._inductor.cudagraph_trees.clear_cublas_manager = prev
|
|
torch._inductor.cudagraph_trees.get_container(
|
|
self.device_idx
|
|
).tree_manager = None
|
|
|
|
def test_peristed_output_livenes(self):
|
|
@torch.compile
|
|
def foo(x):
|
|
return x + x
|
|
|
|
for _ in range(3):
|
|
foo(torch.rand([2, 2], device="cuda"))
|
|
|
|
node = self.get_manager().current_node
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
|
|
|
out = foo(torch.rand([2, 2], device="cuda"))
|
|
self.assertTrue(out is node.cached_tensor_outputs[0])
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
|
|
|
out_ref = out[0:]
|
|
del out
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
|
|
|
del out_ref
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_tensor_no_longer_in_pool(self):
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return x + 1, x + 2
|
|
|
|
inp = torch.rand([4], device="cuda")
|
|
inp_list = [inp]
|
|
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
|
|
x1, x2 = foo_cg(inp_list)
|
|
|
|
def foo2(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return [x * x * x]
|
|
|
|
inp_list = [x1]
|
|
foo2_cg = self.cudagraphify_impl(foo2, inp_list, ())
|
|
foo2_cg(inp_list)
|
|
|
|
del x1, x2
|
|
# TODO make configurable
|
|
|
|
x1, x2 = foo_cg([inp])
|
|
self.assertEqual(self.num_checkpoints(), 0)
|
|
|
|
# input location has changed, should force recompile and checkpointing
|
|
foo2_cg([torch.zeros_like(x1)])
|
|
|
|
self.assertEqual(self.num_checkpoints(), 1)
|
|
self.assertEqual(self.get_root_children(), [2])
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_checkpoint_shared_output_storage_deallocation(self):
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
x_tmp = x + 1
|
|
return x[0], x[1]
|
|
|
|
inp = torch.rand([2, 2], device="cuda")
|
|
inp_list = [inp]
|
|
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
|
|
foo_cg(inp_list)
|
|
foo_cg([inp])
|
|
|
|
x1, x2 = foo_cg([inp])
|
|
inp = [x1]
|
|
|
|
def foo2(args):
|
|
x = args[0]
|
|
args.clear()
|
|
y = x * x
|
|
return y[0], y[1]
|
|
|
|
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
|
|
foo2_cg(inp)
|
|
|
|
self.assertEqual(self.num_checkpoints(), 1)
|
|
self.assertEqual(
|
|
x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
|
|
)
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del x1
|
|
self.assertEqual(all_live_block_count(), 1)
|
|
del x2
|
|
self.assertEqual(all_live_block_count(), 0)
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_cleanup(self):
|
|
def test_closure():
|
|
@torch.compile
|
|
def foo(x):
|
|
return x + 1 + 2, x * 10
|
|
|
|
foo(torch.rand([4], device="cuda"))
|
|
return foo(torch.rand([4], device="cuda"))
|
|
|
|
out1, out2 = test_closure()
|
|
torch._dynamo.reset()
|
|
|
|
# TODO - deallocate on tensor deallocation
|
|
# self.assertTrue(self.get_manager() is not None)
|
|
# del out1
|
|
# self.assertTrue(self.get_manager() is not None)
|
|
# del out2
|
|
self.assertTrue(self.get_manager() is None)
|
|
|
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
|
def test_forward_backward(self):
|
|
@torch.compile
|
|
def foo(x):
|
|
y = x * 2
|
|
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
|
|
|
|
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
|
|
out = foo(inp)
|
|
out.sum().backward()
|
|
|
|
self.assertEqual(self.get_root_children(), [1])
|
|
|
|
# the three saved tensors should die in the backward
|
|
# we kept alive the output
|
|
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
|
self.assertEqual(
|
|
self.curr_node().expected_dead_indices_after_graph,
|
|
[(0, 1), (0, 2)],
|
|
)
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
def test_separate_recordings(self):
|
|
def foo_unopt(x, y):
|
|
return (x + 1) @ y
|
|
|
|
foo = torch.compile(foo_unopt)
|
|
|
|
foo_unopt(
|
|
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
|
|
)
|
|
|
|
inps = [
|
|
torch.ones([20, 20], device="cuda", requires_grad=False)
|
|
for _ in range(2)
|
|
]
|
|
|
|
out = foo(*inps)
|
|
torch.cuda.synchronize()
|
|
foo(*inps)
|
|
torch.cuda.synchronize()
|
|
foo(*inps)
|
|
torch.cuda.synchronize()
|
|
|
|
foo_unopt(
|
|
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
|
|
)
|
|
|
|
inps2 = [
|
|
torch.rand([40, 40], device="cuda", requires_grad=False)
|
|
for _ in range(2)
|
|
]
|
|
|
|
foo(*inps2)
|
|
foo(*inps2)
|
|
foo(*inps2)
|
|
|
|
# two separate roots
|
|
self.assertEqual(self.get_root_children(), [0, 0])
|
|
|
|
def test_alias_of_parameter(self):
|
|
class AliasMod(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
|
|
|
|
def forward(self, x):
|
|
return self.param[0], self.param, self.param + x
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(mod, inp):
|
|
return mod(inp)
|
|
|
|
inp = torch.rand([20, 20], device="cuda")
|
|
mod = AliasMod()
|
|
|
|
storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
|
|
mod.param.untyped_storage()
|
|
)
|
|
|
|
for _ in range(3):
|
|
outs = foo(mod, inp)
|
|
|
|
self.assertEqual(mod(inp), outs)
|
|
|
|
self.assertFalse(storage_ref.expired())
|
|
|
|
node = self.get_manager().current_node
|
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
|
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
|
|
def test_unstable_ptr(self):
|
|
import torch
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(m, inp):
|
|
return m(inp)
|
|
|
|
def f():
|
|
l = []
|
|
m = torch.nn.Linear(20, 20).cuda()
|
|
for _ in range(4):
|
|
inp = torch.rand([20, 20], device="cuda")
|
|
foo(m, inp)
|
|
m.weight.data = torch.rand([20, 20], device="cuda")
|
|
|
|
self.assertRaises(RuntimeError, f)
|
|
|
|
@requires_multigpu()
|
|
def test_manager_per_device(self):
|
|
def test():
|
|
def foo(args):
|
|
x = args[0]
|
|
args.clear()
|
|
return (x + 3,)
|
|
|
|
inp = torch.rand([20, 20], device="cuda:1")
|
|
|
|
inp_list = [inp]
|
|
foo_cg = tree_cudagraphify_impl(
|
|
foo,
|
|
inp_list,
|
|
(),
|
|
device_index=1,
|
|
is_backward=False,
|
|
is_inference=True,
|
|
)
|
|
for _ in range(3):
|
|
self.assertEqual(foo_cg([inp]), foo([inp]))
|
|
|
|
self.assertTrue(self.get_manager(device_index=0) is None)
|
|
self.assertFalse(self.get_manager(device_index=1) is None)
|
|
|
|
test()
|
|
self.assertTrue(self.get_manager(device_index=1) is None)
|
|
|
|
def test_error_on_dealloc_use(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
inp = torch.rand([4], device="cuda")
|
|
out = foo(inp)
|
|
out2 = foo(inp)
|
|
|
|
with self.assertRaisesRegex(Exception, "overwritten by a subsequent"):
|
|
out + out
|
|
|
|
foo(inp)
|
|
|
|
with self.assertRaisesRegex(Exception, "overwritten by a subsequent"):
|
|
out2 + out2
|
|
|
|
def test_error_on_dealloc_use2(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
inp = torch.rand([4], device="cuda")
|
|
out = foo(inp).detach()
|
|
out2 = foo(inp).detach()
|
|
|
|
with self.assertRaises(Exception) as exc:
|
|
out + out
|
|
|
|
FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception))
|
|
|
|
foo(inp)
|
|
|
|
with self.assertRaises(Exception) as exc:
|
|
out2 + out2
|
|
|
|
FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception))
|
|
|
|
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
|
|
def test_conv_benchmark(self):
|
|
with torch.backends.cudnn.flags(
|
|
enabled=True, benchmark=True, deterministic=False
|
|
):
|
|
m = torch.nn.Conv2d(5, 6, [3, 3]).cuda()
|
|
inp = torch.randn([2, 5, 16, 16]).cuda()
|
|
|
|
@torch.compile()
|
|
def foo(m, inp):
|
|
return m(inp)
|
|
|
|
foo(m, inp)
|
|
|
|
def test_single_stream_use(self):
|
|
@torch.compile()
|
|
def foo(x):
|
|
return (x * x * x).relu()
|
|
|
|
inp = torch.rand([4], device="cuda", requires_grad=True)
|
|
streams = set()
|
|
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
|
|
for _ in range(4):
|
|
foo(inp).sum().backward()
|
|
inp.grad = None
|
|
|
|
streams = {
|
|
seg["stream"] for seg in get_all_cudagraph_segments()
|
|
} - streams_init
|
|
self.assertEqual(len(streams), 1)
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_dynamic_backward(self):
|
|
def foo(x):
|
|
x = torch.cat([x, x])
|
|
return torch.addmm(x, x, x).relu(), x.size(0)
|
|
|
|
opt_foo = torch.compile(mode="reduce-overhead")(foo)
|
|
|
|
def run_test(foo, inp):
|
|
r, s = foo(inp)
|
|
r.sum().backward()
|
|
g = inp.grad.clone()
|
|
inp.grad = None
|
|
r = r.clone()
|
|
return r, s, g
|
|
|
|
def run_big_test(inp):
|
|
r0, s0, g0 = run_test(foo, inp)
|
|
r1, s1, g1 = run_test(opt_foo, inp)
|
|
r2, s2, g2 = run_test(opt_foo, inp)
|
|
self.assertEqual(r0, r1)
|
|
self.assertEqual(r0, r2)
|
|
self.assertEqual(s0, s1)
|
|
self.assertEqual(s0, s2)
|
|
self.assertEqual(g0, g1)
|
|
self.assertEqual(g0, g2)
|
|
|
|
inp = torch.randn(2, 4, device="cuda", requires_grad=True)
|
|
run_big_test(inp)
|
|
|
|
inp = torch.randn(3, 6, device="cuda", requires_grad=True)
|
|
run_big_test(inp)
|
|
|
|
def test_dynamic_warmup(self):
|
|
COUNTER = 0
|
|
|
|
def f(inps):
|
|
i, x = inps
|
|
inps.clear()
|
|
nonlocal COUNTER
|
|
COUNTER += 1
|
|
return x * 2
|
|
|
|
x = torch.randn(2, device="cuda")
|
|
inp_list = [2, x]
|
|
foo_cg = self.cudagraphify_impl(f, inp_list, ())
|
|
foo_cg(inp_list) # warmup
|
|
foo_cg([2, x]) # record
|
|
foo_cg([2, x]) # replay
|
|
self.assertEqual(COUNTER, 2)
|
|
|
|
# Switching the size will require a warmup again
|
|
x = torch.randn(3, device="cuda")
|
|
inp_list = [3, x]
|
|
foo_cg(inp_list) # warmup
|
|
foo_cg([3, x]) # record
|
|
foo_cg([3, x]) # replay
|
|
self.assertEqual(COUNTER, 4)
|
|
|
|
def test_forward_generation(self):
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
def foo2(x):
|
|
return x * 12
|
|
|
|
foo_opt = torch.compile(foo)
|
|
foo2_opt = torch.compile(foo2)
|
|
ones = torch.ones([4, 4], device="cuda", requires_grad=True)
|
|
|
|
out = foo_opt(ones)
|
|
out2 = foo2_opt(out)
|
|
|
|
self.assertEqual(all_live_block_count(), 2)
|
|
|
|
self.assertTrue(self.get_manager().running_forwards_with_pending_backwards)
|
|
|
|
out2.sum().backward()
|
|
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
|
|
|
ones.grad = None
|
|
del out
|
|
del out2
|
|
|
|
foo2_opt(foo_opt(ones)).sum().backward()
|
|
|
|
out = foo_opt(ones.detach())
|
|
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
def test_warn_on_pending_backward(self):
|
|
@torch.compile
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
|
|
warnings.resetwarnings()
|
|
with warnings.catch_warnings(record=True) as w:
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
|
|
FileCheck().check(
|
|
"Unable to hit fast path of CUDAGraphs because of pending"
|
|
).run(str(w[0]))
|
|
self.assertTrue(self.get_manager().new_graph_id().id == 0)
|
|
|
|
def test_mark_step(self):
|
|
@torch.compile
|
|
def foo(x):
|
|
return x * x * x
|
|
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
out = foo(torch.rand([4, 4], device="cuda", requires_grad=True))
|
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_incompatible_cudagraph_ops_item(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x.item()
|
|
|
|
# NB: This doesn't work with float, because float unbacked codegen
|
|
# is currently broken. But testing the float case here is also
|
|
# awkward, because we plan to Tensor-ify the float compute, and as
|
|
# a result we'd actually expect this to work with cuda graphs!
|
|
with capture_stderr() as captured_output:
|
|
self.assertEqual(foo(torch.tensor(3, device="cuda")), 3)
|
|
self.assertEqual(foo(torch.tensor(6, device="cuda")), 6)
|
|
|
|
# NOTE: this test is named after incompatible ops, but is not skipping due to incompatible ops.
|
|
# This should get fixed.
|
|
FileCheck().check(
|
|
" to incompatible op aten._local_scalar_dense.default"
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@torch._dynamo.config.patch("compiled_autograd", True)
|
|
def test_compiled_autograd_static_input_params(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def bwd(loss):
|
|
loss.backward()
|
|
|
|
model = torch.nn.Linear(10, 10, bias=False, device="cuda")
|
|
x = torch.randn(10, 10, device="cuda")
|
|
for i in range(5):
|
|
out = model(x)
|
|
bwd(out.sum())
|
|
model.weight.grad = None
|
|
|
|
# i=0, 0 copies (warmup)
|
|
# i=1, 2 copies (record, 1/3 inputs marked as static)
|
|
# i>1, 0 copies (run)
|
|
self.assertEqual(
|
|
counters["inductor"]["cudagraph_recorded_non_static_inputs"], 2
|
|
)
|
|
|
|
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
|
def test_incompatible_cudagraph_ops_nonzero(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x.nonzero()
|
|
|
|
with capture_stderr() as captured_output:
|
|
self.assertEqual(
|
|
foo(torch.tensor([1, 0, 2], device="cuda")),
|
|
torch.tensor([[0], [2]]),
|
|
)
|
|
self.assertEqual(
|
|
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
|
|
)
|
|
|
|
FileCheck().check("incompatible op aten.nonzero.default").check("foo").run(
|
|
captured_output[0]
|
|
)
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
|
def test_incompatible_cudagraph_ops_nonzero_graph_breaks(self):
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
y = x.nonzero() # skip
|
|
torch._dynamo.graph_break()
|
|
return y.nonzero() # skip 2 times (due to recompile)
|
|
|
|
foo(torch.tensor([1, 0, 2], device="cuda"))
|
|
foo(torch.tensor([1, 0, 0], device="cuda"))
|
|
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 3)
|
|
|
|
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
|
|
def test_incompatible_cudagraph_ops_nonzero_backend(self):
|
|
@torch.compile(backend="cudagraphs")
|
|
def foo(x):
|
|
return x.nonzero()
|
|
|
|
with capture_stderr() as captured_output:
|
|
self.assertEqual(
|
|
foo(torch.tensor([1, 0, 2], device="cuda")),
|
|
torch.tensor([[0], [2]]),
|
|
)
|
|
self.assertEqual(
|
|
foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]])
|
|
)
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraphs due to incompatible op (nonzero)"
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
def test_storage_access_error(self):
|
|
x = torch.rand([4], device="cuda")
|
|
torch._C._set_storage_access_error_msg(x, "custom error msg")
|
|
|
|
with self.assertRaisesRegex(Exception, "custom error msg"):
|
|
device = x.untyped_storage()
|
|
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
|
|
def test_static_inputs_address_mutation_log(self):
|
|
class Goo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2, device="cuda")
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
return self.linear(x)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.static_tensor = torch.zeros((2, 2), device="cuda")
|
|
self.goo = Goo()
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
|
|
return self.static_tensor + x + self.goo(x)
|
|
|
|
foo = Foo()
|
|
foo = torch.compile(foo, mode="reduce-overhead")
|
|
inp = torch.rand((2, 2), device="cuda")
|
|
|
|
for _ in range(3):
|
|
foo(inp)
|
|
|
|
# mutates static input tensors' addresses
|
|
foo.static_tensor = torch.ones((2, 2), device="cuda")
|
|
foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
|
|
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
r"(?s)static input data pointer changed.\n"
|
|
r"input name: primals_2. data pointer changed from .* to .*. input stack trace:.*"
|
|
r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*,"
|
|
r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n",
|
|
):
|
|
self.curr_node().run(
|
|
[foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp]
|
|
)
|
|
|
|
def _run_iter(self, param, fn):
|
|
fwd_output = fn(torch.ones(2, 2), param)
|
|
fwd_output.sum().backward()
|
|
grad_output = param.grad.detach().clone()
|
|
param.grad = None
|
|
return fwd_output, grad_output
|
|
|
|
def _assert_equal_multi_loop(self, param, fn_eager, fn_compiled):
|
|
exp_output, exp_grad = self._run_iter(param, fn_eager)
|
|
for _ in range(5):
|
|
compiled_output, compiled_grad = self._run_iter(param, fn_compiled)
|
|
self.assertEqual(exp_output, compiled_output)
|
|
self.assertEqual(exp_grad, compiled_grad)
|
|
|
|
def run_static_input_param_test(self, fn_eager, num_graphs):
|
|
with torch.device("cuda"):
|
|
fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")
|
|
|
|
p1 = torch.nn.Parameter(torch.rand([2, 2]))
|
|
self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
|
|
|
|
p2 = torch.nn.Parameter(torch.rand([2, 2]))
|
|
self._assert_equal_multi_loop(p2, fn_eager, fn_compiled)
|
|
|
|
# Run p1 again to ensure we reuse the previous recording
|
|
self._assert_equal_multi_loop(p1, fn_eager, fn_compiled)
|
|
|
|
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
|
|
|
|
def _module_test(self, mod, name="weight", param_wrapping=True):
|
|
with torch.device("cuda"):
|
|
|
|
def fn(x, mod):
|
|
return mod(x)
|
|
|
|
fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True)
|
|
|
|
def run_test_iter(mod, fn):
|
|
fwd_output = fn(torch.ones(2, 2), mod)
|
|
fwd_output.sum().backward()
|
|
grad_output = mod.weight.grad.detach().clone()
|
|
mod.zero_grad()
|
|
return fwd_output, grad_output
|
|
|
|
def run_test():
|
|
exp_output, exp_grad = run_test_iter(mod, fn)
|
|
for _ in range(5):
|
|
compiled_output, compiled_grad = run_test_iter(mod, fn_compiled)
|
|
self.assertEqual(exp_output, compiled_output)
|
|
self.assertEqual(exp_grad, compiled_grad)
|
|
|
|
run_test()
|
|
old_attr = getattr(mod, name)
|
|
modified_attr = torch.rand_like(old_attr)
|
|
if param_wrapping:
|
|
modified_attr = torch.nn.Parameter(modified_attr)
|
|
setattr(mod, name, modified_attr)
|
|
run_test()
|
|
# Run original version to verify we reuse the other recording
|
|
setattr(mod, name, old_attr)
|
|
run_test()
|
|
|
|
# Fwd + bwd graphs for each version of the function => 4 graphs
|
|
self.assertEqual(self.get_manager().new_graph_id().id, 4)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_single_compile_param_inputs(self):
|
|
# Verify that we can record multiple cudagraphs for a single
|
|
# compiled function with param inputs
|
|
def fn(x, y):
|
|
return x * y
|
|
|
|
# Fwd + bwd graphs for each version of the function => 4 graphs
|
|
self.run_static_input_param_test(fn, 4)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_single_compile_builtin_module(self):
|
|
# Verify that we don't recompile when changing the param of a builtin module
|
|
# and that we record another cudagraph
|
|
# Note: Linear is a builtin module so we enable that config setting above
|
|
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
|
|
# Verify that we don't recompile when changing the buffer of a builtin module
|
|
# and that we record another cudagraph
|
|
self._module_test(
|
|
torch.nn.BatchNorm1d(2, device="cuda"),
|
|
name="running_mean",
|
|
param_wrapping=False,
|
|
)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraphs", True)
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_custom_module(self):
|
|
# Test that we can correctly dispatch multiple graphs
|
|
# if params of a custom module change
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, param) -> None:
|
|
super().__init__()
|
|
self.weight = param
|
|
|
|
def forward(self, x):
|
|
return self.weight * x
|
|
|
|
self._module_test(
|
|
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
|
|
)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_custom_module_buffer(self):
|
|
# Test that we can correctly dispatch multiple graphs
|
|
# if buffers of a custom module change
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, param, buf) -> None:
|
|
super().__init__()
|
|
self.weight = param
|
|
self.buf = torch.nn.Buffer(buf)
|
|
|
|
def forward(self, x):
|
|
return x * self.weight + self.buf
|
|
|
|
self._module_test(
|
|
TestModule(
|
|
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
|
|
torch.rand([2, 2], device="cuda"),
|
|
),
|
|
name="buf",
|
|
param_wrapping=False,
|
|
)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraphs", True)
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_child_node(self):
|
|
# Test that we can correctly dispatch multiple graphs if a child node
|
|
# in the tree has stable input pointers change
|
|
def fn(x, p):
|
|
# Graph 1
|
|
y = x * x
|
|
torch._dynamo.graph_break()
|
|
# Graph 2
|
|
return y * p
|
|
|
|
# We have 5 graphs here
|
|
# Graph 1
|
|
# / \
|
|
# Graph 2 w/ p1 Graph 2 w/ p2
|
|
# and then two backward graphs
|
|
self.run_static_input_param_test(fn, 5)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
def test_multi_dispatch_parent_node(self):
|
|
def fn(x, p):
|
|
# Graph 1
|
|
y = x * p
|
|
torch._dynamo.graph_break()
|
|
# Graph 2
|
|
return y + x
|
|
|
|
# We have 6 graphs here
|
|
# Graph 1 w/ p1 Graph 1 w/ p2
|
|
# | |
|
|
# Graph 2 (v1) Graph 2 (v2)
|
|
# There are two versions of graph 2 because
|
|
# we re-record due to different memory state after running the
|
|
# two versions of Graph 1
|
|
# and then two backward graphs
|
|
self.run_static_input_param_test(fn, 6)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
|
|
def test_fallback_to_eager_if_recompiling_too_many_times(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
|
|
|
|
def forward(self, x):
|
|
return x * self.param
|
|
|
|
with capture_stderr() as captured_output:
|
|
# We have 3 graphs here
|
|
# None
|
|
# / \
|
|
# (fwd w/ p1, Graph 0) (bwd w/p2, Graph2)
|
|
# (bwd w/ p1, Graph 1)
|
|
# All other graphs are skipped because we hit the max recording limit
|
|
# (=0 for each node and function pair)
|
|
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
|
|
for _ in range(3):
|
|
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
|
fn_compiled.param.grad = None
|
|
|
|
# Change static tensor address
|
|
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
|
|
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
|
self.assertEqual(self.get_manager().new_graph_id().id, 3)
|
|
|
|
FileCheck().check(
|
|
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
|
|
"on cudagraph node None due to static input data pointer changed."
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
|
|
def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda"))
|
|
|
|
def forward(self, x):
|
|
return x * self.param
|
|
|
|
with capture_stderr() as captured_output:
|
|
with torch.device("cuda"):
|
|
# We have 3 graphs here
|
|
# None
|
|
# / \
|
|
# (fwd w/ p1, Graph 0) (bwd w/p2, Graph2)
|
|
# (bwd w/ p1, Graph 1)
|
|
# All other graphs are skipped because we hit the max recording limit
|
|
# (=0 for each node and function pair)
|
|
fn_compiled = torch.compile(Foo(), mode="reduce-overhead")
|
|
for _ in range(3):
|
|
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
|
fn_compiled.param.grad = None
|
|
|
|
for _ in range(5):
|
|
# Change static tensor address
|
|
fn_compiled.param.data = torch.rand([2, 2], device="cuda")
|
|
fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward()
|
|
fn_compiled.param.grad = None
|
|
|
|
FileCheck().check_count(
|
|
"skipping cudagraph due to function 0 exceeding max re-recording limit (=0) "
|
|
"on cudagraph node None due to static input data pointer changed.",
|
|
1,
|
|
exactly=True,
|
|
).check_count(
|
|
"skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
|
|
"on cudagraph node None due to static input data pointer changed.",
|
|
1,
|
|
exactly=True,
|
|
).run(
|
|
captured_output[0]
|
|
)
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
|
|
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0)
|
|
def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor(
|
|
self,
|
|
):
|
|
# By setting triton.cudagraph_support_input_mutation=True, we force re-record
|
|
# if cudagraph managed tensor addresses changed.
|
|
@torch.compile(mode="reduce-overhead")
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
def goo(x):
|
|
return x * 2
|
|
|
|
for _ in range(3):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
inp = torch.rand((2, 3), device="cuda")
|
|
y = foo(inp)
|
|
z = goo(y)
|
|
|
|
with capture_stderr() as captured_output:
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
x = torch.rand(2, 3, device="cuda")
|
|
y = foo(x)
|
|
y_clone = y.clone()
|
|
z = goo(y_clone)
|
|
|
|
# eager function should run successfully
|
|
for _ in range(5):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
x = torch.rand(2, 3, device="cuda")
|
|
y = foo(x)
|
|
y_clone = y.clone()
|
|
z = goo(y_clone)
|
|
|
|
FileCheck().check_count(
|
|
"skipping cudagraph due to function 1 exceeding max re-recording limit (=0) "
|
|
"on cudagraph node 0 due to cudagraph managed tensor data pointer changed",
|
|
1,
|
|
exactly=True,
|
|
).run(captured_output[0])
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
|
|
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", False)
|
|
@torch._dynamo.config.patch("error_on_recompile", True)
|
|
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
|
@torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1)
|
|
def test_not_fallback_to_eager_if_have_not_recompiling_too_many_times(self):
|
|
def fn(x, y):
|
|
return x * y
|
|
|
|
# We have 4 graphs here
|
|
# None
|
|
# / \
|
|
# (fwd w/ p1, Graph 0) (fwd w/p2, Graph2)
|
|
# (bwd w/ p1, Graph 1) (bwd w/p2, Graph3)
|
|
self.run_static_input_param_test(fn, 4)
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
|
|
|
def test_tensor_constant_mutation(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.tensor_constant = torch.ones((2, 3), device="cuda")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.tensor_constant += 1
|
|
return x + self.tensor_constant
|
|
|
|
foo = Foo()
|
|
foo = torch.compile(foo, mode="reduce-overhead")
|
|
inp = torch.rand((2, 3), device="cuda")
|
|
for _ in range(3):
|
|
foo(inp)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
|
|
def test_rerecord_if_static_input_address_changed(self):
|
|
# By setting triton.cudagraph_support_input_mutation=True, we force re-record
|
|
# if static tensor addresses changed.
|
|
class Goo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2, device="cuda")
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
return self.linear(x)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"static_tensor", torch.zeros((2, 2), device="cuda")
|
|
)
|
|
self.goo = Goo()
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
self.static_tensor.add_(torch.ones((2, 2), device="cuda"))
|
|
return self.static_tensor + x + self.goo(x)
|
|
|
|
foo = Foo()
|
|
foo = torch.compile(foo, mode="reduce-overhead")
|
|
inp = torch.rand((2, 2), device="cuda")
|
|
|
|
for _ in range(3):
|
|
foo(inp)
|
|
|
|
# mutates static input tensors' addresses
|
|
foo.static_tensor = torch.ones((2, 2), device="cuda")
|
|
foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda"))
|
|
|
|
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
|
for _ in range(3):
|
|
foo(inp)
|
|
else:
|
|
# Run with specific function id to avoid dynamo recompiling
|
|
self.get_manager().run(
|
|
[
|
|
foo.goo.linear.weight,
|
|
foo.goo.linear.bias,
|
|
foo.static_tensor,
|
|
inp,
|
|
],
|
|
FunctionID(0),
|
|
)
|
|
|
|
self.assertEqual(self.get_manager().new_graph_id().id, 2)
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
|
|
def test_skip_if_dynamic_shape_limit_reached1(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3, device="cuda")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.linear(x)
|
|
|
|
def iter(batch_size: int, mod: torch.nn.Module):
|
|
x = torch.rand((batch_size, 3), device="cuda")
|
|
for _ in range(3):
|
|
mod(x)
|
|
|
|
mod = torch.compile(Mod(), mode="reduce-overhead")
|
|
|
|
with capture_stderr() as captured_output:
|
|
for batch_size in range(10, 40, 10):
|
|
iter(batch_size, mod)
|
|
|
|
FileCheck().check(
|
|
"CUDAGraph supports dynamic shapes by recording a new graph for each "
|
|
"distinct input size. Recording too many CUDAGraphs may lead to "
|
|
"extra overhead. We have observed 2 distinct sizes. "
|
|
"Please consider the following options for better performance: "
|
|
"a) padding inputs to a few fixed number of shapes; or b) set "
|
|
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
|
|
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
|
|
"to silence this warning."
|
|
).run("\n".join(captured_output))
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
|
|
def test_skip_if_dynamic_shape_limit_reached2(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attn = torch.nn.MultiheadAttention(
|
|
embed_dim=3, num_heads=3, device="cuda"
|
|
)
|
|
|
|
def forward(
|
|
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return self.attn(q, k, v)
|
|
|
|
mod = torch.compile(Mod(), mode="reduce-overhead")
|
|
|
|
def iter(batch_size: int, length: int):
|
|
q = torch.rand((batch_size, length, 3), device="cuda")
|
|
k = torch.rand((batch_size, length, 3), device="cuda")
|
|
v = torch.rand((batch_size, length, 3), device="cuda")
|
|
for _ in range(3):
|
|
mod(q, k, v)
|
|
|
|
with capture_stderr() as captured_output:
|
|
for batch_size in range(10, 40, 10):
|
|
for length in range(10, 30, 10):
|
|
iter(batch_size, length)
|
|
|
|
print(captured_output)
|
|
FileCheck().check(
|
|
"CUDAGraph supports dynamic shapes by recording a new graph for each "
|
|
"distinct input size. Recording too many CUDAGraphs may lead to "
|
|
"extra overhead. We have observed 2 distinct sizes. "
|
|
"Please consider the following options for better performance: "
|
|
"a) padding inputs to a few fixed number of shapes; or b) set "
|
|
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
|
|
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
|
|
"to silence this warning."
|
|
).run(captured_output[0])
|
|
|
|
@torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1)
|
|
def test_warn_once_if_dynamic_shape_limit_reached(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3, device="cuda")
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.linear(x)
|
|
|
|
def iter(batch_size: int, mod: torch.nn.Module):
|
|
x = torch.rand((batch_size, 3), device="cuda")
|
|
for _ in range(3):
|
|
mod(x)
|
|
|
|
mod = torch.compile(Mod(), mode="reduce-overhead")
|
|
|
|
with capture_stderr() as captured_output:
|
|
for batch_size in range(10, 200, 10):
|
|
iter(batch_size, mod)
|
|
|
|
print(captured_output)
|
|
|
|
FileCheck().check_count(
|
|
"CUDAGraph supports dynamic shapes by recording a new graph for each "
|
|
"distinct input size. Recording too many CUDAGraphs may lead to "
|
|
"extra overhead. We have observed 2 distinct sizes. "
|
|
"Please consider the following options for better performance: "
|
|
"a) padding inputs to a few fixed number of shapes; or b) set "
|
|
"torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. "
|
|
"Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None "
|
|
"to silence this warning.",
|
|
1,
|
|
exactly=True,
|
|
).run("\n".join(captured_output))
|
|
|
|
@torch._inductor.config.patch("cpp_wrapper", 1)
|
|
def test_cpp_wrapper(self):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
compiled = torch.compile(f, mode="reduce-overhead")
|
|
example_input = torch.randn(10, device="cuda")
|
|
compiled_result = self.run_twc(compiled, example_input)
|
|
eager_result = f(example_input)
|
|
self.assertEqual(compiled_result, eager_result)
|
|
|
|
class TestSAC(TestCase):
|
|
def _make_observer_mode(self):
|
|
class ObserverMode(TorchDispatchMode):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.curr_run = 0
|
|
self.op_outputs = defaultdict(list)
|
|
|
|
def __torch_dispatch__(
|
|
self,
|
|
func: OpOverload,
|
|
types: Sequence[type],
|
|
args: Sequence[object] = (),
|
|
kwargs: Mapping[str, object] = immutable_dict(),
|
|
) -> object:
|
|
return func(*args, **kwargs)
|
|
|
|
return ObserverMode
|
|
|
|
def test_simple(self):
|
|
device = "cuda"
|
|
|
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
|
|
|
ObserverMode = self._make_observer_mode()
|
|
|
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
|
def _(mode, op, *args, **kwargs):
|
|
with no_dispatch():
|
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
|
|
|
mode.op_outputs[op].append(out)
|
|
return out
|
|
|
|
obs = ObserverMode()
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
for _ in range(2):
|
|
torch._dynamo.reset()
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
aot_eager_decomp_partition = functools.partial(
|
|
aot_eager_decomp_partition_with_mode, mode=obs
|
|
)
|
|
|
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
|
|
|
fn(x, y).sum().backward()
|
|
|
|
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 4)
|
|
for i in range(2):
|
|
self.assertEqual(
|
|
obs.op_outputs[aten.rand.default][0 + 2 * i],
|
|
obs.op_outputs[aten.rand.default][1 + 2 * i],
|
|
)
|
|
self.assertNotEqual(
|
|
obs.op_outputs[aten.rand.default][0],
|
|
obs.op_outputs[aten.rand.default][2],
|
|
)
|
|
|
|
def test_cudagraph_uneven_forward_backward(self):
|
|
# torch.compile cudagraphs are difficult to test
|
|
# the rng updating bc is sensitive to duration of pending backwards, etc.
|
|
# this is a short repro to mimic the runtime wrappers integration
|
|
# and show that updating the backward rng state with cudagraphs works:
|
|
def forward():
|
|
state = torch.cuda.get_rng_state()
|
|
perm = torch.randperm(10, device="cuda")
|
|
return state, perm
|
|
|
|
def backward(rng_state):
|
|
current_state = torch.cuda.get_rng_state()
|
|
torch.cuda.set_rng_state(rng_state.cpu())
|
|
perm = torch.randperm(10, device="cuda")
|
|
torch.cuda.set_rng_state(current_state)
|
|
return perm
|
|
|
|
def normal_test():
|
|
state, perm = forward()
|
|
repro_perm = backward(state)
|
|
return perm, repro_perm
|
|
|
|
def graphsafe_forward():
|
|
perm = torch.randperm(10, device="cuda")
|
|
return perm
|
|
|
|
def graphsafe_backward(generator, new_state):
|
|
current_state = generator.graphsafe_get_state()
|
|
generator.graphsafe_set_state(new_state)
|
|
perm = torch.randperm(10, device="cuda")
|
|
generator.graphsafe_set_state(current_state)
|
|
return perm
|
|
|
|
def graph_test(generator, capture_cuda_graph):
|
|
if capture_cuda_graph:
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
# state should be cloned before the graph
|
|
old_state = generator.graphsafe_get_state()
|
|
new_state = old_state.clone_state()
|
|
|
|
if capture_cuda_graph:
|
|
# state should be register to the graph
|
|
graph.register_generator_state(new_state)
|
|
|
|
# only capturing the backward
|
|
with torch.cuda.graph(graph):
|
|
repro_perm = graphsafe_backward(generator, new_state)
|
|
|
|
# some number of uneven forwards
|
|
graphsafe_forward()
|
|
graphsafe_forward()
|
|
graphsafe_forward()
|
|
|
|
# state prior to rng invocation
|
|
state = generator.get_state()
|
|
perm = graphsafe_forward()
|
|
|
|
new_state.set_state(state)
|
|
|
|
if capture_cuda_graph:
|
|
graph.replay()
|
|
else:
|
|
repro_perm = graphsafe_backward(generator, new_state)
|
|
|
|
return perm, repro_perm
|
|
|
|
self.assertEqual(*normal_test())
|
|
generator = torch.cuda.default_generators[0]
|
|
self.assertEqual(*graph_test(generator, capture_cuda_graph=False))
|
|
self.assertEqual(*graph_test(generator, capture_cuda_graph=True))
|
|
|
|
def test_cpu_and_cuda_rng(self):
|
|
device = "cuda"
|
|
|
|
ObserverMode = self._make_observer_mode()
|
|
from torch._prims.rng_prims import (
|
|
graphsafe_run_with_rng_state,
|
|
run_and_save_rng_state,
|
|
run_with_rng_state,
|
|
)
|
|
|
|
for hop in [
|
|
graphsafe_run_with_rng_state,
|
|
run_and_save_rng_state,
|
|
run_with_rng_state,
|
|
]:
|
|
|
|
def make_impl(hop):
|
|
@hop.py_impl(ObserverMode)
|
|
def _(mode, *args, **kwargs):
|
|
with no_dispatch():
|
|
out = hop(*args, **kwargs)
|
|
|
|
op = None
|
|
for inp in itertools.chain(args, kwargs.values()):
|
|
if isinstance(inp, torch._ops.OpOverload):
|
|
op = inp
|
|
break
|
|
assert op is not None
|
|
if hop is run_and_save_rng_state:
|
|
mode.op_outputs[op].append(out[1])
|
|
else:
|
|
mode.op_outputs[op].append(out)
|
|
return out
|
|
|
|
make_impl(hop)
|
|
|
|
obs = ObserverMode()
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def gn2(x):
|
|
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
|
|
|
def fn(x, y, z):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
z = torch.utils.checkpoint.checkpoint(gn2, z, use_reentrant=True)
|
|
return x * z.cuda()
|
|
|
|
aot_eager_decomp_partition = functools.partial(
|
|
aot_eager_decomp_partition_with_mode, mode=obs
|
|
)
|
|
|
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
z = torch.randn(4, 4, requires_grad=True)
|
|
|
|
fn(x, y, z).sum().backward()
|
|
for op in [aten.rand.default, aten.randperm.default]:
|
|
self.assertEqual(len(obs.op_outputs[op]), 2)
|
|
self.assertEqual(
|
|
obs.op_outputs[op][0],
|
|
obs.op_outputs[op][1],
|
|
)
|
|
self.assertEqual(
|
|
obs.op_outputs[op][0].device.type,
|
|
"cpu" if op == aten.randperm.default else "cuda",
|
|
)
|
|
|
|
@parametrize("order", (list(itertools.permutations([0, 1, 2]))))
|
|
def test_uneven_forward_backward(self, order):
|
|
device = "cuda"
|
|
|
|
ObserverMode = self._make_observer_mode()
|
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
|
|
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
|
def _(mode, op, *args, **kwargs):
|
|
with no_dispatch():
|
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
|
|
|
mode.op_outputs[(mode.curr_run, op)].append(out)
|
|
return out
|
|
|
|
obs = ObserverMode()
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def gn2(x):
|
|
return x * torch.randperm(x.numel(), device=x.device).reshape(x.shape)
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn2, x, use_reentrant=True)
|
|
return x
|
|
|
|
aot_eager_decomp_partition = functools.partial(
|
|
aot_eager_decomp_partition_with_mode, mode=obs
|
|
)
|
|
|
|
fn_c = torch.compile(fn, backend=aot_eager_decomp_partition)
|
|
|
|
torch.manual_seed(0)
|
|
outs = []
|
|
for i in range(len(order)):
|
|
obs.curr_run = i
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
outs.append(fn_c(x, y))
|
|
|
|
for idx in order:
|
|
obs.curr_run = idx
|
|
outs[idx].sum().backward()
|
|
|
|
for run in range(len(order)):
|
|
for op in (aten.rand.default, aten.randperm.default):
|
|
self.assertEqual(len(obs.op_outputs[(run, op)]), 2)
|
|
self.assertEqual(
|
|
obs.op_outputs[(run, op)][0],
|
|
obs.op_outputs[(run, op)][1],
|
|
)
|
|
if run != 0:
|
|
self.assertNotEqual(
|
|
obs.op_outputs[(run - 1, op)][0],
|
|
obs.op_outputs[(run, op)][0],
|
|
)
|
|
|
|
@config.patch(fallback_random=True)
|
|
@config.patch("test_configs.graphsafe_rng_func_ignores_fallback_random", True)
|
|
def _test_cudagraphs_aot_eager_compat_equal(self, device):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
outs = []
|
|
grads = []
|
|
|
|
outs2 = []
|
|
grads2 = []
|
|
|
|
compile_fns = [
|
|
lambda fn: torch.compile(fn, backend="aot_eager_decomp_partition"),
|
|
lambda fn: torch.compile(fn, mode="reduce-overhead"),
|
|
]
|
|
for i, compile_fn in enumerate(compile_fns):
|
|
torch.manual_seed(0)
|
|
for index in range(3):
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
out = compile_fn(fn)(x, y)
|
|
torch.cuda.synchronize()
|
|
out.sum().backward()
|
|
if i == 0:
|
|
outs.append(out.clone())
|
|
grads.append((x.grad.clone(), y.grad.clone()))
|
|
else:
|
|
outs2.append(out.clone())
|
|
grads2.append((x.grad.clone(), y.grad.clone()))
|
|
|
|
self.assertEqual(outs, outs2)
|
|
self.assertEqual(grads, grads2)
|
|
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
|
|
|
|
def test_cudagraphs_aot_eager_compat_equal(self):
|
|
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:0"))
|
|
|
|
@requires_multigpu()
|
|
def test_cudagraphs_aot_eager_compat_equal_device_one(self):
|
|
self._test_cudagraphs_aot_eager_compat_equal(torch.device("cuda:1"))
|
|
|
|
@requires_multigpu()
|
|
def test_multi_device(self):
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
def multi_fn(x, y, a, b):
|
|
return fn(x, y), fn(a, b)
|
|
|
|
x = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
|
y = torch.randn(4, 4, device="cuda:0", requires_grad=True)
|
|
|
|
a = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
|
b = torch.randn(4, 4, device="cuda:1", requires_grad=True)
|
|
|
|
# No errors. TODO - get graphs from logging, couldnt figure out how
|
|
multi_fn_c = torch.compile(multi_fn, backend="aot_eager_decomp_partition")
|
|
|
|
out = multi_fn_c(x, y, a, b)
|
|
out[0].sum().backward()
|
|
|
|
def test_retain_graph(self):
|
|
device = "cuda"
|
|
|
|
ObserverMode = self._make_observer_mode()
|
|
from torch._prims.rng_prims import graphsafe_run_with_rng_state
|
|
|
|
@graphsafe_run_with_rng_state.py_impl(ObserverMode)
|
|
def _(mode, op, *args, **kwargs):
|
|
with no_dispatch():
|
|
out = graphsafe_run_with_rng_state(op, *args, **kwargs)
|
|
|
|
mode.op_outputs[op].append(out)
|
|
return out
|
|
|
|
obs = ObserverMode()
|
|
|
|
def gn(x, y):
|
|
return torch.sigmoid(torch.rand_like(x) * y) * x
|
|
|
|
def fn(x, y):
|
|
x = torch.sin(x)
|
|
x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
x = torch.randn(4, 4, device=device, requires_grad=True)
|
|
y = torch.randn(4, 4, device=device, requires_grad=True)
|
|
|
|
aot_eager_decomp_partition = functools.partial(
|
|
aot_eager_decomp_partition_with_mode, mode=obs
|
|
)
|
|
|
|
fn = torch.compile(fn, backend=aot_eager_decomp_partition)
|
|
|
|
out = fn(x, y).sum()
|
|
out.backward(retain_graph=True)
|
|
out.backward()
|
|
self.assertEqual(len(obs.op_outputs[aten.rand.default]), 3)
|
|
self.assertEqual(
|
|
obs.op_outputs[aten.rand.default][0],
|
|
obs.op_outputs[aten.rand.default][1],
|
|
)
|
|
self.assertEqual(
|
|
obs.op_outputs[aten.rand.default][1],
|
|
obs.op_outputs[aten.rand.default][2],
|
|
)
|
|
|
|
instantiate_parametrized_tests(CudaGraphTreeTests)
|
|
instantiate_parametrized_tests(TestSAC)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
if not TEST_CUDA_GRAPH:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise unittest.SkipTest("cuda graph test is skipped")
|
|
|
|
if HAS_CUDA:
|
|
run_tests(needs="filelock")
|