mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
cudagraph trees (#89146)
CUDA Graph Trees Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/edit Not currently implemented : - Right now, we are using weak tensor refs from outputs to check if a tensor has dies. This doesn't work because a) aliasing, and b) aot_autograd detaches tensors (see note [Detaching saved tensors in AOTAutograd]). Would need either https://github.com/pytorch/pytorch/issues/91395 to land to use storage weak refs or manually add a deleter fn that does what I want. This is doable but theres some interactions with the caching allocator checkpointing so saving for a stacked pr. - Reclaiming memory from the inputs during model recording. This isn't terribly difficult but deferring to another PR. You would need to write over the input memory during warmup, and therefore copy the inputs to cpu. Saving for a stacked pr. - Warning on overwriting previous generation outputs. and handling nested torch.compile() calls in generation tracking Differential Revision: [D43999887](https://our.internmc.facebook.com/intern/diff/D43999887) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89146 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf732053e4
commit
571f96bf59
534
test/inductor/test_cudagraph_trees.py
Normal file
534
test/inductor/test_cudagraph_trees.py
Normal file
@ -0,0 +1,534 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch.nn as nn
|
||||
from torch._inductor import config
|
||||
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_CI,
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
|
||||
if IS_WINDOWS and IS_CI:
|
||||
sys.stderr.write(
|
||||
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
||||
|
||||
importlib.import_module("functorch")
|
||||
importlib.import_module("filelock")
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
|
||||
aten = torch.ops.aten
|
||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||
requires_multigpu = functools.partial(
|
||||
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
|
||||
)
|
||||
|
||||
|
||||
class TestCase(TorchTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._stack = contextlib.ExitStack()
|
||||
cls._stack.enter_context(
|
||||
config.patch(
|
||||
{
|
||||
"debug": True,
|
||||
"cpp.min_chunk_size": 1,
|
||||
"triton.autotune_pointwise": False, # too slow
|
||||
"implicit_fallbacks": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._stack.close()
|
||||
super().tearDownClass()
|
||||
|
||||
def setUp(self):
|
||||
torch._dynamo.reset()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
def get_all_cudagraph_segments():
|
||||
segments = torch.cuda.memory_snapshot()
|
||||
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
|
||||
|
||||
def all_live_blocks():
|
||||
blocks_addrs = []
|
||||
for segment in get_all_cudagraph_segments():
|
||||
addr = segment["address"]
|
||||
for block in segment["blocks"]:
|
||||
if block["state"] == "active_allocated":
|
||||
blocks_addrs.append(addr)
|
||||
addr += block["size"]
|
||||
|
||||
return blocks_addrs
|
||||
|
||||
def all_live_block_count():
|
||||
return len(all_live_blocks())
|
||||
|
||||
class CudaGraphTreeTests(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.prev_enabled = config.triton.cudagraphs
|
||||
self.tapes_enabled = config.triton.cudagraph_trees
|
||||
config.triton.cudagraphs = True
|
||||
config.triton.cudagraph_trees = True
|
||||
self.device_idx = torch.rand([0], device="cuda").device.index
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._dynamo.reset()
|
||||
gc.collect()
|
||||
config.triton.cudagraphs = self.prev_enabled
|
||||
config.triton.cudagraph_trees = self.tapes_enabled
|
||||
self.assertIsNone(self.get_manager())
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
def get_manager(self, device_index=None):
|
||||
return torch._inductor.cudagraph_trees.get_container(
|
||||
(self.device_idx if not device_index else device_index)
|
||||
).tree_manager
|
||||
|
||||
def get_roots(self):
|
||||
return self.get_manager().get_roots()
|
||||
|
||||
def curr_node(self):
|
||||
return self.get_manager().current_node
|
||||
|
||||
def get_root_children(self):
|
||||
return [root.num_descendants() for root in self.get_roots()]
|
||||
|
||||
def cudagraphify_impl(self, *args, **kwargs):
|
||||
return tree_cudagraphify_impl(*args, **kwargs, device_index=self.device_idx)
|
||||
|
||||
@staticmethod
|
||||
def run_twc(fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def num_checkpoints(self):
|
||||
return self.get_manager().debug_checkpointing_counter
|
||||
|
||||
def test_run_simple(self):
|
||||
def foo(x):
|
||||
return x * x * x
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
ones = torch.ones([4, 4], device="cuda")
|
||||
zeros = torch.zeros([5, 5], device="cuda")
|
||||
self.run_twc(foo_opt, ones)
|
||||
self.run_twc(foo_opt, zeros)
|
||||
self.assertEqual(self.get_root_children(), [0, 0])
|
||||
|
||||
def test_function_compiled_multiple_times(self):
|
||||
def foo(x):
|
||||
y = foo2(x)
|
||||
y2 = foo2(y)
|
||||
return y + y2
|
||||
|
||||
def foo2(x):
|
||||
torch._dynamo.graph_break()
|
||||
return x * x * x
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
ones = torch.ones([4, 4], device="cuda")
|
||||
foo(ones)
|
||||
foo_opt(ones)
|
||||
foo_opt(ones)
|
||||
self.assertEqual(foo_opt(ones), foo(ones))
|
||||
# paths
|
||||
children = self.get_root_children()
|
||||
# one root with two children
|
||||
self.assertEqual(children, [2])
|
||||
|
||||
def test_end_recording_early(self):
|
||||
def foo(x):
|
||||
y = x * x * x
|
||||
torch._dynamo.graph_break()
|
||||
z = x + y
|
||||
return z
|
||||
|
||||
@torch._dynamo.optimize()
|
||||
def foo2(x):
|
||||
return x + 4
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
|
||||
for _ in range(3):
|
||||
out = foo_opt(torch.ones([4, 4], device="cuda"))
|
||||
del out
|
||||
|
||||
# when I tried inducing separate recordings via graph break,
|
||||
# the frame kept interferring by keeping outputs alive
|
||||
# this isnt great by simulates the logic.
|
||||
from torch._dynamo.mutation_guard import GenerationTracker
|
||||
|
||||
GenerationTracker.generation -= 1
|
||||
|
||||
out = foo2(torch.ones([4, 4], device="cuda"))
|
||||
del out
|
||||
|
||||
foo_opt(torch.ones([4, 4], device="cuda"))
|
||||
|
||||
# Two separate traces - one has a child, one doesnt
|
||||
self.assertEqual(self.get_root_children(), [1, 0])
|
||||
|
||||
def test_execution_into_recording(self):
|
||||
def foo(x):
|
||||
y = x + x
|
||||
|
||||
if y.sum() > 0:
|
||||
return y + 10
|
||||
else:
|
||||
return y - 10
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
|
||||
self.assertEqual(foo_opt(inp), foo(inp))
|
||||
self.assertEqual(foo_opt(inp), foo(inp))
|
||||
|
||||
inp.add_(1)
|
||||
out_eager = foo(inp)
|
||||
out_warmup = foo_opt(inp)
|
||||
self.assertEqual(out_warmup, out_eager)
|
||||
# warmup should be have storage deallocator hooked on
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
|
||||
out_live = foo_opt(inp)
|
||||
self.assertEqual(out_live, out_eager)
|
||||
|
||||
# should be in recording mode, with storage deallocator hooked on
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
# warmup should have been freed
|
||||
del out_warmup
|
||||
# should be in recording mode, with storage deallocator hooked on
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
|
||||
del out_live
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
out = foo_opt(inp)
|
||||
self.assertEqual(foo(inp), out)
|
||||
|
||||
# should be in execution mode
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
def test_accumulate_multiple_recordings(self):
|
||||
def foo(x):
|
||||
y = x + x + x
|
||||
torch._dynamo.graph_break()
|
||||
if y.sum() <= 0:
|
||||
return y
|
||||
else:
|
||||
return y * 10
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
|
||||
# two separate compilations & recordings
|
||||
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
|
||||
|
||||
# out1 gets manually freed
|
||||
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
|
||||
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
|
||||
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
|
||||
|
||||
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
|
||||
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
del out1, out2
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
|
||||
del out3
|
||||
gc.collect()
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
def test_live_outputs_multiple_graphs(self):
|
||||
def foo(x):
|
||||
x = x + x + x
|
||||
y = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
z = x * x
|
||||
if z.sum() > 0:
|
||||
return y + 1
|
||||
else:
|
||||
return y
|
||||
|
||||
foo_opt = torch._dynamo.optimize()(foo)
|
||||
|
||||
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
|
||||
self.assertEqual(self.num_checkpoints(), 0)
|
||||
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
|
||||
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
|
||||
del out
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
# we need to checkpoint from function to warmup y + 1,
|
||||
# and then again to record it
|
||||
self.assertEqual(self.num_checkpoints(), 2)
|
||||
|
||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||
def test_tensor_dies_between_checkpoint(self):
|
||||
def foo(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
return x + 1, x + 2
|
||||
|
||||
inp = torch.rand([4], device="cuda")
|
||||
foo_cg = self.cudagraphify_impl(foo, [inp], ())
|
||||
foo_cg([inp])
|
||||
foo_cg([inp])
|
||||
|
||||
out1, out2 = foo_cg([inp])
|
||||
inp = [out1]
|
||||
|
||||
del out1, out2
|
||||
|
||||
def foo2(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
return [x * x * x]
|
||||
|
||||
self.assertEqual(self.num_checkpoints(), 0)
|
||||
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
|
||||
|
||||
x = foo2_cg(inp)[0]
|
||||
|
||||
self.assertEqual(self.num_checkpoints(), 1)
|
||||
# out2 dies between the previous recording and the new one,
|
||||
# need to be manually deallocated after the checkpoint
|
||||
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
del x
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||
def test_tensor_no_longer_in_pool(self):
|
||||
def foo(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
return x + 1, x + 2
|
||||
|
||||
inp = torch.rand([4], device="cuda")
|
||||
foo_cg = self.cudagraphify_impl(foo, [inp], ())
|
||||
x1, x2 = foo_cg([inp])
|
||||
|
||||
def foo2(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
return [x * x * x]
|
||||
|
||||
foo2_cg = self.cudagraphify_impl(foo2, [x1], ())
|
||||
foo2_cg([x1])
|
||||
|
||||
del x1, x2
|
||||
# TODO make configurable
|
||||
|
||||
x1, x2 = foo_cg([inp])
|
||||
self.assertEqual(self.num_checkpoints(), 0)
|
||||
|
||||
# input location has changed, should force recompile and checkpointing
|
||||
foo2_cg([torch.zeros_like(x1)])
|
||||
|
||||
self.assertEqual(self.num_checkpoints(), 1)
|
||||
self.assertEqual(self.get_root_children(), [2])
|
||||
|
||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||
def test_checkpoint_shared_output_storage_deallocation(self):
|
||||
def foo(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
x_tmp = x + 1
|
||||
return x[0], x[1]
|
||||
|
||||
inp = torch.rand([2, 2], device="cuda")
|
||||
foo_cg = self.cudagraphify_impl(foo, [inp], ())
|
||||
foo_cg([inp])
|
||||
foo_cg([inp])
|
||||
|
||||
x1, x2 = foo_cg([inp])
|
||||
inp = [x1]
|
||||
|
||||
def foo2(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
y = x * x
|
||||
return y[0], y[1]
|
||||
|
||||
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
|
||||
foo2_cg(inp)
|
||||
|
||||
self.assertEqual(self.num_checkpoints(), 1)
|
||||
self.assertEqual(
|
||||
x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
|
||||
)
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
del x1
|
||||
self.assertEqual(all_live_block_count(), 1)
|
||||
del x2
|
||||
self.assertEqual(all_live_block_count(), 0)
|
||||
|
||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||
def test_cleanup(self):
|
||||
def test_closure():
|
||||
@torch._dynamo.optimize()
|
||||
def foo(x):
|
||||
return x + 1 + 2, x * 10
|
||||
|
||||
foo(torch.rand([4], device="cuda"))
|
||||
return foo(torch.rand([4], device="cuda"))
|
||||
|
||||
out1, out2 = test_closure()
|
||||
torch._dynamo.reset()
|
||||
|
||||
# TODO - deallocate on tensor deallocation
|
||||
# self.assertTrue(self.get_manager() is not None)
|
||||
# del out1
|
||||
# self.assertTrue(self.get_manager() is not None)
|
||||
# del out2
|
||||
self.assertTrue(self.get_manager() is None)
|
||||
|
||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||
def test_forward_backward(self):
|
||||
@torch._dynamo.optimize()
|
||||
def foo(x):
|
||||
y = x * 2
|
||||
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
|
||||
|
||||
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
|
||||
print("Input ID", id(inp))
|
||||
out = foo(inp)
|
||||
out.sum().backward()
|
||||
|
||||
self.assertEqual(self.get_root_children(), [1])
|
||||
|
||||
# the three saved tensors should die in the backward
|
||||
# we kept alive the output
|
||||
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
|
||||
self.assertEqual(
|
||||
self.curr_node().expected_dead_indices_after_graph,
|
||||
[(0, 1), (0, 2), (0, 3)],
|
||||
)
|
||||
|
||||
def test_separate_recordings(self):
|
||||
def foo_unopt(x, y):
|
||||
return (x + 1) @ y
|
||||
|
||||
foo = torch._dynamo.optimize()(foo_unopt)
|
||||
|
||||
foo_unopt(
|
||||
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
|
||||
)
|
||||
|
||||
inps = [
|
||||
torch.ones([20, 20], device="cuda", requires_grad=False)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
out = foo(*inps)
|
||||
torch.cuda.synchronize()
|
||||
foo(*inps)
|
||||
torch.cuda.synchronize()
|
||||
foo(*inps)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
foo_unopt(
|
||||
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
|
||||
)
|
||||
|
||||
inps2 = [
|
||||
torch.rand([40, 40], device="cuda", requires_grad=False)
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
foo(*inps2)
|
||||
foo(*inps2)
|
||||
foo(*inps2)
|
||||
|
||||
# two separate roots
|
||||
self.assertEqual(self.get_root_children(), [0, 0])
|
||||
|
||||
def test_alias_of_parameter(self):
|
||||
class AliasMod(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.param[0], self.param, self.param + x
|
||||
|
||||
@torch.compile(mode="reduce-overhead")
|
||||
def foo(mod, inp):
|
||||
return mod(inp)
|
||||
|
||||
inp = torch.rand([20, 20], device="cuda")
|
||||
mod = AliasMod()
|
||||
|
||||
storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
|
||||
mod.param.untyped_storage()
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
outs = foo(mod, inp)
|
||||
|
||||
self.assertEqual(mod(inp), outs)
|
||||
|
||||
self.assertFalse(storage_ref.expired())
|
||||
|
||||
node = self.get_manager().current_node
|
||||
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
||||
|
||||
@requires_multigpu()
|
||||
def test_manager_per_device(self):
|
||||
def test():
|
||||
def foo(args):
|
||||
x = args[0]
|
||||
args.clear()
|
||||
return x + 3
|
||||
|
||||
inp = torch.rand([20, 20], device="cuda:1")
|
||||
|
||||
foo_cg = tree_cudagraphify_impl(foo, [inp], (), device_index=1)
|
||||
self.assertEqual(foo_cg([inp]), foo([inp]))
|
||||
|
||||
self.assertTrue(self.get_manager(device_index=0) is None)
|
||||
self.assertFalse(self.get_manager(device_index=1) is None)
|
||||
|
||||
test()
|
||||
self.assertTrue(self.get_manager(device_index=1) is None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
|
||||
run_tests(needs="filelock")
|
@ -5343,6 +5343,12 @@ class TestBlockStateAbsorption(TestCase):
|
||||
for before_block, after_block in zip(before_segment["blocks"], after_segment["blocks"]):
|
||||
self.checkCheckpointedBlock(before_block, after_block)
|
||||
|
||||
@staticmethod
|
||||
def setCheckpointPoolState(device, state, stale_storages_ptr, storages_deleters=None):
|
||||
stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
|
||||
storages_deleters = [] if not storages_deleters else [t.untyped_storage()._cdata for t in storages_deleters]
|
||||
torch._C._cuda_setCheckpointPoolState(device, state, stale_storages_ptr, storages_deleters)
|
||||
|
||||
def checkFunction(self, fn, inputs, pool=None):
|
||||
graph, outputs = cudagraphify(fn, inputs, pool=pool)
|
||||
|
||||
@ -5352,7 +5358,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
segments_before_checkpoint = get_cudagraph_segments(pool_id)
|
||||
|
||||
state = torch._C._cuda_getCheckpointState(device, pool_id)
|
||||
torch._C._cuda_setCheckpointPoolState(device, state, [], [])
|
||||
self.setCheckpointPoolState(device, state, [], [])
|
||||
|
||||
self.checkCheckpointedState(segments_before_checkpoint, get_cudagraph_segments(pool_id))
|
||||
|
||||
@ -5422,7 +5428,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
|
||||
|
||||
|
||||
torch._C._cuda_setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
|
||||
self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
|
||||
|
||||
del outputs2
|
||||
|
||||
@ -5445,7 +5451,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
|
||||
# graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
|
||||
# with self.assertRaisesRegex(Exception, "being manually freed must be passed"):
|
||||
# torch._C._cuda_setCheckpointPoolState(outputs[0].device.index, state, [], [])
|
||||
# self.setCheckpointPoolState(outputs[0].device.index, state, [], [])
|
||||
|
||||
def test_tensor_dies_after_checkpoint(self):
|
||||
|
||||
@ -5463,7 +5469,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
|
||||
del outputs
|
||||
|
||||
torch._C._cuda_setCheckpointPoolState(device, state, [], [])
|
||||
self.setCheckpointPoolState(device, state, [], [])
|
||||
|
||||
self.assertEqual(live_blocks(pool_id), 2)
|
||||
torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
|
||||
@ -5508,7 +5514,7 @@ class TestBlockStateAbsorption(TestCase):
|
||||
for i in range(len(reconstructed_tensors)):
|
||||
self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
|
||||
|
||||
torch._C._cuda_setCheckpointPoolState(device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]])
|
||||
self.setCheckpointPoolState(device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]])
|
||||
|
||||
self.assertEqual(live_blocks(pool_id), 3)
|
||||
|
||||
|
@ -193,9 +193,13 @@ def compile_fx_inner(
|
||||
and not graph.mutated_inputs
|
||||
and not has_incompatible_cudagraph_ops(gm)
|
||||
and not complex_memory_overlap_inputs
|
||||
and (len(graph.device_idxs) == 1 or not config.triton.cudagraph_trees)
|
||||
):
|
||||
compiled_fn = cudagraphify(
|
||||
compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
|
||||
compiled_fn,
|
||||
example_inputs,
|
||||
static_input_idxs=range(num_fixed),
|
||||
device_index=next(iter(graph.device_idxs)),
|
||||
)
|
||||
else:
|
||||
BoxedBool.disable(cudagraphs)
|
||||
@ -209,6 +213,10 @@ def compile_fx_inner(
|
||||
developer_warning(
|
||||
"skipping cudagraphs due to complex input striding"
|
||||
)
|
||||
elif len(graph.device_idxs) > 1 and config.triton.cudagraph_trees:
|
||||
developer_warning(
|
||||
"skipping cudagraphs due to multiple device indexes"
|
||||
)
|
||||
|
||||
result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
|
||||
_step_logger()(
|
||||
@ -259,10 +267,21 @@ def align_inputs(model, inputs, static_input_idxs=()):
|
||||
|
||||
|
||||
@dynamo_utils.dynamo_timed
|
||||
def cudagraphify(model, inputs, static_input_idxs=()):
|
||||
def cudagraphify(model, inputs, static_input_idxs=(), *, device_index: int):
|
||||
from torch._inductor.cudagraph_trees import (
|
||||
cudagraphify_impl as new_cudagraphify_impl,
|
||||
)
|
||||
|
||||
if config.triton.cudagraph_trees:
|
||||
cudagraphify_fn = functools.partial(
|
||||
new_cudagraphify_impl, device_index=device_index
|
||||
)
|
||||
else:
|
||||
cudagraphify_fn = cudagraphify_impl
|
||||
|
||||
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
|
||||
if not any(isinstance(inp, FakeTensor) for inp in inputs):
|
||||
return cudagraphify_impl(model, inputs, static_input_idxs)
|
||||
return cudagraphify_fn(model, inputs, static_input_idxs)
|
||||
|
||||
compiled_fn = None
|
||||
|
||||
@ -270,8 +289,7 @@ def cudagraphify(model, inputs, static_input_idxs=()):
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
with dynamo_utils.preserve_rng_state():
|
||||
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
|
||||
|
||||
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
||||
return compiled_fn(new_inputs)
|
||||
|
||||
return run
|
||||
@ -290,27 +308,28 @@ def remove_unaligned_input_idxs(inputs, static_input_idxs):
|
||||
return static_input_idxs
|
||||
|
||||
|
||||
def static_input(x):
|
||||
"""
|
||||
Copy and input while preserving strides
|
||||
"""
|
||||
# TODO(jansel): figure out why this version doesn't work:
|
||||
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
|
||||
needed_size = (
|
||||
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
||||
)
|
||||
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
|
||||
return torch.as_strided(buffer, x.size(), x.stride())
|
||||
|
||||
|
||||
def cudagraphify_impl(model, inputs, static_input_idxs=()):
|
||||
"""
|
||||
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
||||
"""
|
||||
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
|
||||
|
||||
def static_input(x):
|
||||
"""
|
||||
Copy and input while preserving strides
|
||||
"""
|
||||
# TODO(jansel): figure out why this version doesn't work:
|
||||
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
|
||||
needed_size = (
|
||||
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
|
||||
)
|
||||
buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
|
||||
return torch.as_strided(buffer, x.size(), x.stride())
|
||||
|
||||
assert isinstance(inputs, (list, tuple))
|
||||
static_inputs = [
|
||||
static_input(x) if idx not in static_input_idxs else x.detach()
|
||||
static_input(x).zero_() if idx not in static_input_idxs else x.detach()
|
||||
for idx, x in enumerate(inputs)
|
||||
]
|
||||
|
||||
|
@ -186,6 +186,14 @@ class triton:
|
||||
# Use cudagraphs on output code
|
||||
cudagraphs = False
|
||||
|
||||
# Use cudagraph trees for memory pooling if `cudagraphs` is True
|
||||
cudagraph_trees = False
|
||||
|
||||
debug_cudagraph_trees = True
|
||||
|
||||
# skip warmup for cudagraph trees
|
||||
skip_cudagraph_warmup = False
|
||||
|
||||
# Synchronize before and after every compiled graph.
|
||||
debug_sync_graph = False
|
||||
|
||||
|
1328
torch/_inductor/cudagraph_trees.py
Normal file
1328
torch/_inductor/cudagraph_trees.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -130,6 +130,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
||||
self.graph_outputs: Optional[List[ir.IRNode]] = None
|
||||
self.device_types: Set[str] = set()
|
||||
self.device_idxs: Set[int] = set()
|
||||
self.buffers: List[ir.ComputedBuffer] = []
|
||||
self.constants: Dict[str, torch.Tensor] = {}
|
||||
self.removed_buffers: Set[str] = set()
|
||||
@ -319,6 +320,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.graph_inputs[target] = tensor
|
||||
self.graph_inputs_original[target] = tensor.data.data
|
||||
self.device_types.add(example.device.type)
|
||||
if example.device.type == "cuda":
|
||||
self.device_idxs.add(example.device.index)
|
||||
return tensor
|
||||
|
||||
def call_function(self, target, args, kwargs):
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <c10/util/UniqueVoidPtr.h>
|
||||
#include <unordered_set>
|
||||
#if AT_CUDNN_ENABLED()
|
||||
|
||||
@ -883,8 +884,6 @@ static void registerCudaDeviceProperties(PyObject* module) {
|
||||
});
|
||||
}
|
||||
|
||||
void no_op_delete(void* ptr){};
|
||||
|
||||
// We choose to ignore certain blocks that are currently allocated
|
||||
// when we set the pool to its checkpoint. For those blocks, we need
|
||||
// to swap out the deleter function of their corresponding blocks
|
||||
@ -898,7 +897,7 @@ void removeStorageDeleterFns(
|
||||
TORCH_CHECK(allocated_pointer != definitely_stale_pointers.end());
|
||||
auto t = c10::cuda::CUDACachingAllocator::get();
|
||||
bool succeeded = stale_storage->data_ptr().compare_exchange_deleter(
|
||||
t->raw_deleter(), &no_op_delete);
|
||||
t->raw_deleter(), &c10::detail::deleteNothing);
|
||||
|
||||
TORCH_CHECK(
|
||||
succeeded,
|
||||
@ -907,12 +906,11 @@ void removeStorageDeleterFns(
|
||||
}
|
||||
|
||||
void addStorageDeleterFns(
|
||||
std::vector<at::Tensor>& tensors_to_add_deleters_to,
|
||||
std::vector<c10::StorageImpl*>& storages_to_add_deleters_to,
|
||||
c10::cuda::CUDACachingAllocator::CheckpointDelta& delta) {
|
||||
std::unordered_map<void*, c10::StorageImpl*> storages;
|
||||
for (auto& tensor : tensors_to_add_deleters_to) {
|
||||
storages[tensor.storage().data_ptr().get()] =
|
||||
tensor.storage().unsafeGetStorageImpl();
|
||||
for (auto& storage : storages_to_add_deleters_to) {
|
||||
storages[storage->data_ptr().get()] = storage;
|
||||
}
|
||||
|
||||
for (auto& data_ptr : delta.dataptrs_allocd) {
|
||||
@ -1044,6 +1042,23 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
return c10::cuda::CUDACachingAllocator::getCheckpointState(device, id);
|
||||
});
|
||||
|
||||
m.def("_free_And_Remove_DeleterFn", [](size_t storage_impl_ptr) {
|
||||
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
|
||||
auto alloc = c10::cuda::CUDACachingAllocator::get();
|
||||
auto data_ptr = storage_impl->data_ptr().get();
|
||||
bool succeeded = storage_impl->data_ptr().compare_exchange_deleter(
|
||||
alloc->raw_deleter(), c10::detail::deleteNothing);
|
||||
TORCH_CHECK("Expected standard deleter");
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(data_ptr);
|
||||
});
|
||||
|
||||
m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) {
|
||||
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
|
||||
auto alloc = c10::cuda::CUDACachingAllocator::get();
|
||||
auto data_ptr = storage_impl->data_ptr().get();
|
||||
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_cuda_beginAllocateCurrentStreamToPool",
|
||||
[](int device, at::cuda::MempoolId_t mempool_id) {
|
||||
@ -1067,17 +1082,16 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
"_cuda_setCheckpointPoolState",
|
||||
[](int device,
|
||||
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps,
|
||||
std::vector<at::Tensor> stale_tensors,
|
||||
std::vector<at::Tensor> tensors_to_add_deleters_to = {}) {
|
||||
// Could pass in Storage Pointers instead
|
||||
std::vector<size_t> stale_storages_ptr,
|
||||
std::vector<size_t> storages_to_add_deleters_to_ptr = {}) {
|
||||
std::unordered_set<c10::StorageImpl*> ptr_set;
|
||||
// iterate on std::vector for determinism
|
||||
std::vector<c10::StorageImpl*> ptrs;
|
||||
for (const auto& ten : stale_tensors) {
|
||||
auto ptr = ten.storage().unsafeGetStorageImpl();
|
||||
for (size_t ptr_int : stale_storages_ptr) {
|
||||
c10::StorageImpl* ptr = (c10::StorageImpl*)ptr_int;
|
||||
if (!ptr_set.count(ptr)) {
|
||||
ptrs.push_back(ten.storage().unsafeGetStorageImpl());
|
||||
ptr_set.insert(ten.storage().unsafeGetStorageImpl());
|
||||
ptrs.push_back(ptr);
|
||||
ptr_set.insert(ptr);
|
||||
}
|
||||
}
|
||||
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
|
||||
@ -1107,8 +1121,12 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
" must be passed to set checkpoint");
|
||||
|
||||
removeStorageDeleterFns(ptrs, freed_pointer_set);
|
||||
std::vector<c10::StorageImpl*> storages_to_add_deleters_to;
|
||||
for (size_t ptr_int : storages_to_add_deleters_to_ptr) {
|
||||
storages_to_add_deleters_to.push_back((c10::StorageImpl*)ptr_int);
|
||||
}
|
||||
|
||||
addStorageDeleterFns(tensors_to_add_deleters_to, delta);
|
||||
addStorageDeleterFns(storages_to_add_deleters_to, delta);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from weakref import ref
|
||||
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
|
||||
from collections.abc import MutableMapping, Mapping
|
||||
from typing import Dict
|
||||
from torch import Tensor
|
||||
import collections.abc as _collections_abc
|
||||
|
||||
|
||||
__all__ = ['WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
|
||||
WeakRef = ref
|
||||
|
||||
|
||||
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
|
||||
|
||||
|
||||
# This file defines a variant of WeakKeyDictionary that overrides the hashing
|
||||
@ -261,3 +267,25 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
|
||||
# Convenience alias
|
||||
WeakTensorKeyDictionary = WeakIdKeyDictionary
|
||||
|
||||
|
||||
class TensorWeakRef:
|
||||
"""
|
||||
Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required
|
||||
when unwrapping a Tensor weakref.
|
||||
"""
|
||||
|
||||
ref: WeakRef[Tensor]
|
||||
|
||||
def __init__(self, tensor: Tensor):
|
||||
assert isinstance(tensor, Tensor)
|
||||
self.ref = weakref.ref(tensor)
|
||||
|
||||
def __call__(self):
|
||||
out = self.ref()
|
||||
if out is None:
|
||||
return out
|
||||
assert isinstance(out, Tensor)
|
||||
# TODO, add _fix_weakref type binding
|
||||
out._fix_weakref() # type: ignore[attr-defined]
|
||||
return out
|
||||
|
Reference in New Issue
Block a user