mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[profiler] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#150957)
Credit to @mgmtea who wrote the initial version of this PR: https://github.com/pytorch/pytorch/pull/146604 Context: CUPTI is the NVIDIA library that Kineto uses for collecting GPU-side info during profiling. The intended usage is to register a callback while you want profiling to occur, and then unregister the callback when you want profiling to stop. But a bug would cause crashes if CUPTI callbacks were de-registered when used with cudagraphs. The workaround was to disable "CUPTI_LAZY_REINIT" and "CUPTI_TEARDOWN" in Kineto - which prevents crashes, but can result in slower execution after profiling has occurred and completed. This bug is believed to be fixed in CUDA >= 12.6, so this PR qualifies that DISABLE_CUPTI_LAZY_REINIT=1 and CUPTI_TEARDOWN=0 should only be applied if CUDA >= 12.6. Additionally, `profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()` is added as an escape hatch so that we can add a killswitch in case we see more crashes related to this. Differential Revision: [D72745929](https://our.internmc.facebook.com/intern/diff/D72745929) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150957 Approved by: https://github.com/aaronenyeshi, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
6720d23969
commit
37812009fd
@ -54,6 +54,7 @@ from torch._utils import (
|
||||
from torch._utils_internal import (
|
||||
get_file_path,
|
||||
prepare_multiprocessing_environment,
|
||||
profiler_allow_cudagraph_cupti_lazy_reinit_cuda12,
|
||||
USE_GLOBAL_DEPS,
|
||||
USE_RTLD_GLOBAL_WITH_LIBTORCH,
|
||||
)
|
||||
@ -2294,6 +2295,7 @@ class _TorchCompileInductorWrapper:
|
||||
|
||||
def __init__(self, mode, options, dynamic):
|
||||
from torch._inductor.compiler_bisector import CompilerBisector
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
self.config: dict[str, _Any] = {}
|
||||
self.dynamic = dynamic
|
||||
@ -2301,7 +2303,13 @@ class _TorchCompileInductorWrapper:
|
||||
self.apply_options(options)
|
||||
self.apply_options(CompilerBisector.get_config_change("inductor"))
|
||||
|
||||
if self.config.get("triton.cudagraphs", False):
|
||||
if self.config.get("triton.cudagraphs", False) and (
|
||||
(
|
||||
getattr(torch.version, "cuda", None)
|
||||
and TorchVersion(torch.version.cuda) < "12.6"
|
||||
)
|
||||
or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
|
||||
):
|
||||
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
|
||||
# FIXME: CUDA Graph does not work well with CUPTI teardown.
|
||||
# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
|
||||
|
Reference in New Issue
Block a user