cudagraph trees (#89146)

CUDA Graph Trees

Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/edit

Not currently implemented :

- Right now, we are using weak tensor refs from outputs to check if a tensor has dies. This doesn't work because a) aliasing, and b) aot_autograd detaches tensors (see note [Detaching saved tensors in AOTAutograd]). Would need either https://github.com/pytorch/pytorch/issues/91395 to land to use storage weak refs or manually add a deleter fn that does what I want. This is doable but theres some interactions with the caching allocator checkpointing so saving for a stacked pr.

- Reclaiming memory from the inputs during model recording. This isn't terribly difficult but deferring to another PR. You would need to write over the input memory during warmup, and therefore copy the inputs to cpu. Saving for a stacked pr.

- Warning on overwriting previous generation outputs. and handling nested torch.compile() calls in generation tracking

Differential Revision: [D43999887](https://our.internmc.facebook.com/intern/diff/D43999887)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89146
Approved by: https://github.com/ezyang
This commit is contained in:
Elias Ellison
2023-03-16 20:46:16 +00:00
committed by PyTorch MergeBot
parent cf732053e4
commit 571f96bf59
8 changed files with 1983 additions and 39 deletions

View File

@ -0,0 +1,534 @@
# Owner(s): ["module: inductor"]
import contextlib
import functools
import gc
import importlib
import sys
import unittest
import torch
import torch._dynamo
import torch.nn as nn
from torch._inductor import config
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
)
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
aten = torch.ops.aten
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
requires_multigpu = functools.partial(
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
)
class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
torch._dynamo.reset()
super().setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
if HAS_CUDA and not TEST_WITH_ASAN:
def get_all_cudagraph_segments():
segments = torch.cuda.memory_snapshot()
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
def all_live_blocks():
blocks_addrs = []
for segment in get_all_cudagraph_segments():
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
blocks_addrs.append(addr)
addr += block["size"]
return blocks_addrs
def all_live_block_count():
return len(all_live_blocks())
class CudaGraphTreeTests(TestCase):
def setUp(self):
super().setUp()
self.prev_enabled = config.triton.cudagraphs
self.tapes_enabled = config.triton.cudagraph_trees
config.triton.cudagraphs = True
config.triton.cudagraph_trees = True
self.device_idx = torch.rand([0], device="cuda").device.index
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
gc.collect()
config.triton.cudagraphs = self.prev_enabled
config.triton.cudagraph_trees = self.tapes_enabled
self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0)
def get_manager(self, device_index=None):
return torch._inductor.cudagraph_trees.get_container(
(self.device_idx if not device_index else device_index)
).tree_manager
def get_roots(self):
return self.get_manager().get_roots()
def curr_node(self):
return self.get_manager().current_node
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]
def cudagraphify_impl(self, *args, **kwargs):
return tree_cudagraphify_impl(*args, **kwargs, device_index=self.device_idx)
@staticmethod
def run_twc(fn, *args, **kwargs):
fn(*args, **kwargs)
return fn(*args, **kwargs)
def num_checkpoints(self):
return self.get_manager().debug_checkpointing_counter
def test_run_simple(self):
def foo(x):
return x * x * x
foo_opt = torch._dynamo.optimize()(foo)
ones = torch.ones([4, 4], device="cuda")
zeros = torch.zeros([5, 5], device="cuda")
self.run_twc(foo_opt, ones)
self.run_twc(foo_opt, zeros)
self.assertEqual(self.get_root_children(), [0, 0])
def test_function_compiled_multiple_times(self):
def foo(x):
y = foo2(x)
y2 = foo2(y)
return y + y2
def foo2(x):
torch._dynamo.graph_break()
return x * x * x
foo_opt = torch._dynamo.optimize()(foo)
ones = torch.ones([4, 4], device="cuda")
foo(ones)
foo_opt(ones)
foo_opt(ones)
self.assertEqual(foo_opt(ones), foo(ones))
# paths
children = self.get_root_children()
# one root with two children
self.assertEqual(children, [2])
def test_end_recording_early(self):
def foo(x):
y = x * x * x
torch._dynamo.graph_break()
z = x + y
return z
@torch._dynamo.optimize()
def foo2(x):
return x + 4
foo_opt = torch._dynamo.optimize()(foo)
for _ in range(3):
out = foo_opt(torch.ones([4, 4], device="cuda"))
del out
# when I tried inducing separate recordings via graph break,
# the frame kept interferring by keeping outputs alive
# this isnt great by simulates the logic.
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out = foo2(torch.ones([4, 4], device="cuda"))
del out
foo_opt(torch.ones([4, 4], device="cuda"))
# Two separate traces - one has a child, one doesnt
self.assertEqual(self.get_root_children(), [1, 0])
def test_execution_into_recording(self):
def foo(x):
y = x + x
if y.sum() > 0:
return y + 10
else:
return y - 10
foo_opt = torch._dynamo.optimize()(foo)
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
self.assertEqual(foo_opt(inp), foo(inp))
self.assertEqual(foo_opt(inp), foo(inp))
inp.add_(1)
out_eager = foo(inp)
out_warmup = foo_opt(inp)
self.assertEqual(out_warmup, out_eager)
# warmup should be have storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
out_live = foo_opt(inp)
self.assertEqual(out_live, out_eager)
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
# warmup should have been freed
del out_warmup
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
del out_live
self.assertEqual(all_live_block_count(), 0)
out = foo_opt(inp)
self.assertEqual(foo(inp), out)
# should be in execution mode
self.assertEqual(all_live_block_count(), 0)
def test_accumulate_multiple_recordings(self):
def foo(x):
y = x + x + x
torch._dynamo.graph_break()
if y.sum() <= 0:
return y
else:
return y * 10
foo_opt = torch._dynamo.optimize()(foo)
# two separate compilations & recordings
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
# out1 gets manually freed
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
self.assertEqual(all_live_block_count(), 1)
del out1, out2
self.assertEqual(all_live_block_count(), 1)
del out3
gc.collect()
self.assertEqual(all_live_block_count(), 0)
def test_live_outputs_multiple_graphs(self):
def foo(x):
x = x + x + x
y = x + 1
torch._dynamo.graph_break()
z = x * x
if z.sum() > 0:
return y + 1
else:
return y
foo_opt = torch._dynamo.optimize()(foo)
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
self.assertEqual(self.num_checkpoints(), 0)
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
del out
self.assertEqual(all_live_block_count(), 0)
# we need to checkpoint from function to warmup y + 1,
# and then again to record it
self.assertEqual(self.num_checkpoints(), 2)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_dies_between_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
foo_cg = self.cudagraphify_impl(foo, [inp], ())
foo_cg([inp])
foo_cg([inp])
out1, out2 = foo_cg([inp])
inp = [out1]
del out1, out2
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_no_longer_in_pool(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
foo_cg = self.cudagraphify_impl(foo, [inp], ())
x1, x2 = foo_cg([inp])
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
foo2_cg = self.cudagraphify_impl(foo2, [x1], ())
foo2_cg([x1])
del x1, x2
# TODO make configurable
x1, x2 = foo_cg([inp])
self.assertEqual(self.num_checkpoints(), 0)
# input location has changed, should force recompile and checkpointing
foo2_cg([torch.zeros_like(x1)])
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(self.get_root_children(), [2])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_checkpoint_shared_output_storage_deallocation(self):
def foo(args):
x = args[0]
args.clear()
x_tmp = x + 1
return x[0], x[1]
inp = torch.rand([2, 2], device="cuda")
foo_cg = self.cudagraphify_impl(foo, [inp], ())
foo_cg([inp])
foo_cg([inp])
x1, x2 = foo_cg([inp])
inp = [x1]
def foo2(args):
x = args[0]
args.clear()
y = x * x
return y[0], y[1]
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
foo2_cg(inp)
self.assertEqual(self.num_checkpoints(), 1)
self.assertEqual(
x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr()
)
self.assertEqual(all_live_block_count(), 1)
del x1
self.assertEqual(all_live_block_count(), 1)
del x2
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_cleanup(self):
def test_closure():
@torch._dynamo.optimize()
def foo(x):
return x + 1 + 2, x * 10
foo(torch.rand([4], device="cuda"))
return foo(torch.rand([4], device="cuda"))
out1, out2 = test_closure()
torch._dynamo.reset()
# TODO - deallocate on tensor deallocation
# self.assertTrue(self.get_manager() is not None)
# del out1
# self.assertTrue(self.get_manager() is not None)
# del out2
self.assertTrue(self.get_manager() is None)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_forward_backward(self):
@torch._dynamo.optimize()
def foo(x):
y = x * 2
return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4)
inp = torch.rand([4, 4], requires_grad=True, device="cuda")
print("Input ID", id(inp))
out = foo(inp)
out.sum().backward()
self.assertEqual(self.get_root_children(), [1])
# the three saved tensors should die in the backward
# we kept alive the output
self.assertEqual(self.curr_node().expected_dead_indices_before_graph, [])
self.assertEqual(
self.curr_node().expected_dead_indices_after_graph,
[(0, 1), (0, 2), (0, 3)],
)
def test_separate_recordings(self):
def foo_unopt(x, y):
return (x + 1) @ y
foo = torch._dynamo.optimize()(foo_unopt)
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps = [
torch.ones([20, 20], device="cuda", requires_grad=False)
for _ in range(2)
]
out = foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo(*inps)
torch.cuda.synchronize()
foo_unopt(
torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda")
)
inps2 = [
torch.rand([40, 40], device="cuda", requires_grad=False)
for _ in range(2)
]
foo(*inps2)
foo(*inps2)
foo(*inps2)
# two separate roots
self.assertEqual(self.get_root_children(), [0, 0])
def test_alias_of_parameter(self):
class AliasMod(nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda"))
def forward(self, x):
return self.param[0], self.param, self.param + x
@torch.compile(mode="reduce-overhead")
def foo(mod, inp):
return mod(inp)
inp = torch.rand([20, 20], device="cuda")
mod = AliasMod()
storage_ref = torch.multiprocessing.reductions.StorageWeakRef(
mod.param.untyped_storage()
)
for _ in range(3):
outs = foo(mod, inp)
self.assertEqual(mod(inp), outs)
self.assertFalse(storage_ref.expired())
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
@requires_multigpu()
def test_manager_per_device(self):
def test():
def foo(args):
x = args[0]
args.clear()
return x + 3
inp = torch.rand([20, 20], device="cuda:1")
foo_cg = tree_cudagraphify_impl(foo, [inp], (), device_index=1)
self.assertEqual(foo_cg([inp]), foo([inp]))
self.assertTrue(self.get_manager(device_index=0) is None)
self.assertFalse(self.get_manager(device_index=1) is None)
test()
self.assertTrue(self.get_manager(device_index=1) is None)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
run_tests(needs="filelock")

View File

@ -5343,6 +5343,12 @@ class TestBlockStateAbsorption(TestCase):
for before_block, after_block in zip(before_segment["blocks"], after_segment["blocks"]):
self.checkCheckpointedBlock(before_block, after_block)
@staticmethod
def setCheckpointPoolState(device, state, stale_storages_ptr, storages_deleters=None):
stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr]
storages_deleters = [] if not storages_deleters else [t.untyped_storage()._cdata for t in storages_deleters]
torch._C._cuda_setCheckpointPoolState(device, state, stale_storages_ptr, storages_deleters)
def checkFunction(self, fn, inputs, pool=None):
graph, outputs = cudagraphify(fn, inputs, pool=pool)
@ -5352,7 +5358,7 @@ class TestBlockStateAbsorption(TestCase):
segments_before_checkpoint = get_cudagraph_segments(pool_id)
state = torch._C._cuda_getCheckpointState(device, pool_id)
torch._C._cuda_setCheckpointPoolState(device, state, [], [])
self.setCheckpointPoolState(device, state, [], [])
self.checkCheckpointedState(segments_before_checkpoint, get_cudagraph_segments(pool_id))
@ -5422,7 +5428,7 @@ class TestBlockStateAbsorption(TestCase):
graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
torch._C._cuda_setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, [])
del outputs2
@ -5445,7 +5451,7 @@ class TestBlockStateAbsorption(TestCase):
# graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool())
# with self.assertRaisesRegex(Exception, "being manually freed must be passed"):
# torch._C._cuda_setCheckpointPoolState(outputs[0].device.index, state, [], [])
# self.setCheckpointPoolState(outputs[0].device.index, state, [], [])
def test_tensor_dies_after_checkpoint(self):
@ -5463,7 +5469,7 @@ class TestBlockStateAbsorption(TestCase):
del outputs
torch._C._cuda_setCheckpointPoolState(device, state, [], [])
self.setCheckpointPoolState(device, state, [], [])
self.assertEqual(live_blocks(pool_id), 2)
torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0])
@ -5508,7 +5514,7 @@ class TestBlockStateAbsorption(TestCase):
for i in range(len(reconstructed_tensors)):
self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3)
torch._C._cuda_setCheckpointPoolState(device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]])
self.setCheckpointPoolState(device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]])
self.assertEqual(live_blocks(pool_id), 3)

View File

@ -193,9 +193,13 @@ def compile_fx_inner(
and not graph.mutated_inputs
and not has_incompatible_cudagraph_ops(gm)
and not complex_memory_overlap_inputs
and (len(graph.device_idxs) == 1 or not config.triton.cudagraph_trees)
):
compiled_fn = cudagraphify(
compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
compiled_fn,
example_inputs,
static_input_idxs=range(num_fixed),
device_index=next(iter(graph.device_idxs)),
)
else:
BoxedBool.disable(cudagraphs)
@ -209,6 +213,10 @@ def compile_fx_inner(
developer_warning(
"skipping cudagraphs due to complex input striding"
)
elif len(graph.device_idxs) > 1 and config.triton.cudagraph_trees:
developer_warning(
"skipping cudagraphs due to multiple device indexes"
)
result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
_step_logger()(
@ -259,10 +267,21 @@ def align_inputs(model, inputs, static_input_idxs=()):
@dynamo_utils.dynamo_timed
def cudagraphify(model, inputs, static_input_idxs=()):
def cudagraphify(model, inputs, static_input_idxs=(), *, device_index: int):
from torch._inductor.cudagraph_trees import (
cudagraphify_impl as new_cudagraphify_impl,
)
if config.triton.cudagraph_trees:
cudagraphify_fn = functools.partial(
new_cudagraphify_impl, device_index=device_index
)
else:
cudagraphify_fn = cudagraphify_impl
# if using fake tensors, defer cudagraphs until we get real inputs at runtime
if not any(isinstance(inp, FakeTensor) for inp in inputs):
return cudagraphify_impl(model, inputs, static_input_idxs)
return cudagraphify_fn(model, inputs, static_input_idxs)
compiled_fn = None
@ -270,8 +289,7 @@ def cudagraphify(model, inputs, static_input_idxs=()):
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.preserve_rng_state():
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
return run
@ -290,27 +308,28 @@ def remove_unaligned_input_idxs(inputs, static_input_idxs):
return static_input_idxs
def static_input(x):
"""
Copy and input while preserving strides
"""
# TODO(jansel): figure out why this version doesn't work:
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.empty(needed_size, dtype=x.dtype, device=x.device)
return torch.as_strided(buffer, x.size(), x.stride())
def cudagraphify_impl(model, inputs, static_input_idxs=()):
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
def static_input(x):
"""
Copy and input while preserving strides
"""
# TODO(jansel): figure out why this version doesn't work:
# return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
needed_size = (
sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
)
buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
return torch.as_strided(buffer, x.size(), x.stride())
assert isinstance(inputs, (list, tuple))
static_inputs = [
static_input(x) if idx not in static_input_idxs else x.detach()
static_input(x).zero_() if idx not in static_input_idxs else x.detach()
for idx, x in enumerate(inputs)
]

View File

@ -186,6 +186,14 @@ class triton:
# Use cudagraphs on output code
cudagraphs = False
# Use cudagraph trees for memory pooling if `cudagraphs` is True
cudagraph_trees = False
debug_cudagraph_trees = True
# skip warmup for cudagraph trees
skip_cudagraph_warmup = False
# Synchronize before and after every compiled graph.
debug_sync_graph = False

File diff suppressed because it is too large Load Diff

View File

@ -130,6 +130,7 @@ class GraphLowering(torch.fx.Interpreter):
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.graph_outputs: Optional[List[ir.IRNode]] = None
self.device_types: Set[str] = set()
self.device_idxs: Set[int] = set()
self.buffers: List[ir.ComputedBuffer] = []
self.constants: Dict[str, torch.Tensor] = {}
self.removed_buffers: Set[str] = set()
@ -319,6 +320,8 @@ class GraphLowering(torch.fx.Interpreter):
self.graph_inputs[target] = tensor
self.graph_inputs_original[target] = tensor.data.data
self.device_types.add(example.device.type)
if example.device.type == "cuda":
self.device_idxs.add(example.device.index)
return tensor
def call_function(self, target, args, kwargs):

View File

@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAConfig.h>
#include <c10/util/UniqueVoidPtr.h>
#include <unordered_set>
#if AT_CUDNN_ENABLED()
@ -883,8 +884,6 @@ static void registerCudaDeviceProperties(PyObject* module) {
});
}
void no_op_delete(void* ptr){};
// We choose to ignore certain blocks that are currently allocated
// when we set the pool to its checkpoint. For those blocks, we need
// to swap out the deleter function of their corresponding blocks
@ -898,7 +897,7 @@ void removeStorageDeleterFns(
TORCH_CHECK(allocated_pointer != definitely_stale_pointers.end());
auto t = c10::cuda::CUDACachingAllocator::get();
bool succeeded = stale_storage->data_ptr().compare_exchange_deleter(
t->raw_deleter(), &no_op_delete);
t->raw_deleter(), &c10::detail::deleteNothing);
TORCH_CHECK(
succeeded,
@ -907,12 +906,11 @@ void removeStorageDeleterFns(
}
void addStorageDeleterFns(
std::vector<at::Tensor>& tensors_to_add_deleters_to,
std::vector<c10::StorageImpl*>& storages_to_add_deleters_to,
c10::cuda::CUDACachingAllocator::CheckpointDelta& delta) {
std::unordered_map<void*, c10::StorageImpl*> storages;
for (auto& tensor : tensors_to_add_deleters_to) {
storages[tensor.storage().data_ptr().get()] =
tensor.storage().unsafeGetStorageImpl();
for (auto& storage : storages_to_add_deleters_to) {
storages[storage->data_ptr().get()] = storage;
}
for (auto& data_ptr : delta.dataptrs_allocd) {
@ -1044,6 +1042,23 @@ static void registerCudaPluggableAllocator(PyObject* module) {
return c10::cuda::CUDACachingAllocator::getCheckpointState(device, id);
});
m.def("_free_And_Remove_DeleterFn", [](size_t storage_impl_ptr) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
auto alloc = c10::cuda::CUDACachingAllocator::get();
auto data_ptr = storage_impl->data_ptr().get();
bool succeeded = storage_impl->data_ptr().compare_exchange_deleter(
alloc->raw_deleter(), c10::detail::deleteNothing);
TORCH_CHECK("Expected standard deleter");
c10::cuda::CUDACachingAllocator::raw_delete(data_ptr);
});
m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
auto alloc = c10::cuda::CUDACachingAllocator::get();
auto data_ptr = storage_impl->data_ptr().get();
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
});
m.def(
"_cuda_beginAllocateCurrentStreamToPool",
[](int device, at::cuda::MempoolId_t mempool_id) {
@ -1067,17 +1082,16 @@ static void registerCudaPluggableAllocator(PyObject* module) {
"_cuda_setCheckpointPoolState",
[](int device,
std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps,
std::vector<at::Tensor> stale_tensors,
std::vector<at::Tensor> tensors_to_add_deleters_to = {}) {
// Could pass in Storage Pointers instead
std::vector<size_t> stale_storages_ptr,
std::vector<size_t> storages_to_add_deleters_to_ptr = {}) {
std::unordered_set<c10::StorageImpl*> ptr_set;
// iterate on std::vector for determinism
std::vector<c10::StorageImpl*> ptrs;
for (const auto& ten : stale_tensors) {
auto ptr = ten.storage().unsafeGetStorageImpl();
for (size_t ptr_int : stale_storages_ptr) {
c10::StorageImpl* ptr = (c10::StorageImpl*)ptr_int;
if (!ptr_set.count(ptr)) {
ptrs.push_back(ten.storage().unsafeGetStorageImpl());
ptr_set.insert(ten.storage().unsafeGetStorageImpl());
ptrs.push_back(ptr);
ptr_set.insert(ptr);
}
}
auto delta = c10::cuda::CUDACachingAllocator::setCheckpointPoolState(
@ -1107,8 +1121,12 @@ static void registerCudaPluggableAllocator(PyObject* module) {
" must be passed to set checkpoint");
removeStorageDeleterFns(ptrs, freed_pointer_set);
std::vector<c10::StorageImpl*> storages_to_add_deleters_to;
for (size_t ptr_int : storages_to_add_deleters_to_ptr) {
storages_to_add_deleters_to.push_back((c10::StorageImpl*)ptr_int);
}
addStorageDeleterFns(tensors_to_add_deleters_to, delta);
addStorageDeleterFns(storages_to_add_deleters_to, delta);
});
}

View File

@ -1,12 +1,18 @@
from __future__ import annotations
import weakref
from weakref import ref
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import MutableMapping, Mapping
from typing import Dict
from torch import Tensor
import collections.abc as _collections_abc
__all__ = ['WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
WeakRef = ref
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
# This file defines a variant of WeakKeyDictionary that overrides the hashing
@ -261,3 +267,25 @@ class WeakIdKeyDictionary(MutableMapping):
# Convenience alias
WeakTensorKeyDictionary = WeakIdKeyDictionary
class TensorWeakRef:
"""
Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required
when unwrapping a Tensor weakref.
"""
ref: WeakRef[Tensor]
def __init__(self, tensor: Tensor):
assert isinstance(tensor, Tensor)
self.ref = weakref.ref(tensor)
def __call__(self):
out = self.ref()
if out is None:
return out
assert isinstance(out, Tensor)
# TODO, add _fix_weakref type binding
out._fix_weakref() # type: ignore[attr-defined]
return out