[inductor] Refactor is_big_gpu (#142220)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142220
Approved by: https://github.com/yanboliang
ghstack dependencies: #142219, #142033, #142222
This commit is contained in:
Jason Ansel
2024-12-06 21:56:30 -08:00
committed by PyTorch MergeBot
parent dc7461d6f5
commit e343f46464
10 changed files with 20 additions and 13 deletions

View File

@ -242,7 +242,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
def setUp(self):
super().setUp()
if not is_big_gpu(0):
if not is_big_gpu():
return self.skipTest("Need a big GPU to run max_autotune=True")
def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):

View File

@ -456,5 +456,5 @@ if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
if HAS_CUDA and HAS_CPU and is_big_gpu():
run_tests()

View File

@ -937,5 +937,5 @@ if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
if HAS_CUDA and HAS_CPU and is_big_gpu():
run_tests()

View File

@ -268,7 +268,7 @@ if RUN_GPU:
from torch._inductor.utils import is_big_gpu
if GPU_TYPE == "cuda" and is_big_gpu(0):
if GPU_TYPE == "cuda" and is_big_gpu():
skip_list = ["test_addmm", "test_linear_relu"]
# need to skip instead of omit, otherwise fbcode ci can be flaky
for test_name in skip_list:

View File

@ -974,5 +974,5 @@ if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu
# Set env to make it work in CI.
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
if HAS_CUDA and HAS_CPU and is_big_gpu():
run_tests()

View File

@ -22,7 +22,7 @@ from torch.testing._internal.inductor_utils import HAS_CUDA
class PadMMTest(TestCase):
def setUp(self):
super().setUp()
if not is_big_gpu(0):
if not is_big_gpu():
return self.skipTest("Need a big GPU to run max_autotune=True")
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")

View File

@ -49,7 +49,7 @@ def patches(fn):
class TestSelectAlgorithm(TestCase):
def setUp(self):
super().setUp()
if not is_big_gpu(0):
if not is_big_gpu():
return self.skipTest("Need a big GPU to run max_autotune=True")
@patches
@ -361,5 +361,5 @@ class TestSelectAlgorithm(TestCase):
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
if IS_LINUX and HAS_CUDA and is_big_gpu():
run_tests()

View File

@ -300,5 +300,5 @@ instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu(0)):
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()):
run_tests()

View File

@ -47,6 +47,7 @@ from unittest import mock
import sympy
import torch
from torch._inductor.runtime.hints import DeviceProperties
if TYPE_CHECKING:
@ -1113,12 +1114,18 @@ class DelayReplaceLine(DeferredLineBase):
@functools.lru_cache(None)
def is_big_gpu(index) -> bool:
prop = torch.cuda.get_device_properties(index)
def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
if isinstance(index_or_device, torch.device):
device = index_or_device
else:
device = torch.device("cuda", index_or_device)
prop = DeviceProperties.create(device)
# SM logic is not relevant to ROCm gpus
# Arbitrarily skipping the older models
if torch.version.hip:
assert prop.major is not None
if prop.major < 9 or prop.major == 10:
log.warning("GPU arch does not support max_autotune_gemm mode usage")
return False
@ -1145,7 +1152,7 @@ def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) ->
return (
layout.device.type == "cuda"
and layout.dtype in allowed_layout_dtypes
and is_big_gpu(layout.device.index or 0)
and is_big_gpu(layout.device)
)

View File

@ -127,4 +127,4 @@ IS_H100 = LazyVal(
and get_gpu_shared_memory() == 232448
)
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu(0))
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())