mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDAGraph] add config to error on skipping cudagraph (#161862)
Many users want a config to force all cuda ops captured by cudagraph. When not possible, pt2 should error. This PR adds `torch._inductor.triton.cudagraph_or_error` for that (default as False). Also added an environment variable `TORCHINDUCTOR_CUDAGRAPH_OR_ERROR` to control. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161862 Approved by: https://github.com/ezyang, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
b7dad7dd49
commit
601ae8e483
@ -3937,6 +3937,17 @@ if HAS_CUDA_AND_TRITON:
|
||||
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 4)
|
||||
|
||||
@torch._inductor.config.patch("triton.cudagraph_or_error", True)
|
||||
def test_cudagraph_or_error(self):
|
||||
def f(x):
|
||||
x.add_(1)
|
||||
return x
|
||||
|
||||
f = torch.compile(f, mode="reduce-overhead")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
f(torch.tensor(1, device="cuda"))
|
||||
|
||||
class TestSAC(TestCase):
|
||||
def _make_observer_mode(self):
|
||||
class ObserverMode(TorchDispatchMode):
|
||||
|
@ -1242,6 +1242,15 @@ class triton:
|
||||
# instead of recording and executing cudagraphs
|
||||
force_cudagraphs_warmup = False
|
||||
|
||||
# If False (default), torch.compile skips cudagraph for a graph if it
|
||||
# contains cudagraph-unsafe ops. If True, we require that all cuda ops
|
||||
# be captured into cudagraph. If this is not possible, this will raise
|
||||
# an error.
|
||||
cudagraph_or_error: bool = Config(
|
||||
env_name_force="TORCHINDUCTOR_CUDAGRAPH_OR_ERROR",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# assertions on the fast path
|
||||
fast_path_cudagraph_asserts = False
|
||||
|
||||
|
@ -204,6 +204,10 @@ def check_lowering_disable_cudagraph(
|
||||
def log_cudagraph_skip_and_bump_counter(msg: str) -> None:
|
||||
perf_hint_log.warning(msg)
|
||||
counters["inductor"]["cudagraph_skips"] += 1
|
||||
|
||||
if torch._inductor.config.triton.cudagraph_or_error:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
if metrics_context.in_progress():
|
||||
metrics_context.set("cudagraph_skip_reason", msg, overwrite=True)
|
||||
|
Reference in New Issue
Block a user