[CUDA graphs] Pool argument for make_graphed_callables (#121475)

It is just a nice feature to have for the situations when users want multiple graphs captures and/or graphed callables to share the same memory pool.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121475
Approved by: https://github.com/eellison, https://github.com/eqy
This commit is contained in:
Aidyn-A
2024-03-09 00:15:34 +00:00
committed by PyTorch MergeBot
parent b2f19dd284
commit ca9678405a
2 changed files with 45 additions and 3 deletions

View File

@ -2643,6 +2643,46 @@ exit(2)
model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]})
)
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_graph_make_graphed_callables_same_pool(self):
torch.manual_seed(5)
torch.cuda.manual_seed(5)
models = []
num_models = 3
for _ in range(num_models):
models.append(
torch.nn.Sequential(
torch.nn.Linear(32, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
).cuda()
)
# we will reuse the same pool for all graph captures
mempool = torch.cuda.graph_pool_handle()
graphed_models = []
for model in models:
x = torch.randn([64, 32], device="cuda")
graphed_model = deepcopy(model)
graphed_model = torch.cuda.make_graphed_callables(graphed_model, (x,), pool=mempool)
graphed_models.append(graphed_model)
for model, graphed_model in zip(models, graphed_models):
x = torch.randn([64, 32], device="cuda")
y = model(x)
yg = graphed_model(x)
l = y.norm()
lg = yg.norm()
l.backward()
lg.backward()
self.assertEqual(y, yg)
self.assertEqual(l, lg)
for p, pg in zip(model.parameters(), graphed_model.parameters()):
self.assertEqual(p, pg)
self.assertEqual(p.grad, pg.grad)
self.assertNotEqual(p.data_ptr(), pg.data_ptr())
self.assertNotEqual(p.grad.data_ptr, pg.grad.data_ptr)
def _test_graphed_optimizer(self, steps_warmup, steps_train, optimizer_ctor, kwargs):
for actually_do_graphs in (True, False):
params = [

View File

@ -187,7 +187,7 @@ class graph:
def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
@ -218,7 +218,9 @@ def make_graphed_callables(
11 iterations for warm up. Default: ``3``.
allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
(and therefore their grad is always zero) is an error. Defaults to False.
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
.. note::
The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
that's expected for the corresponding real input in the training loop.
@ -304,7 +306,7 @@ def make_graphed_callables(
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
mempool = graph_pool_handle()
mempool = graph_pool_handle() if pool is None else pool
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work