Files
pytorch/test/functorch/test_control_flow_cuda_initialization.py
Daniel Galvez c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00

91 lines
2.9 KiB
Python

# Owner(s): ["module: functorch"]
import unittest
import torch
from torch.testing._internal.common_utils import (
run_tests,
TEST_CUDA_GRAPH_CONDITIONAL_NODES,
TestCase,
)
@unittest.skipIf(
not TEST_CUDA_GRAPH_CONDITIONAL_NODES,
"CUDA 12.4 or greater is required for CUDA Graphs with conditional nodes",
)
class TestControlFlowInCUDAGraphInitialization(TestCase):
# Duplicated from test_cuda_primary_ctx.py
CTX_ALREADY_CREATED_ERR_MSG = (
"Tests defined in TestControlFlowInCUDAGraphInitialization must be run in a process "
"where CUDA contexts are never created. Use either run_test.py or add "
"--subprocess to run each test in a different subprocess."
)
def setUp(self):
# Ensure context has not been created beforehand
self.assertFalse(
torch._C._cuda_hasPrimaryContext(0),
TestControlFlowInCUDAGraphInitialization.CTX_ALREADY_CREATED_ERR_MSG,
)
def _check_compile_cudagraphs(self, f, pred, *other_args):
f = torch.compile(f, backend="cudagraphs")
outputs = []
for p in [pred, torch.logical_not(pred)]:
for i in range(3):
outputs.append(f(pred, *other_args).clone())
# We compute the eager output only after running cudagraphs
# backend compiled function, in order to make sure that
# cudagraph trees warms up the conditional part of the code
# properly.
eager_output = f(pred, *other_args)
for output in outputs:
self.assertEqual(output, eager_output)
def test_cond_cudnn(self):
# Tests that cublasCreate() does not break stream capture
def f(pred, x, filters):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.nn.functional.conv1d(y, filters)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.randn(33, 16, 30, device="cuda")
filters = torch.randn(20, 16, 5, device="cuda")
self._check_compile_cudagraphs(f, pred, x, filters)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
def test_cond_stft(self):
# Tests that cufft plan creation does not break stream capture
def f(pred, x):
return torch.cond(
pred,
lambda y: torch.sum(y),
lambda y: torch.sum(torch.stft(y, 512, return_complex=False)),
[x],
)
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
pred = torch.tensor(True, device="cuda")
x = torch.ones(1024 * 1024, device="cuda")
self._check_compile_cudagraphs(f, pred, x)
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
if __name__ == "__main__":
run_tests()