mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Refactor is_big_gpu (#142220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142220 Approved by: https://github.com/yanboliang ghstack dependencies: #142219, #142033, #142222
This commit is contained in:
committed by
PyTorch MergeBot
parent
dc7461d6f5
commit
e343f46464
@ -242,7 +242,7 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not is_big_gpu(0):
|
||||
if not is_big_gpu():
|
||||
return self.skipTest("Need a big GPU to run max_autotune=True")
|
||||
|
||||
def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):
|
||||
|
@ -456,5 +456,5 @@ if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
# Set env to make it work in CI.
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu():
|
||||
run_tests()
|
||||
|
@ -937,5 +937,5 @@ if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
# Set env to make it work in CI.
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu():
|
||||
run_tests()
|
||||
|
@ -268,7 +268,7 @@ if RUN_GPU:
|
||||
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
if GPU_TYPE == "cuda" and is_big_gpu(0):
|
||||
if GPU_TYPE == "cuda" and is_big_gpu():
|
||||
skip_list = ["test_addmm", "test_linear_relu"]
|
||||
# need to skip instead of omit, otherwise fbcode ci can be flaky
|
||||
for test_name in skip_list:
|
||||
|
@ -974,5 +974,5 @@ if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
# Set env to make it work in CI.
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu(0):
|
||||
if HAS_CUDA and HAS_CPU and is_big_gpu():
|
||||
run_tests()
|
||||
|
@ -22,7 +22,7 @@ from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
class PadMMTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not is_big_gpu(0):
|
||||
if not is_big_gpu():
|
||||
return self.skipTest("Need a big GPU to run max_autotune=True")
|
||||
|
||||
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
|
||||
|
@ -49,7 +49,7 @@ def patches(fn):
|
||||
class TestSelectAlgorithm(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not is_big_gpu(0):
|
||||
if not is_big_gpu():
|
||||
return self.skipTest("Need a big GPU to run max_autotune=True")
|
||||
|
||||
@patches
|
||||
@ -361,5 +361,5 @@ class TestSelectAlgorithm(TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
|
||||
if IS_LINUX and HAS_CUDA and is_big_gpu():
|
||||
run_tests()
|
||||
|
@ -300,5 +300,5 @@ instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu(0)):
|
||||
if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu()):
|
||||
run_tests()
|
||||
|
@ -47,6 +47,7 @@ from unittest import mock
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1113,12 +1114,18 @@ class DelayReplaceLine(DeferredLineBase):
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def is_big_gpu(index) -> bool:
|
||||
prop = torch.cuda.get_device_properties(index)
|
||||
def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
|
||||
if isinstance(index_or_device, torch.device):
|
||||
device = index_or_device
|
||||
else:
|
||||
device = torch.device("cuda", index_or_device)
|
||||
|
||||
prop = DeviceProperties.create(device)
|
||||
|
||||
# SM logic is not relevant to ROCm gpus
|
||||
# Arbitrarily skipping the older models
|
||||
if torch.version.hip:
|
||||
assert prop.major is not None
|
||||
if prop.major < 9 or prop.major == 10:
|
||||
log.warning("GPU arch does not support max_autotune_gemm mode usage")
|
||||
return False
|
||||
@ -1145,7 +1152,7 @@ def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) ->
|
||||
return (
|
||||
layout.device.type == "cuda"
|
||||
and layout.dtype in allowed_layout_dtypes
|
||||
and is_big_gpu(layout.device.index or 0)
|
||||
and is_big_gpu(layout.device)
|
||||
)
|
||||
|
||||
|
||||
|
@ -127,4 +127,4 @@ IS_H100 = LazyVal(
|
||||
and get_gpu_shared_memory() == 232448
|
||||
)
|
||||
|
||||
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu(0))
|
||||
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
|
||||
|
Reference in New Issue
Block a user