mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[inductor] fix compile time regression by caching get_gpu_type (#128363)
We observed signficant compile time regression in torchtitan when turning on 2D parallel + torch.compile recently. So I decided to get a deeper understanding why. It turns out this is affecting **all the trainings** that have functional collectives captured in the graph, not only 2D parallel (2D parallel was just the job that happen to have collectives captured in the TP region). The root cause is because when doing inductor lowering, we are calling the comm analysis pass to get a estimated collective time for each collective node in the graph, for each call to check the collective node, we are calling `get_gpu_type()`, which under the hood calls a `torch.utils.collect_env.run` to get the GPU info. However, this call is super expensive! The reason is that this call effectively spawns a new process and call `nvidia-smi` to get the GPU info, so the cost is **linear** to the number of collective nodes in the graph. see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75 The fix is to add a lru cache to the function, so that we only call this once and reuse the cached results afterwards torchtitan benchmark shows: * before this fix: 2D parallel + fp8 compile time: 6min + * after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement) There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128363 Approved by: https://github.com/yf225
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d233b8f50
commit
8a09940a54
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import math
|
||||
from enum import IntEnum
|
||||
|
||||
@ -22,6 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
|
||||
HOPPER = 2
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
|
||||
if "V100" in gpu_info:
|
||||
|
Reference in New Issue
Block a user