[Bug] Fix Negative Cuda Memory Usage (#25683)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-10-01 18:16:26 -04:00
committed by GitHub
parent aac622e0cd
commit da554f932e

View File

@ -3517,7 +3517,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
@contextmanager
def freeze_gc():
@ -3540,6 +3539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
@ -3568,6 +3568,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
torch.cuda.synchronize()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Disable cudagraph capturing globally, so any unexpected cudagraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
@ -3576,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
set_cudagraph_capturing_enabled(False)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.