mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Improve device info with new flops and bandwidth formula based on hardware libraries (#162245)"
This reverts commit 35d7b321597ed00245aad533a8fa6b7fdadd73ea. Reverted https://github.com/pytorch/pytorch/pull/162245 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/162245#issuecomment-3313669412))
This commit is contained in:
@ -289,12 +289,14 @@ class TestAnalysis(TestCase):
|
||||
om = _test_model(device, dtype)
|
||||
REPEAT = 5
|
||||
trace1, trace2 = trace_files()
|
||||
print(f"first trace {trace1}")
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
om()
|
||||
p.export_chrome_trace(trace1)
|
||||
|
||||
print(f"second trace {trace2}")
|
||||
torch._dynamo.reset() # reset the cache
|
||||
with fresh_inductor_cache():
|
||||
with torch.profiler.profile(record_shapes=True) as p:
|
||||
@ -302,6 +304,7 @@ class TestAnalysis(TestCase):
|
||||
om()
|
||||
p.export_chrome_trace(trace2)
|
||||
|
||||
print("diffing...")
|
||||
with patch(
|
||||
"sys.argv",
|
||||
[
|
||||
|
@ -1,697 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import (
|
||||
_get_amd_smi,
|
||||
_get_pynvml,
|
||||
datasheet_tops,
|
||||
DeviceInfo,
|
||||
DeviceSpec,
|
||||
lookup_device_info,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestDeviceInfo(TestCase):
|
||||
def test_lookup_device_info(self):
|
||||
h100_info = lookup_device_info("NVIDIA H100")
|
||||
self.assertIsNotNone(h100_info)
|
||||
if h100_info is not None:
|
||||
self.assertEqual(h100_info.dram_gb, 80)
|
||||
self.assertIn(torch.float32, h100_info.tops)
|
||||
|
||||
unknown_info = lookup_device_info("Unknown Device")
|
||||
self.assertIsNone(unknown_info)
|
||||
|
||||
def test_datasheet_tops_function(self):
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
tops = datasheet_tops(torch.float32)
|
||||
self.assertIsNotNone(tops)
|
||||
self.assertEqual(tops, 67.0)
|
||||
|
||||
tops_tf32 = datasheet_tops(torch.float32, is_tf32=True)
|
||||
self.assertEqual(tops_tf32, 989.0)
|
||||
|
||||
mock_get_device_name.return_value = "Unknown Device"
|
||||
tops_unknown = datasheet_tops(torch.float32)
|
||||
self.assertIsNone(tops_unknown)
|
||||
|
||||
mock_get_device_name.return_value = None
|
||||
tops_no_device = datasheet_tops(torch.float32)
|
||||
self.assertIsNone(tops_no_device)
|
||||
|
||||
@unittest.skipIf(torch.version.hip, "only nvidia")
|
||||
def test_lazy_pynvml_import(self):
|
||||
"""Test pynvml import through torch.cuda."""
|
||||
with (
|
||||
patch("torch.cuda._HAS_PYNVML", True),
|
||||
patch.object(torch.cuda, "pynvml", MagicMock(), create=True) as mock_pynvml,
|
||||
):
|
||||
pynvml = _get_pynvml()
|
||||
self.assertEqual(pynvml, mock_pynvml)
|
||||
|
||||
with patch("torch.cuda._HAS_PYNVML", False):
|
||||
pynvml = _get_pynvml()
|
||||
self.assertIsNone(pynvml)
|
||||
|
||||
@patch("torch.version.hip", None)
|
||||
@patch("torch._inductor.analysis.device_info._get_pynvml")
|
||||
def test_hardware_lookup_clock_hz_success(self, mock_get_pynvml):
|
||||
mock_pynvml = MagicMock()
|
||||
mock_pynvml.nvmlInit = MagicMock()
|
||||
mock_pynvml.nvmlDeviceGetHandleByIndex.return_value = "mock_handle"
|
||||
mock_pynvml.nvmlDeviceGetMaxClockInfo.return_value = 1500
|
||||
mock_pynvml.NVML_CLOCK_SM = "clock_key"
|
||||
mock_pynvml.nvmlShutdown = MagicMock()
|
||||
mock_get_pynvml.return_value = mock_pynvml
|
||||
|
||||
result = DeviceInfo._hardware_lookup_clock_hz()
|
||||
self.assertEqual(result, 1500 * 1e6)
|
||||
|
||||
@unittest.skipIf(torch.version.hip, "only nvidia")
|
||||
def test_lazy_pynvml_import_caching(self):
|
||||
"""Test pynvml caching through torch.cuda (now handled by torch.cuda module)."""
|
||||
with (
|
||||
patch("torch.cuda._HAS_PYNVML", True),
|
||||
patch.object(torch.cuda, "pynvml", MagicMock(), create=True) as mock_pynvml,
|
||||
):
|
||||
pynvml1 = _get_pynvml()
|
||||
self.assertEqual(pynvml1, mock_pynvml)
|
||||
|
||||
pynvml2 = _get_pynvml()
|
||||
self.assertEqual(pynvml2, mock_pynvml)
|
||||
|
||||
self.assertEqual(pynvml1, pynvml2)
|
||||
|
||||
def test_hardware_lookup_exception_handling(self):
|
||||
with (
|
||||
patch("torch.version.hip", None),
|
||||
patch(
|
||||
"torch.cuda.get_device_properties", side_effect=Exception("CUDA Error")
|
||||
),
|
||||
patch(
|
||||
"torch._inductor.analysis.device_info._get_pynvml"
|
||||
) as mock_get_pynvml,
|
||||
):
|
||||
mock_pynvml = MagicMock()
|
||||
mock_pynvml.nvmlInit.side_effect = Exception("NVML Error")
|
||||
mock_get_pynvml.return_value = mock_pynvml
|
||||
|
||||
# Test direct hardware lookup methods, not the generic lookup methods
|
||||
result = DeviceInfo._hardware_lookup_sm_count()
|
||||
self.assertIsNone(result)
|
||||
|
||||
result = DeviceInfo._hardware_lookup_clock_hz()
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_device_mapping_aliases(self):
|
||||
mi300x_direct = lookup_device_info("AMD MI300X")
|
||||
mi300x_alias = lookup_device_info("AMD INSTINCT MI300X")
|
||||
self.assertEqual(mi300x_direct, mi300x_alias)
|
||||
|
||||
mi210x_direct = lookup_device_info("AMD MI210X")
|
||||
mi210x_alias = lookup_device_info("AMD INSTINCT MI210X")
|
||||
self.assertEqual(mi210x_direct, mi210x_alias)
|
||||
|
||||
def test_lazy_amd_smi_import_success(self):
|
||||
"""Test AMD SMI import through torch.cuda."""
|
||||
with patch("torch.cuda._HAS_PYNVML", False):
|
||||
amd_smi = _get_amd_smi()
|
||||
self.assertIsNone(amd_smi)
|
||||
|
||||
def test_lazy_amd_smi_import_caching(self):
|
||||
"""Test AMD SMI caching through torch.cuda."""
|
||||
# Test consistent behavior across multiple calls
|
||||
with patch("torch.cuda._HAS_PYNVML", True):
|
||||
amd_smi1 = _get_amd_smi()
|
||||
amd_smi2 = _get_amd_smi()
|
||||
# Both should return the same result (None in this environment)
|
||||
self.assertEqual(amd_smi1, amd_smi2)
|
||||
|
||||
with patch("torch.cuda._HAS_PYNVML", False):
|
||||
amd_smi1 = _get_amd_smi()
|
||||
amd_smi2 = _get_amd_smi()
|
||||
self.assertEqual(amd_smi1, amd_smi2)
|
||||
self.assertIsNone(amd_smi1)
|
||||
|
||||
def test_amd_device_mapping_entries(self):
|
||||
"""Test that AMD devices are properly represented in device mapping."""
|
||||
mi300x = lookup_device_info("AMD MI300X")
|
||||
self.assertIsNotNone(mi300x)
|
||||
if mi300x is not None:
|
||||
self.assertEqual(mi300x.dram_gb, 192.0)
|
||||
self.assertEqual(mi300x.dram_bw_gbs, 5300.0)
|
||||
self.assertIn(torch.float32, mi300x.tops)
|
||||
|
||||
mi300x_instinct = lookup_device_info("AMD INSTINCT MI300X")
|
||||
self.assertEqual(mi300x, mi300x_instinct)
|
||||
|
||||
mi300a = lookup_device_info("AMD MI300A")
|
||||
self.assertIsNotNone(mi300a)
|
||||
if mi300a is not None:
|
||||
self.assertEqual(mi300a.dram_gb, 128.0)
|
||||
self.assertEqual(mi300a.dram_bw_gbs, 5300.0)
|
||||
|
||||
mi210x = lookup_device_info("AMD MI210X")
|
||||
self.assertIsNotNone(mi210x)
|
||||
if mi210x is not None:
|
||||
self.assertEqual(mi210x.dram_gb, 64.0)
|
||||
self.assertEqual(mi210x.dram_bw_gbs, 1600.0)
|
||||
|
||||
mi210x_instinct = lookup_device_info("AMD INSTINCT MI210X")
|
||||
self.assertEqual(mi210x, mi210x_instinct)
|
||||
|
||||
def test_amd_integration_with_datasheet_tops(self):
|
||||
"""Test datasheet_tops function with AMD devices."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "AMD MI300X"
|
||||
|
||||
tops_fp32 = datasheet_tops(torch.float32)
|
||||
self.assertEqual(tops_fp32, 163.4)
|
||||
|
||||
tops_fp16 = datasheet_tops(torch.float16)
|
||||
self.assertEqual(tops_fp16, 1307.4)
|
||||
|
||||
tops_bf16 = datasheet_tops(torch.bfloat16)
|
||||
self.assertEqual(tops_bf16, 1307.4)
|
||||
|
||||
tops_tf32 = datasheet_tops(torch.float32, is_tf32=True)
|
||||
self.assertEqual(tops_tf32, 653.7)
|
||||
|
||||
def test_flops_hardware_calculation(self):
|
||||
"""Test FLOPS calculation now uses datasheet values with clock adjustment."""
|
||||
with (
|
||||
patch.object(DeviceInfo, "lookup_clock_hz", return_value=1.5e9),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="AMD MI300X"),
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops(
|
||||
device_name="AMD MI300X", dtype=torch.float32
|
||||
)
|
||||
# Now uses datasheet value (163.4 TOPS) with clock adjustment
|
||||
# Device mapping has clock_hz=2100*1e6, so ratio = 1.5e9 / (2100*1e6) = ~0.714
|
||||
datasheet_flops = 163.4 * 1e12
|
||||
device_info = lookup_device_info("AMD MI300X")
|
||||
if device_info and device_info.clock_hz:
|
||||
clock_ratio = 1.5e9 / device_info.clock_hz
|
||||
expected_flops = datasheet_flops * clock_ratio
|
||||
else:
|
||||
expected_flops = datasheet_flops
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_datasheet_calculation(self):
|
||||
"""Test FLOPS calculation using datasheet TOPS."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
|
||||
), # Use datasheet clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
flops = DeviceInfo.lookup_tops(
|
||||
device_name="NVIDIA H100", dtype=torch.float32
|
||||
)
|
||||
expected_flops = 67.0 * 1e12 / 2
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_fallback_to_datasheet(self):
|
||||
"""Test FLOPS fallback to datasheet when hardware lookup fails."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
|
||||
), # Use datasheet clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
flops = DeviceInfo.lookup_tops(
|
||||
device_name="NVIDIA H100", dtype=torch.float32
|
||||
)
|
||||
expected_flops = 67.0 * 1e12 / 2
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_clock_adjustment_in_fallback(self):
|
||||
"""Test clock adjustment when falling back to datasheet."""
|
||||
custom_device_info = DeviceSpec(
|
||||
memory_clock_hz=100,
|
||||
tops={torch.float32: 100.0},
|
||||
dram_bw_gbs=1000.0,
|
||||
dram_gb=16.0,
|
||||
sm_count=None,
|
||||
clock_hz=1.5e9,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch(
|
||||
"torch._inductor.analysis.device_info.lookup_device_info"
|
||||
) as mock_lookup,
|
||||
):
|
||||
mock_get_device_name.return_value = "Custom Device"
|
||||
mock_lookup.return_value = custom_device_info
|
||||
|
||||
with patch.object(
|
||||
DeviceInfo, "_hardware_lookup_clock_hz", return_value=3.0e9
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops("Custom Device", dtype=torch.float32)
|
||||
|
||||
datasheet_flops = 100.0 * 1e12
|
||||
clock_ratio = 3.0e9 / 1.5e9
|
||||
expected_flops = datasheet_flops * clock_ratio
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
@patch("torch._inductor.analysis.device_info.lookup_device_info")
|
||||
def test_flops_clock_adjustment_no_expected_clock(self, mock_lookup):
|
||||
"""Test fallback behavior when device mapping has None for clock_hz."""
|
||||
device_info = DeviceSpec(
|
||||
memory_clock_hz=100,
|
||||
tops={torch.float32: 100.0},
|
||||
dram_bw_gbs=1000.0,
|
||||
dram_gb=16.0,
|
||||
sm_count=None,
|
||||
clock_hz=None,
|
||||
)
|
||||
mock_lookup.return_value = device_info
|
||||
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
with patch.object(
|
||||
DeviceInfo, "_hardware_lookup_clock_hz", return_value=3.0e9
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
|
||||
|
||||
expected_flops = 100.0 * 1e12
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_clock_adjustment_none_clock(self):
|
||||
"""Test fallback behavior when clock lookup returns None."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
with patch.object(
|
||||
DeviceInfo, "_hardware_lookup_clock_hz", return_value=None
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
|
||||
|
||||
expected_flops = 67.0 * 1e12
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_no_device_name(self):
|
||||
"""Test FLOPS calculation when device name is unavailable."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name", return_value=None),
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
):
|
||||
# When there's no device name and we force datasheet, it should return None
|
||||
with patch(
|
||||
"torch._inductor.analysis.device_info.datasheet_tops", return_value=None
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
|
||||
self.assertIsNone(flops)
|
||||
|
||||
# When cuda is not available, hardware lookup is skipped and datasheet is used
|
||||
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
|
||||
self.assertIsNone(
|
||||
flops
|
||||
) # Should be None since cuda.is_available() is False
|
||||
|
||||
def test_flops_unknown_device(self):
|
||||
"""Test FLOPS calculation with unknown device."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "Unknown Device"
|
||||
|
||||
flops = DeviceInfo.lookup_tops("Unknown Device", dtype=torch.float32)
|
||||
# Should be None for unknown device
|
||||
self.assertIsNone(flops)
|
||||
|
||||
def test_flops_partial_hardware_values(self):
|
||||
"""Test FLOPS calculation with some hardware values missing."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
|
||||
), # Use datasheet clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
flops = DeviceInfo.lookup_tops(
|
||||
device_name="NVIDIA H100", dtype=torch.float32
|
||||
)
|
||||
expected_flops = 67.0 * 1e12 / 2
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_exception_handling(self):
|
||||
"""Test FLOPS calculation handles exceptions gracefully."""
|
||||
with (
|
||||
patch.object(
|
||||
DeviceInfo,
|
||||
"_hardware_lookup_sm_count",
|
||||
side_effect=Exception("Hardware error"),
|
||||
),
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
|
||||
), # Use datasheet clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
|
||||
expected_flops = 67.0 * 1e12 / 2
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
def test_flops_integration_with_hardware_lookup(self):
|
||||
"""Test FLOPS integration with datasheet values and clock adjustment."""
|
||||
dn = "NVIDIA H100"
|
||||
|
||||
with (
|
||||
patch.object(DeviceInfo, "lookup_clock_hz", return_value=1500 * 1e6),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value=dn),
|
||||
):
|
||||
flops = DeviceInfo.lookup_tops(device_name=dn, dtype=torch.float32)
|
||||
# Now uses datasheet value (67.0 TOPS) with clock adjustment
|
||||
# Device mapping has clock_hz=1.98e9, so ratio = 1500*1e6 / 1.98e9 = ~0.7576
|
||||
datasheet_flops = 67.0 * 1e12
|
||||
device_info = lookup_device_info(dn)
|
||||
if device_info and device_info.clock_hz:
|
||||
clock_ratio = (1500 * 1e6) / device_info.clock_hz
|
||||
expected_flops = datasheet_flops * clock_ratio
|
||||
else:
|
||||
expected_flops = datasheet_flops
|
||||
self.assertEqual(flops, expected_flops)
|
||||
|
||||
@unittest.skipIf(
|
||||
True, "pynvml and amdsmi are not available in CI, run these tests locally"
|
||||
)
|
||||
@unittest.skipIf(torch.version.hip, "only nvidia")
|
||||
def test_pynvml_integration(self):
|
||||
"""Test direct pynvml library integration."""
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
# Test basic NVML initialization and device access
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
|
||||
# Test clock frequency retrieval
|
||||
sm_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle, pynvml.NVML_CLOCK_SM
|
||||
)
|
||||
self.assertIsInstance(sm_clock_mhz, int)
|
||||
self.assertGreater(sm_clock_mhz, 0)
|
||||
|
||||
# Test memory clock frequency retrieval
|
||||
mem_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle, pynvml.NVML_CLOCK_MEM
|
||||
)
|
||||
self.assertIsInstance(mem_clock_mhz, int)
|
||||
self.assertGreater(mem_clock_mhz, 0)
|
||||
|
||||
# Test memory bus width retrieval
|
||||
bus_width_bits = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
|
||||
self.assertIsInstance(bus_width_bits, int)
|
||||
self.assertGreater(bus_width_bits, 0)
|
||||
|
||||
# Test bandwidth calculation (same as device_info.py implementation)
|
||||
mem_clock_hz = mem_clock_mhz * 1e6
|
||||
effective_rate = mem_clock_hz * 2 # GDDR uses DDR so *2
|
||||
peak_bw = (effective_rate * bus_width_bits) / 8
|
||||
peak_bw_gbs = peak_bw / (1024**3)
|
||||
|
||||
self.assertIsInstance(peak_bw_gbs, float)
|
||||
self.assertGreater(peak_bw_gbs, 0)
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
except ImportError:
|
||||
self.fail(
|
||||
"pynvml library not available - install with 'pip install nvidia-ml-py'"
|
||||
)
|
||||
except Exception as e:
|
||||
self.fail(f"pynvml integration failed: {e}")
|
||||
|
||||
@unittest.skipIf(
|
||||
True, "pynvml and amdsmi are not available in CI, run these tests locally"
|
||||
)
|
||||
@unittest.skipIf(not torch.version.hip, "only amd")
|
||||
def test_amdsmi_integration(self):
|
||||
"""Test direct amdsmi library integration."""
|
||||
try:
|
||||
import amdsmi
|
||||
|
||||
# Test basic AMD SMI initialization
|
||||
amdsmi.amdsmi_init()
|
||||
|
||||
# Test device handle retrieval (matches current implementation)
|
||||
device_handle = amdsmi.amdsmi_get_processor_handles()[0]
|
||||
self.assertIsNotNone(device_handle)
|
||||
|
||||
# Test GPU clock info retrieval (matches current implementation)
|
||||
clock_info = amdsmi.amdsmi_get_clock_info(
|
||||
device_handle, amdsmi.AmdSmiClkType.SYS
|
||||
)
|
||||
self.assertTrue("max_clk" in clock_info)
|
||||
self.assertIsInstance(clock_info["max_clk"], int)
|
||||
self.assertGreater(clock_info["max_clk"], 0)
|
||||
|
||||
# Test GPU memory clock info retrieval (matches current implementation)
|
||||
mem_clock_info = amdsmi.amdsmi_get_clock_info(
|
||||
device_handle, amdsmi.AmdSmiClkType.MEM
|
||||
)
|
||||
self.assertTrue("max_clk" in mem_clock_info)
|
||||
self.assertIsInstance(mem_clock_info["max_clk"], int)
|
||||
self.assertGreater(mem_clock_info["max_clk"], 0)
|
||||
|
||||
amdsmi.amdsmi_shut_down()
|
||||
|
||||
except ImportError:
|
||||
self.fail("amdsmi library not available - install AMD SMI")
|
||||
except Exception as e:
|
||||
self.fail(f"amdsmi integration failed: {e}")
|
||||
|
||||
@unittest.skipIf(
|
||||
True, "pynvml and amdsmi are not available in CI, run these tests locally"
|
||||
)
|
||||
@unittest.skipIf(torch.version.hip, "only amd")
|
||||
def test_pynvml_error_handling(self):
|
||||
"""Test pynvml error handling for invalid operations."""
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
|
||||
# Test invalid device index - should raise exception
|
||||
with self.assertRaises(Exception):
|
||||
pynvml.nvmlDeviceGetHandleByIndex(999) # Invalid index
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
except ImportError:
|
||||
self.skipTest("pynvml library not available")
|
||||
|
||||
@unittest.skipIf(
|
||||
True, "pynvml and amdsmi are not available in CI, run these tests locally"
|
||||
)
|
||||
@unittest.skipIf(not torch.version.hip, "only nvidia")
|
||||
def test_amd_smi_error_handling(self):
|
||||
"""Test AMD SMI error handling for invalid operations."""
|
||||
# Try amdsmi only
|
||||
try:
|
||||
import amdsmi
|
||||
|
||||
amdsmi.amdsmi_init()
|
||||
|
||||
# Test invalid device index - should raise exception
|
||||
with self.assertRaises(Exception):
|
||||
amdsmi.amdsmi_get_processor_handle(999) # Invalid index
|
||||
|
||||
amdsmi.amdsmi_shut_down()
|
||||
|
||||
except ImportError:
|
||||
self.skipTest("amdsmi library not available")
|
||||
|
||||
@unittest.skipIf(True, "amdsmi is not available in CI, run this test locally")
|
||||
@unittest.skipIf(not torch.version.hip, "only amd")
|
||||
def test_amd_hardware_lookup_clock_hz(self):
|
||||
"""Test the _amd_hardware_lookup_clock_hz function with real AMD hardware."""
|
||||
# Test the actual function directly
|
||||
clock_hz = DeviceInfo._amd_hardware_lookup_clock_hz()
|
||||
|
||||
self.assertIsInstance(clock_hz, float)
|
||||
self.assertGreater(clock_hz, 0)
|
||||
# Clock frequency should be reasonable (between 500MHz and 5GHz)
|
||||
self.assertGreater(clock_hz, 50e6)
|
||||
self.assertLess(clock_hz, 5e9)
|
||||
# Should return frequency in Hz, not MHz
|
||||
# Most AMD clocks are in GHz range, so check it's properly converted
|
||||
self.assertGreater(clock_hz, 1e8) # At least 100MHz in Hz
|
||||
|
||||
@unittest.skipIf(True, "amdsmi is not available in CI, run this test locally")
|
||||
@unittest.skipIf(not torch.version.hip, "only amd")
|
||||
def test_amd_hardware_lookup_memory_clock_hz(self):
|
||||
"""Test the _amd_hardware_lookup_memory_clock_hz function with real AMD hardware."""
|
||||
try:
|
||||
memory_clock_hz = DeviceInfo._amd_hardware_lookup_memory_clock_hz()
|
||||
|
||||
self.assertIsInstance(memory_clock_hz, float)
|
||||
self.assertGreater(memory_clock_hz, 0)
|
||||
# Memory clock frequency should be reasonable (between 500MHz and 10GHz)
|
||||
self.assertGreater(memory_clock_hz, 500e6)
|
||||
self.assertLess(memory_clock_hz, 10e9)
|
||||
# Should return frequency in Hz, not MHz
|
||||
# Most AMD memory clocks are in GHz range, so check it's properly converted
|
||||
self.assertGreater(memory_clock_hz, 1e8) # At least 100MHz in Hz
|
||||
|
||||
except ImportError:
|
||||
self.fail("amdsmi library not available - install AMD SMI")
|
||||
except Exception:
|
||||
# If there's a hardware error or no AMD device, the function should
|
||||
# handle it gracefully and return None rather than crash
|
||||
self.assertIsNone(DeviceInfo._amd_hardware_lookup_memory_clock_hz())
|
||||
|
||||
def test_dram_bw_hardware_calculation(self):
|
||||
"""Test DRAM bandwidth calculation with memory clock adjustment."""
|
||||
with (
|
||||
patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=7e9),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="AMD MI300X"),
|
||||
):
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="AMD MI300X")
|
||||
# Uses datasheet value (5300.0 GB/s) with memory clock adjustment
|
||||
# Device mapping has memory_clock_hz=5200*1e6, so ratio = 7e9 / (5200*1e6) = ~1.346
|
||||
datasheet_bw = 5300.0
|
||||
device_info = lookup_device_info("AMD MI300X")
|
||||
if device_info and device_info.memory_clock_hz:
|
||||
memory_clock_ratio = 7e9 / device_info.memory_clock_hz
|
||||
expected_bw = datasheet_bw * memory_clock_ratio
|
||||
else:
|
||||
expected_bw = datasheet_bw
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
def test_dram_bw_datasheet_calculation(self):
|
||||
"""Test DRAM bandwidth calculation using datasheet values."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_memory_clock_hz", return_value=1.4e10 / 2
|
||||
), # Use half datasheet memory clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="NVIDIA H100")
|
||||
expected_bw = 3350 / 2 # Datasheet bandwidth scaled by memory clock ratio
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
def test_dram_bw_fallback_to_datasheet(self):
|
||||
"""Test DRAM bandwidth fallback to datasheet when hardware lookup fails."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch.object(
|
||||
DeviceInfo, "lookup_memory_clock_hz", return_value=1.4e10 / 2
|
||||
), # Use half datasheet memory clock
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="NVIDIA H100")
|
||||
expected_bw = 3350 / 2 # Datasheet bandwidth scaled by memory clock ratio
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
def test_dram_bw_memory_clock_adjustment_in_fallback(self):
|
||||
"""Test memory clock adjustment when falling back to datasheet."""
|
||||
custom_device_info = DeviceSpec(
|
||||
memory_clock_hz=2e9,
|
||||
tops={torch.float32: 100.0},
|
||||
dram_bw_gbs=1000.0,
|
||||
dram_gb=16.0,
|
||||
sm_count=None,
|
||||
clock_hz=1.5e9,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch(
|
||||
"torch._inductor.analysis.device_info.lookup_device_info"
|
||||
) as mock_lookup,
|
||||
):
|
||||
mock_get_device_name.return_value = "Custom Device"
|
||||
mock_lookup.return_value = custom_device_info
|
||||
|
||||
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=4e9):
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs("Custom Device")
|
||||
|
||||
datasheet_bw = 1000.0
|
||||
memory_clock_ratio = 4e9 / 2e9
|
||||
expected_bw = datasheet_bw * memory_clock_ratio
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
@patch("torch._inductor.analysis.device_info.lookup_device_info")
|
||||
def test_dram_bw_memory_clock_adjustment_no_expected_clock(self, mock_lookup):
|
||||
"""Test fallback behavior when device mapping has None for memory_clock_hz."""
|
||||
device_info = DeviceSpec(
|
||||
memory_clock_hz=None,
|
||||
tops={torch.float32: 100.0},
|
||||
dram_bw_gbs=1000.0,
|
||||
dram_gb=16.0,
|
||||
sm_count=None,
|
||||
clock_hz=1.5e9,
|
||||
)
|
||||
mock_lookup.return_value = device_info
|
||||
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=4e9):
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs("NVIDIA H100")
|
||||
|
||||
expected_bw = 1000.0 # No memory clock adjustment
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
def test_dram_bw_memory_clock_adjustment_none_clock(self):
|
||||
"""Test fallback behavior when memory clock lookup returns None."""
|
||||
with (
|
||||
patch("torch.cuda.get_device_name") as mock_get_device_name,
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
mock_get_device_name.return_value = "NVIDIA H100"
|
||||
|
||||
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=None):
|
||||
dram_bw = DeviceInfo.lookup_dram_bw_gbs("NVIDIA H100")
|
||||
|
||||
expected_bw = 3350 # Datasheet value without adjustment
|
||||
self.assertEqual(dram_bw, expected_bw)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,8 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,360 +8,32 @@ import torch
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_pynvml() -> Optional[Any]:
|
||||
"""Get pynvml from torch.cuda if available."""
|
||||
return getattr(torch.cuda, "pynvml", None) if torch.cuda._HAS_PYNVML else None
|
||||
|
||||
|
||||
def _get_amd_smi() -> Optional[Any]:
|
||||
"""Get AMD SMI from torch.cuda if available."""
|
||||
return getattr(torch.cuda, "amdsmi", None) if torch.cuda._HAS_PYNVML else None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _device_library_context(
|
||||
library_getter: Callable[[], Optional[Any]],
|
||||
library_name: str,
|
||||
init_method: str,
|
||||
shutdown_method: str,
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Generic context manager for device library operations.
|
||||
Handles initialization, exception catching, and cleanup.
|
||||
|
||||
Args:
|
||||
library_getter: Function that returns the library module or None
|
||||
library_name: Name of the library for error messages
|
||||
init_method: Name of the initialization method to call
|
||||
shutdown_method: Name of the shutdown method to call
|
||||
"""
|
||||
library = library_getter()
|
||||
if library is None:
|
||||
raise RuntimeError(f"{library_name} not available")
|
||||
|
||||
try:
|
||||
getattr(library, init_method)()
|
||||
yield library
|
||||
finally:
|
||||
try:
|
||||
getattr(library, shutdown_method)()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml_context() -> Generator[Any, None, None]:
|
||||
"""Context manager for NVML operations."""
|
||||
with _device_library_context(
|
||||
_get_pynvml, "pynvml", "nvmlInit", "nvmlShutdown"
|
||||
) as library:
|
||||
yield library
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _amd_smi_context() -> Generator[Any, None, None]:
|
||||
"""Context manager for AMD SMI operations."""
|
||||
with _device_library_context(
|
||||
_get_amd_smi, "amdsmi", "amdsmi_init", "amdsmi_shut_down"
|
||||
) as library:
|
||||
yield library
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceSpec:
|
||||
class DeviceInfo:
|
||||
"""
|
||||
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
|
||||
then the higher number is reported. Sparsity is not considered.
|
||||
|
||||
|
||||
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
|
||||
For example,
|
||||
"""
|
||||
|
||||
tops: dict[Union[torch.dtype, str], float]
|
||||
dram_bw_gbs: float
|
||||
dram_gb: float
|
||||
sm_count: Optional[int]
|
||||
clock_hz: Optional[float]
|
||||
memory_clock_hz: Optional[float]
|
||||
|
||||
|
||||
class DeviceInfo:
|
||||
"""
|
||||
Device information lookup utility for GPU hardware introspection.
|
||||
|
||||
This class provides methods to retrieve various hardware specifications
|
||||
and performance characteristics of GPU devices. It supports both NVIDIA
|
||||
and AMD GPUs through hardware lookup methods and falls back to datasheet
|
||||
values when hardware information is not available.
|
||||
|
||||
The class can provide information about:
|
||||
- Streaming multiprocessor (SM) count
|
||||
- Clock frequencies (core and memory)
|
||||
- DRAM capacity and bandwidth
|
||||
- Peak FLOPS/TOPS performance
|
||||
|
||||
Methods use a two-tier lookup strategy:
|
||||
1. Hardware introspection via pynvml (NVIDIA) or AMD SMI libraries
|
||||
2. Fallback to predefined datasheet values for known device models
|
||||
|
||||
Example usage:
|
||||
device_name = torch.cuda.get_device_name()
|
||||
peak_tops = DeviceInfo.lookup_tops(device_name, torch.float32)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _hardware_lookup_sm_count() -> Optional[int]:
|
||||
"""Get the number of streaming multiprocessors from the hardware."""
|
||||
try:
|
||||
# rely on device_properties api
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
return device_props.multi_processor_count
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hardware_lookup_clock_hz() -> Optional[float]:
|
||||
"""Get the clock speed in Hz from the hardware."""
|
||||
if torch.version.hip is not None:
|
||||
amd_clock = DeviceInfo._amd_hardware_lookup_clock_hz()
|
||||
return amd_clock
|
||||
|
||||
try:
|
||||
with _nvml_context() as pynvml:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle, pynvml.NVML_CLOCK_SM
|
||||
)
|
||||
return clock_mhz * 1e6
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _amd_hardware_lookup_clock_hz() -> Optional[float]:
|
||||
"""Get the clock speed in Hz from AMD hardware."""
|
||||
try:
|
||||
with _amd_smi_context() as amd_smi:
|
||||
device_handle = amd_smi.amdsmi_get_processor_handles()[0]
|
||||
clock_info = amd_smi.amdsmi_get_clock_info(
|
||||
device_handle, amd_smi.AmdSmiClkType.SYS
|
||||
)
|
||||
return clock_info["max_clk"] * 1e6 if "max_clk" in clock_info else None
|
||||
except Exception as e:
|
||||
log.info("Failed to get AMD clock frequency: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hardware_lookup_memory_clock_hz() -> Optional[float]:
|
||||
"""Get the memory clock speed in Hz from the hardware."""
|
||||
if torch.version.hip is not None:
|
||||
amd_memory_clock = DeviceInfo._amd_hardware_lookup_memory_clock_hz()
|
||||
return amd_memory_clock
|
||||
|
||||
try:
|
||||
with _nvml_context() as pynvml:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
mem_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
|
||||
handle, pynvml.NVML_CLOCK_MEM
|
||||
)
|
||||
return mem_clock_mhz * 1e6
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _amd_hardware_lookup_memory_clock_hz() -> Optional[float]:
|
||||
"""Get the memory clock speed in Hz from AMD hardware."""
|
||||
try:
|
||||
with _amd_smi_context() as amd_smi:
|
||||
device_handle = amd_smi.amdsmi_get_processor_handles()[0]
|
||||
mem_clock_info = amd_smi.amdsmi_get_clock_info(
|
||||
device_handle, amd_smi.AmdSmiClkType.MEM
|
||||
)
|
||||
return (
|
||||
mem_clock_info["max_clk"] * 1e6
|
||||
if "max_clk" in mem_clock_info
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
log.info("Failed to get AMD memory clock frequency: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _hardware_dram_gb() -> Optional[float]:
|
||||
"""Get the DRAM memory size in GB from the hardware."""
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
# Convert from bytes to GB
|
||||
return device_props.total_memory / (1024**3)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _generic_lookup(
|
||||
device_name: str, element_name: str
|
||||
) -> Optional[Union[int, float]]:
|
||||
"""
|
||||
Generic lookup method for device elements.
|
||||
First attempts hardware lookup, then falls back to device mapping.
|
||||
|
||||
Args:
|
||||
element_name: Name of the element to lookup (e.g., 'sm_count', 'clock_hz')
|
||||
|
||||
Returns:
|
||||
The value from hardware lookup or device mapping, or None if not available.
|
||||
"""
|
||||
hardware_lookup_methods = {
|
||||
"sm_count": DeviceInfo._hardware_lookup_sm_count,
|
||||
"clock_hz": DeviceInfo._hardware_lookup_clock_hz,
|
||||
"memory_clock_hz": DeviceInfo._hardware_lookup_memory_clock_hz,
|
||||
"dram_gb": DeviceInfo._hardware_dram_gb,
|
||||
}
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.get_device_name() == device_name:
|
||||
# we're on the device that we're testing, so try to look up values via hardware libraries.
|
||||
hardware_method = hardware_lookup_methods.get(element_name)
|
||||
if hardware_method:
|
||||
hardware_value = hardware_method()
|
||||
if hardware_value is not None:
|
||||
return hardware_value
|
||||
|
||||
# Attempt to lookup from device mapping
|
||||
device_info = lookup_device_info(device_name)
|
||||
if device_info is not None:
|
||||
return getattr(device_info, element_name, None)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def lookup_sm_count(device_name: str) -> Optional[int]:
|
||||
"""Get the number of streaming multiprocessors for the current device."""
|
||||
result = DeviceInfo._generic_lookup(device_name, "sm_count")
|
||||
return result if isinstance(result, int) or result is None else None
|
||||
|
||||
@staticmethod
|
||||
def lookup_clock_hz(device_name: str) -> Optional[float]:
|
||||
"""Get the clock speed in Hz for the current device."""
|
||||
return DeviceInfo._generic_lookup(device_name, "clock_hz")
|
||||
|
||||
@staticmethod
|
||||
def lookup_memory_clock_hz(device_name: str) -> Optional[float]:
|
||||
"""Get the memory clock speed in Hz for the current device."""
|
||||
return DeviceInfo._generic_lookup(device_name, "memory_clock_hz")
|
||||
|
||||
@staticmethod
|
||||
def lookup_dram_gb(device_name: str) -> Optional[float]:
|
||||
"""Get the DRAM memory size in GB for the current device."""
|
||||
return DeviceInfo._generic_lookup(device_name, "dram_gb")
|
||||
|
||||
@staticmethod
|
||||
def lookup_dram_bw_gbs(device_name: str) -> Optional[float]:
|
||||
"""
|
||||
Get the DRAM bandwidth in GB/s for the current device.
|
||||
|
||||
Uses hardware lookup first, then falls back to datasheet value
|
||||
scaled by memory clock ratio if available.
|
||||
"""
|
||||
lookupable = torch.cuda.is_available() and (
|
||||
torch.cuda.get_device_name() == device_name
|
||||
)
|
||||
|
||||
# Fall back to datasheet value with memory clock scaling
|
||||
device_info = lookup_device_info(device_name)
|
||||
if device_info is None:
|
||||
return None
|
||||
|
||||
datasheet_bw = device_info.dram_bw_gbs
|
||||
if datasheet_bw is None:
|
||||
return None
|
||||
|
||||
# Apply memory clock adjustment if current memory clock is available
|
||||
if lookupable:
|
||||
current_memory_clock_hz = DeviceInfo.lookup_memory_clock_hz(device_name)
|
||||
if (
|
||||
current_memory_clock_hz is not None
|
||||
and device_info.memory_clock_hz is not None
|
||||
):
|
||||
# Scale bandwidth by memory clock ratio
|
||||
expected_memory_clock_hz = device_info.memory_clock_hz
|
||||
memory_clock_ratio = current_memory_clock_hz / expected_memory_clock_hz
|
||||
datasheet_bw *= memory_clock_ratio
|
||||
|
||||
return datasheet_bw
|
||||
|
||||
@staticmethod
|
||||
def lookup_tops(
|
||||
device_name: str,
|
||||
dtype: torch.dtype,
|
||||
is_tf32: bool = False,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Our best attempt to calculate the current tops. Adjust by the ratio of current clock speed to theoretical.
|
||||
|
||||
Returns:
|
||||
Peak FLOPS as a float, or None if calculation fails
|
||||
"""
|
||||
lookupable = torch.cuda.is_available() and (
|
||||
torch.cuda.get_device_name() == device_name
|
||||
)
|
||||
|
||||
# Use datasheet values adjusted by clock ratio
|
||||
peak_ops = datasheet_tops(dtype, is_tf32)
|
||||
if peak_ops is None:
|
||||
return None
|
||||
peak_ops *= 1e12 # Convert TOPS to FLOPS
|
||||
|
||||
# Apply clock adjustment for datasheet fallback calculations
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return peak_ops
|
||||
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if device_name is None:
|
||||
return peak_ops
|
||||
|
||||
device_info = lookup_device_info(device_name)
|
||||
if device_info is None:
|
||||
return peak_ops
|
||||
|
||||
if lookupable:
|
||||
current_clock_hz = DeviceInfo.lookup_clock_hz(device_name)
|
||||
if current_clock_hz is not None and device_info.clock_hz is not None:
|
||||
# Use the expected clock speed from the device mapping for scaling
|
||||
expected_clock_hz = device_info.clock_hz
|
||||
clock_ratio = current_clock_hz / expected_clock_hz
|
||||
peak_ops *= clock_ratio
|
||||
|
||||
return peak_ops
|
||||
|
||||
@staticmethod
|
||||
def lookup_tops_current_device(
|
||||
dtype: torch.dtype,
|
||||
is_tf32: bool = False,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Our best attempt to calculate the current tops. Adjust by the ratio of current clock speed to theoretical.
|
||||
|
||||
Returns:
|
||||
Peak FLOPS as a float, or None if calculation fails
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
name: Optional[str] = torch.cuda.get_device_name()
|
||||
if name is None:
|
||||
return None
|
||||
return DeviceInfo.lookup_tops(name, dtype, is_tf32)
|
||||
|
||||
|
||||
# Indexing is based on `torch.cuda.get_device_name()`
|
||||
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
|
||||
_device_mapping: dict[str, DeviceSpec] = {
|
||||
_device_mapping: dict[str, DeviceInfo] = {
|
||||
# Source:
|
||||
# @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
|
||||
# These are from H100 SXM.
|
||||
#
|
||||
"NVIDIA H100": DeviceSpec(
|
||||
"NVIDIA H100": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 34.0,
|
||||
torch.float32: 67.0,
|
||||
"torch.tf32": 989.0,
|
||||
torch.float64: 67.0,
|
||||
torch.float32: 67.5,
|
||||
"torch.tf32": 156.0,
|
||||
torch.bfloat16: 1979.0,
|
||||
torch.float16: 1979.0,
|
||||
torch.float8_e8m0fnu: 3958.0,
|
||||
@ -376,17 +46,11 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
},
|
||||
dram_bw_gbs=3350,
|
||||
dram_gb=80,
|
||||
sm_count=132,
|
||||
# boost clock
|
||||
clock_hz=1.98e9,
|
||||
memory_clock_hz=1.4e10,
|
||||
# bus: 5120 bit
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
|
||||
# nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
|
||||
# Tensor cores enabled + SXM
|
||||
"NVIDIA A100": DeviceSpec(
|
||||
"NVIDIA A100": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 19.5,
|
||||
torch.float32: 19.5,
|
||||
@ -394,19 +58,14 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
torch.float16: 312.5,
|
||||
# Not in datasheet: float8
|
||||
torch.int8: 624.0,
|
||||
"torch.tf32": 312.0,
|
||||
"torch.tf32": 156.0,
|
||||
},
|
||||
dram_bw_gbs=2039.0,
|
||||
dram_gb=80.0,
|
||||
sm_count=108,
|
||||
# boost clock
|
||||
clock_hz=1410 * 1e6,
|
||||
memory_clock_hz=1593 * 1e6,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
|
||||
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/l4/PB-11316-001_v01.pdf
|
||||
"NVIDIA L4": DeviceSpec(
|
||||
"NVIDIA L4": DeviceInfo(
|
||||
tops={
|
||||
# This is a guess, not in datasheet
|
||||
torch.float64: 15.1,
|
||||
@ -424,15 +83,11 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
},
|
||||
dram_bw_gbs=3350,
|
||||
dram_gb=24,
|
||||
sm_count=58,
|
||||
clock_hz=2040 * 1e6,
|
||||
# bus: 192 bit
|
||||
memory_clock_hz=6251 * 1e6,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
|
||||
# /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
|
||||
"AMD MI300A": DeviceSpec(
|
||||
"AMD MI300A": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 122.6,
|
||||
torch.float32: 122.6,
|
||||
@ -449,15 +104,11 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
},
|
||||
dram_bw_gbs=5300.0,
|
||||
dram_gb=128.0,
|
||||
sm_count=228,
|
||||
# bus: 8192 bit
|
||||
clock_hz=2100 * 1e6,
|
||||
memory_clock_hz=2600 * 1e6,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
|
||||
# instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
|
||||
"AMD MI300X": DeviceSpec(
|
||||
"AMD MI300X": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 163.4,
|
||||
torch.float32: 163.4,
|
||||
@ -474,14 +125,11 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
},
|
||||
dram_bw_gbs=5300.0,
|
||||
dram_gb=192.0,
|
||||
sm_count=304,
|
||||
clock_hz=2100 * 1e6,
|
||||
memory_clock_hz=5200 * 1e6,
|
||||
),
|
||||
# Source:
|
||||
# @lint-ignore https://www.amd.com/content/dam/amd/\
|
||||
# en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
|
||||
"AMD MI210X": DeviceSpec(
|
||||
"AMD MI210X": DeviceInfo(
|
||||
tops={
|
||||
torch.float64: 45.3,
|
||||
torch.float32: 45.3,
|
||||
@ -501,21 +149,18 @@ _device_mapping: dict[str, DeviceSpec] = {
|
||||
# pcie4.0x16
|
||||
dram_bw_gbs=1600.0,
|
||||
dram_gb=64.0,
|
||||
sm_count=104,
|
||||
clock_hz=1700 * 1e6,
|
||||
memory_clock_hz=1600 * 1e6,
|
||||
),
|
||||
}
|
||||
_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
|
||||
_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
|
||||
|
||||
|
||||
def lookup_device_info(name: str) -> Optional[DeviceSpec]:
|
||||
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
|
||||
"""
|
||||
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
|
||||
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
|
||||
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
|
||||
If one is missing, please run DeviceSpec.get_device_info() and add it to _device_mapping.
|
||||
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
|
||||
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
|
||||
"""
|
||||
return _device_mapping.get(name, None)
|
||||
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import DeviceSpec, lookup_device_info
|
||||
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
|
||||
from torch._inductor.utils import tabulate_2d, zip_dicts
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -381,7 +381,7 @@ KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
|
||||
class Device:
|
||||
name: str
|
||||
index: int
|
||||
info: Optional[DeviceSpec]
|
||||
info: Optional[DeviceInfo]
|
||||
stats: KernelNameMap
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -60,7 +60,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.analysis.device_info import DeviceInfo
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -2416,9 +2416,7 @@ def get_device_tflops(dtype: torch.dtype) -> float:
|
||||
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
|
||||
then fall back to the inaccurate triton estimation.
|
||||
"""
|
||||
ds_tops = DeviceInfo.lookup_tops_current_device(
|
||||
dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32
|
||||
)
|
||||
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
|
||||
if ds_tops is not None:
|
||||
return ds_tops
|
||||
|
||||
|
Reference in New Issue
Block a user