mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve device info with new flops and bandwidth formula based on hardware libraries (#162245)
Previously, DeviceInfo provided theoretical hardware information based on a hardcoded list manually created from various datasheets. This update: - Attempting to gather the information from a hardware library like `pynvml`, improving accuracy and expanding support to devices that don't have entries in the datasheet list. - Adjusts flops and bw calculation based on these hardware values. For example, if the the memory or SMs are underclocked, it adjusts the theoretical max flops/bw accordingly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162245 Approved by: https://github.com/v0i0, https://github.com/shunting314
This commit is contained in:
committed by
PyTorch MergeBot
parent
0663bdb123
commit
35d7b32159
@ -60,7 +60,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.analysis.device_info import DeviceInfo
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -2381,7 +2381,9 @@ def get_device_tflops(dtype: torch.dtype) -> float:
|
||||
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
|
||||
then fall back to the inaccurate triton estimation.
|
||||
"""
|
||||
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
|
||||
ds_tops = DeviceInfo.lookup_tops_current_device(
|
||||
dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32
|
||||
)
|
||||
if ds_tops is not None:
|
||||
return ds_tops
|
||||
|
||||
|
||||
Reference in New Issue
Block a user