Improve device info with new flops and bandwidth formula based on hardware libraries (#162245)

Previously, DeviceInfo provided theoretical hardware information based on a hardcoded list manually created from various datasheets.

This update:
- Attempting to gather the information from a hardware library like `pynvml`, improving accuracy and expanding support to devices that don't have entries in the datasheet list.
- Adjusts flops and bw calculation based on these hardware values. For example, if the the memory or SMs are underclocked, it adjusts the theoretical max flops/bw accordingly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162245
Approved by: https://github.com/v0i0, https://github.com/shunting314
This commit is contained in:
Gabriel Ferns
2025-09-10 21:19:09 +00:00
committed by PyTorch MergeBot
parent 0663bdb123
commit 35d7b32159
5 changed files with 1075 additions and 24 deletions

View File

@ -60,7 +60,7 @@ import sympy
import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.analysis.device_info import DeviceInfo
from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._ordered_set import OrderedSet
@ -2381,7 +2381,9 @@ def get_device_tflops(dtype: torch.dtype) -> float:
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
then fall back to the inaccurate triton estimation.
"""
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
ds_tops = DeviceInfo.lookup_tops_current_device(
dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32
)
if ds_tops is not None:
return ds_tops