[CUDAGraph] add config to error on skipping cudagraph (#161862)

Many users want a config to force all cuda ops captured by cudagraph. When not possible, pt2 should error.

This PR adds `torch._inductor.triton.cudagraph_or_error` for that (default as False). Also added an environment variable `TORCHINDUCTOR_CUDAGRAPH_OR_ERROR` to control.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161862
Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
Boyuan Feng
2025-09-04 15:52:39 +00:00
committed by PyTorch MergeBot
parent b7dad7dd49
commit 601ae8e483
3 changed files with 24 additions and 0 deletions

View File

@ -3937,6 +3937,17 @@ if HAS_CUDA_AND_TRITON:
self.assertEqual(self.get_manager().new_graph_id().id, 4)
@torch._inductor.config.patch("triton.cudagraph_or_error", True)
def test_cudagraph_or_error(self):
def f(x):
x.add_(1)
return x
f = torch.compile(f, mode="reduce-overhead")
with self.assertRaises(RuntimeError):
f(torch.tensor(1, device="cuda"))
class TestSAC(TestCase):
def _make_observer_mode(self):
class ObserverMode(TorchDispatchMode):

View File

@ -1242,6 +1242,15 @@ class triton:
# instead of recording and executing cudagraphs
force_cudagraphs_warmup = False
# If False (default), torch.compile skips cudagraph for a graph if it
# contains cudagraph-unsafe ops. If True, we require that all cuda ops
# be captured into cudagraph. If this is not possible, this will raise
# an error.
cudagraph_or_error: bool = Config(
env_name_force="TORCHINDUCTOR_CUDAGRAPH_OR_ERROR",
default=False,
)
# assertions on the fast path
fast_path_cudagraph_asserts = False

View File

@ -204,6 +204,10 @@ def check_lowering_disable_cudagraph(
def log_cudagraph_skip_and_bump_counter(msg: str) -> None:
perf_hint_log.warning(msg)
counters["inductor"]["cudagraph_skips"] += 1
if torch._inductor.config.triton.cudagraph_or_error:
raise RuntimeError(msg)
metrics_context = get_metrics_context()
if metrics_context.in_progress():
metrics_context.set("cudagraph_skip_reason", msg, overwrite=True)