Revert "Improve device info with new flops and bandwidth formula based on hardware libraries (#162245)"

This reverts commit 35d7b321597ed00245aad533a8fa6b7fdadd73ea.

Reverted https://github.com/pytorch/pytorch/pull/162245 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/162245#issuecomment-3313669412))
This commit is contained in:
PyTorch MergeBot
2025-09-19 20:09:12 +00:00
parent 4a160dae3c
commit 2a308c7dee
5 changed files with 24 additions and 1075 deletions

View File

@ -289,12 +289,14 @@ class TestAnalysis(TestCase):
om = _test_model(device, dtype)
REPEAT = 5
trace1, trace2 = trace_files()
print(f"first trace {trace1}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
om()
p.export_chrome_trace(trace1)
print(f"second trace {trace2}")
torch._dynamo.reset() # reset the cache
with fresh_inductor_cache():
with torch.profiler.profile(record_shapes=True) as p:
@ -302,6 +304,7 @@ class TestAnalysis(TestCase):
om()
p.export_chrome_trace(trace2)
print("diffing...")
with patch(
"sys.argv",
[

View File

@ -1,697 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import MagicMock, patch
import torch
from torch._inductor.analysis.device_info import (
_get_amd_smi,
_get_pynvml,
datasheet_tops,
DeviceInfo,
DeviceSpec,
lookup_device_info,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestDeviceInfo(TestCase):
def test_lookup_device_info(self):
h100_info = lookup_device_info("NVIDIA H100")
self.assertIsNotNone(h100_info)
if h100_info is not None:
self.assertEqual(h100_info.dram_gb, 80)
self.assertIn(torch.float32, h100_info.tops)
unknown_info = lookup_device_info("Unknown Device")
self.assertIsNone(unknown_info)
def test_datasheet_tops_function(self):
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "NVIDIA H100"
tops = datasheet_tops(torch.float32)
self.assertIsNotNone(tops)
self.assertEqual(tops, 67.0)
tops_tf32 = datasheet_tops(torch.float32, is_tf32=True)
self.assertEqual(tops_tf32, 989.0)
mock_get_device_name.return_value = "Unknown Device"
tops_unknown = datasheet_tops(torch.float32)
self.assertIsNone(tops_unknown)
mock_get_device_name.return_value = None
tops_no_device = datasheet_tops(torch.float32)
self.assertIsNone(tops_no_device)
@unittest.skipIf(torch.version.hip, "only nvidia")
def test_lazy_pynvml_import(self):
"""Test pynvml import through torch.cuda."""
with (
patch("torch.cuda._HAS_PYNVML", True),
patch.object(torch.cuda, "pynvml", MagicMock(), create=True) as mock_pynvml,
):
pynvml = _get_pynvml()
self.assertEqual(pynvml, mock_pynvml)
with patch("torch.cuda._HAS_PYNVML", False):
pynvml = _get_pynvml()
self.assertIsNone(pynvml)
@patch("torch.version.hip", None)
@patch("torch._inductor.analysis.device_info._get_pynvml")
def test_hardware_lookup_clock_hz_success(self, mock_get_pynvml):
mock_pynvml = MagicMock()
mock_pynvml.nvmlInit = MagicMock()
mock_pynvml.nvmlDeviceGetHandleByIndex.return_value = "mock_handle"
mock_pynvml.nvmlDeviceGetMaxClockInfo.return_value = 1500
mock_pynvml.NVML_CLOCK_SM = "clock_key"
mock_pynvml.nvmlShutdown = MagicMock()
mock_get_pynvml.return_value = mock_pynvml
result = DeviceInfo._hardware_lookup_clock_hz()
self.assertEqual(result, 1500 * 1e6)
@unittest.skipIf(torch.version.hip, "only nvidia")
def test_lazy_pynvml_import_caching(self):
"""Test pynvml caching through torch.cuda (now handled by torch.cuda module)."""
with (
patch("torch.cuda._HAS_PYNVML", True),
patch.object(torch.cuda, "pynvml", MagicMock(), create=True) as mock_pynvml,
):
pynvml1 = _get_pynvml()
self.assertEqual(pynvml1, mock_pynvml)
pynvml2 = _get_pynvml()
self.assertEqual(pynvml2, mock_pynvml)
self.assertEqual(pynvml1, pynvml2)
def test_hardware_lookup_exception_handling(self):
with (
patch("torch.version.hip", None),
patch(
"torch.cuda.get_device_properties", side_effect=Exception("CUDA Error")
),
patch(
"torch._inductor.analysis.device_info._get_pynvml"
) as mock_get_pynvml,
):
mock_pynvml = MagicMock()
mock_pynvml.nvmlInit.side_effect = Exception("NVML Error")
mock_get_pynvml.return_value = mock_pynvml
# Test direct hardware lookup methods, not the generic lookup methods
result = DeviceInfo._hardware_lookup_sm_count()
self.assertIsNone(result)
result = DeviceInfo._hardware_lookup_clock_hz()
self.assertIsNone(result)
def test_device_mapping_aliases(self):
mi300x_direct = lookup_device_info("AMD MI300X")
mi300x_alias = lookup_device_info("AMD INSTINCT MI300X")
self.assertEqual(mi300x_direct, mi300x_alias)
mi210x_direct = lookup_device_info("AMD MI210X")
mi210x_alias = lookup_device_info("AMD INSTINCT MI210X")
self.assertEqual(mi210x_direct, mi210x_alias)
def test_lazy_amd_smi_import_success(self):
"""Test AMD SMI import through torch.cuda."""
with patch("torch.cuda._HAS_PYNVML", False):
amd_smi = _get_amd_smi()
self.assertIsNone(amd_smi)
def test_lazy_amd_smi_import_caching(self):
"""Test AMD SMI caching through torch.cuda."""
# Test consistent behavior across multiple calls
with patch("torch.cuda._HAS_PYNVML", True):
amd_smi1 = _get_amd_smi()
amd_smi2 = _get_amd_smi()
# Both should return the same result (None in this environment)
self.assertEqual(amd_smi1, amd_smi2)
with patch("torch.cuda._HAS_PYNVML", False):
amd_smi1 = _get_amd_smi()
amd_smi2 = _get_amd_smi()
self.assertEqual(amd_smi1, amd_smi2)
self.assertIsNone(amd_smi1)
def test_amd_device_mapping_entries(self):
"""Test that AMD devices are properly represented in device mapping."""
mi300x = lookup_device_info("AMD MI300X")
self.assertIsNotNone(mi300x)
if mi300x is not None:
self.assertEqual(mi300x.dram_gb, 192.0)
self.assertEqual(mi300x.dram_bw_gbs, 5300.0)
self.assertIn(torch.float32, mi300x.tops)
mi300x_instinct = lookup_device_info("AMD INSTINCT MI300X")
self.assertEqual(mi300x, mi300x_instinct)
mi300a = lookup_device_info("AMD MI300A")
self.assertIsNotNone(mi300a)
if mi300a is not None:
self.assertEqual(mi300a.dram_gb, 128.0)
self.assertEqual(mi300a.dram_bw_gbs, 5300.0)
mi210x = lookup_device_info("AMD MI210X")
self.assertIsNotNone(mi210x)
if mi210x is not None:
self.assertEqual(mi210x.dram_gb, 64.0)
self.assertEqual(mi210x.dram_bw_gbs, 1600.0)
mi210x_instinct = lookup_device_info("AMD INSTINCT MI210X")
self.assertEqual(mi210x, mi210x_instinct)
def test_amd_integration_with_datasheet_tops(self):
"""Test datasheet_tops function with AMD devices."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "AMD MI300X"
tops_fp32 = datasheet_tops(torch.float32)
self.assertEqual(tops_fp32, 163.4)
tops_fp16 = datasheet_tops(torch.float16)
self.assertEqual(tops_fp16, 1307.4)
tops_bf16 = datasheet_tops(torch.bfloat16)
self.assertEqual(tops_bf16, 1307.4)
tops_tf32 = datasheet_tops(torch.float32, is_tf32=True)
self.assertEqual(tops_tf32, 653.7)
def test_flops_hardware_calculation(self):
"""Test FLOPS calculation now uses datasheet values with clock adjustment."""
with (
patch.object(DeviceInfo, "lookup_clock_hz", return_value=1.5e9),
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="AMD MI300X"),
):
flops = DeviceInfo.lookup_tops(
device_name="AMD MI300X", dtype=torch.float32
)
# Now uses datasheet value (163.4 TOPS) with clock adjustment
# Device mapping has clock_hz=2100*1e6, so ratio = 1.5e9 / (2100*1e6) = ~0.714
datasheet_flops = 163.4 * 1e12
device_info = lookup_device_info("AMD MI300X")
if device_info and device_info.clock_hz:
clock_ratio = 1.5e9 / device_info.clock_hz
expected_flops = datasheet_flops * clock_ratio
else:
expected_flops = datasheet_flops
self.assertEqual(flops, expected_flops)
def test_flops_datasheet_calculation(self):
"""Test FLOPS calculation using datasheet TOPS."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
), # Use datasheet clock
):
mock_get_device_name.return_value = "NVIDIA H100"
flops = DeviceInfo.lookup_tops(
device_name="NVIDIA H100", dtype=torch.float32
)
expected_flops = 67.0 * 1e12 / 2
self.assertEqual(flops, expected_flops)
def test_flops_fallback_to_datasheet(self):
"""Test FLOPS fallback to datasheet when hardware lookup fails."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
), # Use datasheet clock
):
mock_get_device_name.return_value = "NVIDIA H100"
flops = DeviceInfo.lookup_tops(
device_name="NVIDIA H100", dtype=torch.float32
)
expected_flops = 67.0 * 1e12 / 2
self.assertEqual(flops, expected_flops)
def test_flops_clock_adjustment_in_fallback(self):
"""Test clock adjustment when falling back to datasheet."""
custom_device_info = DeviceSpec(
memory_clock_hz=100,
tops={torch.float32: 100.0},
dram_bw_gbs=1000.0,
dram_gb=16.0,
sm_count=None,
clock_hz=1.5e9,
)
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch(
"torch._inductor.analysis.device_info.lookup_device_info"
) as mock_lookup,
):
mock_get_device_name.return_value = "Custom Device"
mock_lookup.return_value = custom_device_info
with patch.object(
DeviceInfo, "_hardware_lookup_clock_hz", return_value=3.0e9
):
flops = DeviceInfo.lookup_tops("Custom Device", dtype=torch.float32)
datasheet_flops = 100.0 * 1e12
clock_ratio = 3.0e9 / 1.5e9
expected_flops = datasheet_flops * clock_ratio
self.assertEqual(flops, expected_flops)
@patch("torch._inductor.analysis.device_info.lookup_device_info")
def test_flops_clock_adjustment_no_expected_clock(self, mock_lookup):
"""Test fallback behavior when device mapping has None for clock_hz."""
device_info = DeviceSpec(
memory_clock_hz=100,
tops={torch.float32: 100.0},
dram_bw_gbs=1000.0,
dram_gb=16.0,
sm_count=None,
clock_hz=None,
)
mock_lookup.return_value = device_info
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "NVIDIA H100"
with patch.object(
DeviceInfo, "_hardware_lookup_clock_hz", return_value=3.0e9
):
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
expected_flops = 100.0 * 1e12
self.assertEqual(flops, expected_flops)
def test_flops_clock_adjustment_none_clock(self):
"""Test fallback behavior when clock lookup returns None."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "NVIDIA H100"
with patch.object(
DeviceInfo, "_hardware_lookup_clock_hz", return_value=None
):
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
expected_flops = 67.0 * 1e12
self.assertEqual(flops, expected_flops)
def test_flops_no_device_name(self):
"""Test FLOPS calculation when device name is unavailable."""
with (
patch("torch.cuda.get_device_name", return_value=None),
patch("torch.cuda.is_available", return_value=False),
):
# When there's no device name and we force datasheet, it should return None
with patch(
"torch._inductor.analysis.device_info.datasheet_tops", return_value=None
):
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
self.assertIsNone(flops)
# When cuda is not available, hardware lookup is skipped and datasheet is used
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
self.assertIsNone(
flops
) # Should be None since cuda.is_available() is False
def test_flops_unknown_device(self):
"""Test FLOPS calculation with unknown device."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "Unknown Device"
flops = DeviceInfo.lookup_tops("Unknown Device", dtype=torch.float32)
# Should be None for unknown device
self.assertIsNone(flops)
def test_flops_partial_hardware_values(self):
"""Test FLOPS calculation with some hardware values missing."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
), # Use datasheet clock
):
mock_get_device_name.return_value = "NVIDIA H100"
flops = DeviceInfo.lookup_tops(
device_name="NVIDIA H100", dtype=torch.float32
)
expected_flops = 67.0 * 1e12 / 2
self.assertEqual(flops, expected_flops)
def test_flops_exception_handling(self):
"""Test FLOPS calculation handles exceptions gracefully."""
with (
patch.object(
DeviceInfo,
"_hardware_lookup_sm_count",
side_effect=Exception("Hardware error"),
),
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_clock_hz", return_value=1.98e9 / 2
), # Use datasheet clock
):
mock_get_device_name.return_value = "NVIDIA H100"
flops = DeviceInfo.lookup_tops("NVIDIA H100", dtype=torch.float32)
expected_flops = 67.0 * 1e12 / 2
self.assertEqual(flops, expected_flops)
def test_flops_integration_with_hardware_lookup(self):
"""Test FLOPS integration with datasheet values and clock adjustment."""
dn = "NVIDIA H100"
with (
patch.object(DeviceInfo, "lookup_clock_hz", return_value=1500 * 1e6),
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value=dn),
):
flops = DeviceInfo.lookup_tops(device_name=dn, dtype=torch.float32)
# Now uses datasheet value (67.0 TOPS) with clock adjustment
# Device mapping has clock_hz=1.98e9, so ratio = 1500*1e6 / 1.98e9 = ~0.7576
datasheet_flops = 67.0 * 1e12
device_info = lookup_device_info(dn)
if device_info and device_info.clock_hz:
clock_ratio = (1500 * 1e6) / device_info.clock_hz
expected_flops = datasheet_flops * clock_ratio
else:
expected_flops = datasheet_flops
self.assertEqual(flops, expected_flops)
@unittest.skipIf(
True, "pynvml and amdsmi are not available in CI, run these tests locally"
)
@unittest.skipIf(torch.version.hip, "only nvidia")
def test_pynvml_integration(self):
"""Test direct pynvml library integration."""
try:
import pynvml
# Test basic NVML initialization and device access
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# Test clock frequency retrieval
sm_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
handle, pynvml.NVML_CLOCK_SM
)
self.assertIsInstance(sm_clock_mhz, int)
self.assertGreater(sm_clock_mhz, 0)
# Test memory clock frequency retrieval
mem_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
handle, pynvml.NVML_CLOCK_MEM
)
self.assertIsInstance(mem_clock_mhz, int)
self.assertGreater(mem_clock_mhz, 0)
# Test memory bus width retrieval
bus_width_bits = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
self.assertIsInstance(bus_width_bits, int)
self.assertGreater(bus_width_bits, 0)
# Test bandwidth calculation (same as device_info.py implementation)
mem_clock_hz = mem_clock_mhz * 1e6
effective_rate = mem_clock_hz * 2 # GDDR uses DDR so *2
peak_bw = (effective_rate * bus_width_bits) / 8
peak_bw_gbs = peak_bw / (1024**3)
self.assertIsInstance(peak_bw_gbs, float)
self.assertGreater(peak_bw_gbs, 0)
pynvml.nvmlShutdown()
except ImportError:
self.fail(
"pynvml library not available - install with 'pip install nvidia-ml-py'"
)
except Exception as e:
self.fail(f"pynvml integration failed: {e}")
@unittest.skipIf(
True, "pynvml and amdsmi are not available in CI, run these tests locally"
)
@unittest.skipIf(not torch.version.hip, "only amd")
def test_amdsmi_integration(self):
"""Test direct amdsmi library integration."""
try:
import amdsmi
# Test basic AMD SMI initialization
amdsmi.amdsmi_init()
# Test device handle retrieval (matches current implementation)
device_handle = amdsmi.amdsmi_get_processor_handles()[0]
self.assertIsNotNone(device_handle)
# Test GPU clock info retrieval (matches current implementation)
clock_info = amdsmi.amdsmi_get_clock_info(
device_handle, amdsmi.AmdSmiClkType.SYS
)
self.assertTrue("max_clk" in clock_info)
self.assertIsInstance(clock_info["max_clk"], int)
self.assertGreater(clock_info["max_clk"], 0)
# Test GPU memory clock info retrieval (matches current implementation)
mem_clock_info = amdsmi.amdsmi_get_clock_info(
device_handle, amdsmi.AmdSmiClkType.MEM
)
self.assertTrue("max_clk" in mem_clock_info)
self.assertIsInstance(mem_clock_info["max_clk"], int)
self.assertGreater(mem_clock_info["max_clk"], 0)
amdsmi.amdsmi_shut_down()
except ImportError:
self.fail("amdsmi library not available - install AMD SMI")
except Exception as e:
self.fail(f"amdsmi integration failed: {e}")
@unittest.skipIf(
True, "pynvml and amdsmi are not available in CI, run these tests locally"
)
@unittest.skipIf(torch.version.hip, "only amd")
def test_pynvml_error_handling(self):
"""Test pynvml error handling for invalid operations."""
try:
import pynvml
pynvml.nvmlInit()
# Test invalid device index - should raise exception
with self.assertRaises(Exception):
pynvml.nvmlDeviceGetHandleByIndex(999) # Invalid index
pynvml.nvmlShutdown()
except ImportError:
self.skipTest("pynvml library not available")
@unittest.skipIf(
True, "pynvml and amdsmi are not available in CI, run these tests locally"
)
@unittest.skipIf(not torch.version.hip, "only nvidia")
def test_amd_smi_error_handling(self):
"""Test AMD SMI error handling for invalid operations."""
# Try amdsmi only
try:
import amdsmi
amdsmi.amdsmi_init()
# Test invalid device index - should raise exception
with self.assertRaises(Exception):
amdsmi.amdsmi_get_processor_handle(999) # Invalid index
amdsmi.amdsmi_shut_down()
except ImportError:
self.skipTest("amdsmi library not available")
@unittest.skipIf(True, "amdsmi is not available in CI, run this test locally")
@unittest.skipIf(not torch.version.hip, "only amd")
def test_amd_hardware_lookup_clock_hz(self):
"""Test the _amd_hardware_lookup_clock_hz function with real AMD hardware."""
# Test the actual function directly
clock_hz = DeviceInfo._amd_hardware_lookup_clock_hz()
self.assertIsInstance(clock_hz, float)
self.assertGreater(clock_hz, 0)
# Clock frequency should be reasonable (between 500MHz and 5GHz)
self.assertGreater(clock_hz, 50e6)
self.assertLess(clock_hz, 5e9)
# Should return frequency in Hz, not MHz
# Most AMD clocks are in GHz range, so check it's properly converted
self.assertGreater(clock_hz, 1e8) # At least 100MHz in Hz
@unittest.skipIf(True, "amdsmi is not available in CI, run this test locally")
@unittest.skipIf(not torch.version.hip, "only amd")
def test_amd_hardware_lookup_memory_clock_hz(self):
"""Test the _amd_hardware_lookup_memory_clock_hz function with real AMD hardware."""
try:
memory_clock_hz = DeviceInfo._amd_hardware_lookup_memory_clock_hz()
self.assertIsInstance(memory_clock_hz, float)
self.assertGreater(memory_clock_hz, 0)
# Memory clock frequency should be reasonable (between 500MHz and 10GHz)
self.assertGreater(memory_clock_hz, 500e6)
self.assertLess(memory_clock_hz, 10e9)
# Should return frequency in Hz, not MHz
# Most AMD memory clocks are in GHz range, so check it's properly converted
self.assertGreater(memory_clock_hz, 1e8) # At least 100MHz in Hz
except ImportError:
self.fail("amdsmi library not available - install AMD SMI")
except Exception:
# If there's a hardware error or no AMD device, the function should
# handle it gracefully and return None rather than crash
self.assertIsNone(DeviceInfo._amd_hardware_lookup_memory_clock_hz())
def test_dram_bw_hardware_calculation(self):
"""Test DRAM bandwidth calculation with memory clock adjustment."""
with (
patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=7e9),
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="AMD MI300X"),
):
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="AMD MI300X")
# Uses datasheet value (5300.0 GB/s) with memory clock adjustment
# Device mapping has memory_clock_hz=5200*1e6, so ratio = 7e9 / (5200*1e6) = ~1.346
datasheet_bw = 5300.0
device_info = lookup_device_info("AMD MI300X")
if device_info and device_info.memory_clock_hz:
memory_clock_ratio = 7e9 / device_info.memory_clock_hz
expected_bw = datasheet_bw * memory_clock_ratio
else:
expected_bw = datasheet_bw
self.assertEqual(dram_bw, expected_bw)
def test_dram_bw_datasheet_calculation(self):
"""Test DRAM bandwidth calculation using datasheet values."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_memory_clock_hz", return_value=1.4e10 / 2
), # Use half datasheet memory clock
):
mock_get_device_name.return_value = "NVIDIA H100"
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="NVIDIA H100")
expected_bw = 3350 / 2 # Datasheet bandwidth scaled by memory clock ratio
self.assertEqual(dram_bw, expected_bw)
def test_dram_bw_fallback_to_datasheet(self):
"""Test DRAM bandwidth fallback to datasheet when hardware lookup fails."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch.object(
DeviceInfo, "lookup_memory_clock_hz", return_value=1.4e10 / 2
), # Use half datasheet memory clock
):
mock_get_device_name.return_value = "NVIDIA H100"
dram_bw = DeviceInfo.lookup_dram_bw_gbs(device_name="NVIDIA H100")
expected_bw = 3350 / 2 # Datasheet bandwidth scaled by memory clock ratio
self.assertEqual(dram_bw, expected_bw)
def test_dram_bw_memory_clock_adjustment_in_fallback(self):
"""Test memory clock adjustment when falling back to datasheet."""
custom_device_info = DeviceSpec(
memory_clock_hz=2e9,
tops={torch.float32: 100.0},
dram_bw_gbs=1000.0,
dram_gb=16.0,
sm_count=None,
clock_hz=1.5e9,
)
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
patch(
"torch._inductor.analysis.device_info.lookup_device_info"
) as mock_lookup,
):
mock_get_device_name.return_value = "Custom Device"
mock_lookup.return_value = custom_device_info
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=4e9):
dram_bw = DeviceInfo.lookup_dram_bw_gbs("Custom Device")
datasheet_bw = 1000.0
memory_clock_ratio = 4e9 / 2e9
expected_bw = datasheet_bw * memory_clock_ratio
self.assertEqual(dram_bw, expected_bw)
@patch("torch._inductor.analysis.device_info.lookup_device_info")
def test_dram_bw_memory_clock_adjustment_no_expected_clock(self, mock_lookup):
"""Test fallback behavior when device mapping has None for memory_clock_hz."""
device_info = DeviceSpec(
memory_clock_hz=None,
tops={torch.float32: 100.0},
dram_bw_gbs=1000.0,
dram_gb=16.0,
sm_count=None,
clock_hz=1.5e9,
)
mock_lookup.return_value = device_info
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "NVIDIA H100"
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=4e9):
dram_bw = DeviceInfo.lookup_dram_bw_gbs("NVIDIA H100")
expected_bw = 1000.0 # No memory clock adjustment
self.assertEqual(dram_bw, expected_bw)
def test_dram_bw_memory_clock_adjustment_none_clock(self):
"""Test fallback behavior when memory clock lookup returns None."""
with (
patch("torch.cuda.get_device_name") as mock_get_device_name,
patch("torch.cuda.is_available", return_value=True),
):
mock_get_device_name.return_value = "NVIDIA H100"
with patch.object(DeviceInfo, "lookup_memory_clock_hz", return_value=None):
dram_bw = DeviceInfo.lookup_dram_bw_gbs("NVIDIA H100")
expected_bw = 3350 # Datasheet value without adjustment
self.assertEqual(dram_bw, expected_bw)
if __name__ == "__main__":
run_tests()

View File

@ -1,8 +1,6 @@
import logging
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Optional, Union
import torch
@ -10,360 +8,32 @@ import torch
log = logging.getLogger(__name__)
def _get_pynvml() -> Optional[Any]:
"""Get pynvml from torch.cuda if available."""
return getattr(torch.cuda, "pynvml", None) if torch.cuda._HAS_PYNVML else None
def _get_amd_smi() -> Optional[Any]:
"""Get AMD SMI from torch.cuda if available."""
return getattr(torch.cuda, "amdsmi", None) if torch.cuda._HAS_PYNVML else None
@contextmanager
def _device_library_context(
library_getter: Callable[[], Optional[Any]],
library_name: str,
init_method: str,
shutdown_method: str,
) -> Generator[Any, None, None]:
"""
Generic context manager for device library operations.
Handles initialization, exception catching, and cleanup.
Args:
library_getter: Function that returns the library module or None
library_name: Name of the library for error messages
init_method: Name of the initialization method to call
shutdown_method: Name of the shutdown method to call
"""
library = library_getter()
if library is None:
raise RuntimeError(f"{library_name} not available")
try:
getattr(library, init_method)()
yield library
finally:
try:
getattr(library, shutdown_method)()
except Exception:
pass
@contextmanager
def _nvml_context() -> Generator[Any, None, None]:
"""Context manager for NVML operations."""
with _device_library_context(
_get_pynvml, "pynvml", "nvmlInit", "nvmlShutdown"
) as library:
yield library
@contextmanager
def _amd_smi_context() -> Generator[Any, None, None]:
"""Context manager for AMD SMI operations."""
with _device_library_context(
_get_amd_smi, "amdsmi", "amdsmi_init", "amdsmi_shut_down"
) as library:
yield library
@dataclass(frozen=True)
class DeviceSpec:
class DeviceInfo:
"""
Theoretical Numbers from data sheet. If two numbers are given, Tensor/Matrix Core vs not,
then the higher number is reported. Sparsity is not considered.
Bandwidth numbers are tricky, because there are platform differences that may not show up in the profiler trace.
For example,
"""
tops: dict[Union[torch.dtype, str], float]
dram_bw_gbs: float
dram_gb: float
sm_count: Optional[int]
clock_hz: Optional[float]
memory_clock_hz: Optional[float]
class DeviceInfo:
"""
Device information lookup utility for GPU hardware introspection.
This class provides methods to retrieve various hardware specifications
and performance characteristics of GPU devices. It supports both NVIDIA
and AMD GPUs through hardware lookup methods and falls back to datasheet
values when hardware information is not available.
The class can provide information about:
- Streaming multiprocessor (SM) count
- Clock frequencies (core and memory)
- DRAM capacity and bandwidth
- Peak FLOPS/TOPS performance
Methods use a two-tier lookup strategy:
1. Hardware introspection via pynvml (NVIDIA) or AMD SMI libraries
2. Fallback to predefined datasheet values for known device models
Example usage:
device_name = torch.cuda.get_device_name()
peak_tops = DeviceInfo.lookup_tops(device_name, torch.float32)
"""
@staticmethod
def _hardware_lookup_sm_count() -> Optional[int]:
"""Get the number of streaming multiprocessors from the hardware."""
try:
# rely on device_properties api
device_props = torch.cuda.get_device_properties(0)
return device_props.multi_processor_count
except Exception:
return None
@staticmethod
def _hardware_lookup_clock_hz() -> Optional[float]:
"""Get the clock speed in Hz from the hardware."""
if torch.version.hip is not None:
amd_clock = DeviceInfo._amd_hardware_lookup_clock_hz()
return amd_clock
try:
with _nvml_context() as pynvml:
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
handle, pynvml.NVML_CLOCK_SM
)
return clock_mhz * 1e6
except Exception:
return None
@staticmethod
def _amd_hardware_lookup_clock_hz() -> Optional[float]:
"""Get the clock speed in Hz from AMD hardware."""
try:
with _amd_smi_context() as amd_smi:
device_handle = amd_smi.amdsmi_get_processor_handles()[0]
clock_info = amd_smi.amdsmi_get_clock_info(
device_handle, amd_smi.AmdSmiClkType.SYS
)
return clock_info["max_clk"] * 1e6 if "max_clk" in clock_info else None
except Exception as e:
log.info("Failed to get AMD clock frequency: %s", e)
return None
@staticmethod
def _hardware_lookup_memory_clock_hz() -> Optional[float]:
"""Get the memory clock speed in Hz from the hardware."""
if torch.version.hip is not None:
amd_memory_clock = DeviceInfo._amd_hardware_lookup_memory_clock_hz()
return amd_memory_clock
try:
with _nvml_context() as pynvml:
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
mem_clock_mhz = pynvml.nvmlDeviceGetMaxClockInfo(
handle, pynvml.NVML_CLOCK_MEM
)
return mem_clock_mhz * 1e6
except Exception:
return None
@staticmethod
def _amd_hardware_lookup_memory_clock_hz() -> Optional[float]:
"""Get the memory clock speed in Hz from AMD hardware."""
try:
with _amd_smi_context() as amd_smi:
device_handle = amd_smi.amdsmi_get_processor_handles()[0]
mem_clock_info = amd_smi.amdsmi_get_clock_info(
device_handle, amd_smi.AmdSmiClkType.MEM
)
return (
mem_clock_info["max_clk"] * 1e6
if "max_clk" in mem_clock_info
else None
)
except Exception as e:
log.info("Failed to get AMD memory clock frequency: %s", e)
return None
@staticmethod
def _hardware_dram_gb() -> Optional[float]:
"""Get the DRAM memory size in GB from the hardware."""
try:
device_props = torch.cuda.get_device_properties(0)
# Convert from bytes to GB
return device_props.total_memory / (1024**3)
except Exception:
return None
@staticmethod
def _generic_lookup(
device_name: str, element_name: str
) -> Optional[Union[int, float]]:
"""
Generic lookup method for device elements.
First attempts hardware lookup, then falls back to device mapping.
Args:
element_name: Name of the element to lookup (e.g., 'sm_count', 'clock_hz')
Returns:
The value from hardware lookup or device mapping, or None if not available.
"""
hardware_lookup_methods = {
"sm_count": DeviceInfo._hardware_lookup_sm_count,
"clock_hz": DeviceInfo._hardware_lookup_clock_hz,
"memory_clock_hz": DeviceInfo._hardware_lookup_memory_clock_hz,
"dram_gb": DeviceInfo._hardware_dram_gb,
}
if torch.cuda.is_available() and torch.cuda.get_device_name() == device_name:
# we're on the device that we're testing, so try to look up values via hardware libraries.
hardware_method = hardware_lookup_methods.get(element_name)
if hardware_method:
hardware_value = hardware_method()
if hardware_value is not None:
return hardware_value
# Attempt to lookup from device mapping
device_info = lookup_device_info(device_name)
if device_info is not None:
return getattr(device_info, element_name, None)
return None
@staticmethod
def lookup_sm_count(device_name: str) -> Optional[int]:
"""Get the number of streaming multiprocessors for the current device."""
result = DeviceInfo._generic_lookup(device_name, "sm_count")
return result if isinstance(result, int) or result is None else None
@staticmethod
def lookup_clock_hz(device_name: str) -> Optional[float]:
"""Get the clock speed in Hz for the current device."""
return DeviceInfo._generic_lookup(device_name, "clock_hz")
@staticmethod
def lookup_memory_clock_hz(device_name: str) -> Optional[float]:
"""Get the memory clock speed in Hz for the current device."""
return DeviceInfo._generic_lookup(device_name, "memory_clock_hz")
@staticmethod
def lookup_dram_gb(device_name: str) -> Optional[float]:
"""Get the DRAM memory size in GB for the current device."""
return DeviceInfo._generic_lookup(device_name, "dram_gb")
@staticmethod
def lookup_dram_bw_gbs(device_name: str) -> Optional[float]:
"""
Get the DRAM bandwidth in GB/s for the current device.
Uses hardware lookup first, then falls back to datasheet value
scaled by memory clock ratio if available.
"""
lookupable = torch.cuda.is_available() and (
torch.cuda.get_device_name() == device_name
)
# Fall back to datasheet value with memory clock scaling
device_info = lookup_device_info(device_name)
if device_info is None:
return None
datasheet_bw = device_info.dram_bw_gbs
if datasheet_bw is None:
return None
# Apply memory clock adjustment if current memory clock is available
if lookupable:
current_memory_clock_hz = DeviceInfo.lookup_memory_clock_hz(device_name)
if (
current_memory_clock_hz is not None
and device_info.memory_clock_hz is not None
):
# Scale bandwidth by memory clock ratio
expected_memory_clock_hz = device_info.memory_clock_hz
memory_clock_ratio = current_memory_clock_hz / expected_memory_clock_hz
datasheet_bw *= memory_clock_ratio
return datasheet_bw
@staticmethod
def lookup_tops(
device_name: str,
dtype: torch.dtype,
is_tf32: bool = False,
) -> Optional[float]:
"""
Our best attempt to calculate the current tops. Adjust by the ratio of current clock speed to theoretical.
Returns:
Peak FLOPS as a float, or None if calculation fails
"""
lookupable = torch.cuda.is_available() and (
torch.cuda.get_device_name() == device_name
)
# Use datasheet values adjusted by clock ratio
peak_ops = datasheet_tops(dtype, is_tf32)
if peak_ops is None:
return None
peak_ops *= 1e12 # Convert TOPS to FLOPS
# Apply clock adjustment for datasheet fallback calculations
if not torch.cuda.is_available():
return peak_ops
device_name = torch.cuda.get_device_name()
if device_name is None:
return peak_ops
device_info = lookup_device_info(device_name)
if device_info is None:
return peak_ops
if lookupable:
current_clock_hz = DeviceInfo.lookup_clock_hz(device_name)
if current_clock_hz is not None and device_info.clock_hz is not None:
# Use the expected clock speed from the device mapping for scaling
expected_clock_hz = device_info.clock_hz
clock_ratio = current_clock_hz / expected_clock_hz
peak_ops *= clock_ratio
return peak_ops
@staticmethod
def lookup_tops_current_device(
dtype: torch.dtype,
is_tf32: bool = False,
) -> Optional[float]:
"""
Our best attempt to calculate the current tops. Adjust by the ratio of current clock speed to theoretical.
Returns:
Peak FLOPS as a float, or None if calculation fails
"""
if not torch.cuda.is_available():
return None
name: Optional[str] = torch.cuda.get_device_name()
if name is None:
return None
return DeviceInfo.lookup_tops(name, dtype, is_tf32)
# Indexing is based on `torch.cuda.get_device_name()`
# TODO investigate profiler support for tf32 and allow device to report correct number when it's turned on.
_device_mapping: dict[str, DeviceSpec] = {
_device_mapping: dict[str, DeviceInfo] = {
# Source:
# @lint-ignore https://www.nvidia.com/en-us/data-center/h100/
# These are from H100 SXM.
#
"NVIDIA H100": DeviceSpec(
"NVIDIA H100": DeviceInfo(
tops={
torch.float64: 34.0,
torch.float32: 67.0,
"torch.tf32": 989.0,
torch.float64: 67.0,
torch.float32: 67.5,
"torch.tf32": 156.0,
torch.bfloat16: 1979.0,
torch.float16: 1979.0,
torch.float8_e8m0fnu: 3958.0,
@ -376,17 +46,11 @@ _device_mapping: dict[str, DeviceSpec] = {
},
dram_bw_gbs=3350,
dram_gb=80,
sm_count=132,
# boost clock
clock_hz=1.98e9,
memory_clock_hz=1.4e10,
# bus: 5120 bit
),
# Source:
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/
# nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
# Tensor cores enabled + SXM
"NVIDIA A100": DeviceSpec(
"NVIDIA A100": DeviceInfo(
tops={
torch.float64: 19.5,
torch.float32: 19.5,
@ -394,19 +58,14 @@ _device_mapping: dict[str, DeviceSpec] = {
torch.float16: 312.5,
# Not in datasheet: float8
torch.int8: 624.0,
"torch.tf32": 312.0,
"torch.tf32": 156.0,
},
dram_bw_gbs=2039.0,
dram_gb=80.0,
sm_count=108,
# boost clock
clock_hz=1410 * 1e6,
memory_clock_hz=1593 * 1e6,
),
# Source:
# @lint-ignore https://resources.nvidia.com/en-us-gpu-resources/l4-tensor-datasheet
# @lint-ignore https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/l4/PB-11316-001_v01.pdf
"NVIDIA L4": DeviceSpec(
"NVIDIA L4": DeviceInfo(
tops={
# This is a guess, not in datasheet
torch.float64: 15.1,
@ -424,15 +83,11 @@ _device_mapping: dict[str, DeviceSpec] = {
},
dram_bw_gbs=3350,
dram_gb=24,
sm_count=58,
clock_hz=2040 * 1e6,
# bus: 192 bit
memory_clock_hz=6251 * 1e6,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents\
# /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf
"AMD MI300A": DeviceSpec(
"AMD MI300A": DeviceInfo(
tops={
torch.float64: 122.6,
torch.float32: 122.6,
@ -449,15 +104,11 @@ _device_mapping: dict[str, DeviceSpec] = {
},
dram_bw_gbs=5300.0,
dram_gb=128.0,
sm_count=228,
# bus: 8192 bit
clock_hz=2100 * 1e6,
memory_clock_hz=2600 * 1e6,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/en/documents/\
# instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
"AMD MI300X": DeviceSpec(
"AMD MI300X": DeviceInfo(
tops={
torch.float64: 163.4,
torch.float32: 163.4,
@ -474,14 +125,11 @@ _device_mapping: dict[str, DeviceSpec] = {
},
dram_bw_gbs=5300.0,
dram_gb=192.0,
sm_count=304,
clock_hz=2100 * 1e6,
memory_clock_hz=5200 * 1e6,
),
# Source:
# @lint-ignore https://www.amd.com/content/dam/amd/\
# en/documents/instinct-business-docs/product-briefs/instinct-mi210-brochure.pdf
"AMD MI210X": DeviceSpec(
"AMD MI210X": DeviceInfo(
tops={
torch.float64: 45.3,
torch.float32: 45.3,
@ -501,21 +149,18 @@ _device_mapping: dict[str, DeviceSpec] = {
# pcie4.0x16
dram_bw_gbs=1600.0,
dram_gb=64.0,
sm_count=104,
clock_hz=1700 * 1e6,
memory_clock_hz=1600 * 1e6,
),
}
_device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"]
_device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"]
def lookup_device_info(name: str) -> Optional[DeviceSpec]:
def lookup_device_info(name: str) -> Optional[DeviceInfo]:
"""
Problem: when diffing profiles between amd and nvidia, we don't have access to the device information
of the other one. Also, since the analysis is static, we should be able to do it on another device unrelated
to the recorded device. Therefore, _device_mapping statically contains the information for lots of devices.
If one is missing, please run DeviceSpec.get_device_info() and add it to _device_mapping.
If one is missing, please run DeviceInfo.get_device_info() and add it to _device_mapping.
name (str): name of the device to lookup. Should map onto torch.cuda.get_device_name().
"""
return _device_mapping.get(name, None)

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from torch._inductor.analysis.device_info import DeviceSpec, lookup_device_info
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
from torch._inductor.utils import tabulate_2d, zip_dicts
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
@ -381,7 +381,7 @@ KernelNameMap = defaultdict[str, OrderedSet[KernelStats]]
class Device:
name: str
index: int
info: Optional[DeviceSpec]
info: Optional[DeviceInfo]
stats: KernelNameMap
def __repr__(self) -> str:

View File

@ -60,7 +60,7 @@ import sympy
import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import DeviceInfo
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._dtype_abbrs import dtype_abbrs
from torch.utils._ordered_set import OrderedSet
@ -2416,9 +2416,7 @@ def get_device_tflops(dtype: torch.dtype) -> float:
We don't want to throw errors in this function. First check to see if the device is in device_info.py,
then fall back to the inaccurate triton estimation.
"""
ds_tops = DeviceInfo.lookup_tops_current_device(
dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32
)
ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32)
if ds_tops is not None:
return ds_tops