mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
skip various unit tests for Jetson (#122531)
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531 Approved by: https://github.com/eqy, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
aaad0554b4
commit
1cf62e86a4
@ -4,9 +4,11 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from torch.testing._internal.common_cuda import IS_JETSON
|
||||
|
||||
if torch.cuda.is_available() and not IS_JETSON:
|
||||
torch.cuda.memory._set_allocator_settings('expandable_segments:True')
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filepath = os.path.join(current_dir, 'test_cuda.py')
|
||||
exec(compile(open(filepath).read(), filepath, mode='exec'))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
filepath = os.path.join(current_dir, 'test_cuda.py')
|
||||
exec(compile(open(filepath).read(), filepath, mode='exec'))
|
||||
|
@ -14,6 +14,7 @@ import torch.cuda
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.hooks
|
||||
from torch.nn import Parameter
|
||||
from torch.testing._internal.common_cuda import IS_JETSON
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
@ -36,12 +37,15 @@ load_tests = load_tests
|
||||
TEST_REPEATS = 30
|
||||
HAS_SHM_FILES = os.path.isdir("/dev/shm")
|
||||
MAX_WAITING_TIME_IN_SECONDS = 30
|
||||
|
||||
TEST_CUDA_IPC = (
|
||||
torch.cuda.is_available()
|
||||
and sys.platform != "darwin"
|
||||
and sys.platform != "win32"
|
||||
and not IS_JETSON
|
||||
and not TEST_WITH_ROCM
|
||||
) # https://github.com/pytorch/pytorch/issues/90940
|
||||
|
||||
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
|
||||
|
||||
if TEST_CUDA_IPC:
|
||||
|
@ -40,7 +40,7 @@ from torch._dynamo.testing import CompileCounterWithBackend
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION
|
||||
@ -2570,6 +2570,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
# verified passing successfully on H100
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
|
||||
@ -2671,6 +2672,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512])
|
||||
@ -2788,6 +2790,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
|
||||
|
@ -31,6 +31,8 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
||||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||
|
||||
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
|
||||
|
||||
def evaluate_gfx_arch_exact(matching_arch):
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
Reference in New Issue
Block a user