skip various unit tests for Jetson (#122531)

skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy, https://github.com/malfet
This commit is contained in:
Fuzzkatt
2024-04-16 01:26:22 +00:00
committed by PyTorch MergeBot
parent aaad0554b4
commit 1cf62e86a4
4 changed files with 16 additions and 5 deletions

View File

@ -4,9 +4,11 @@
import os
import torch
if torch.cuda.is_available():
from torch.testing._internal.common_cuda import IS_JETSON
if torch.cuda.is_available() and not IS_JETSON:
torch.cuda.memory._set_allocator_settings('expandable_segments:True')
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, 'test_cuda.py')
exec(compile(open(filepath).read(), filepath, mode='exec'))
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, 'test_cuda.py')
exec(compile(open(filepath).read(), filepath, mode='exec'))

View File

@ -14,6 +14,7 @@ import torch.cuda
import torch.multiprocessing as mp
import torch.utils.hooks
from torch.nn import Parameter
from torch.testing._internal.common_cuda import IS_JETSON
from torch.testing._internal.common_utils import (
IS_MACOS,
IS_WINDOWS,
@ -36,12 +37,15 @@ load_tests = load_tests
TEST_REPEATS = 30
HAS_SHM_FILES = os.path.isdir("/dev/shm")
MAX_WAITING_TIME_IN_SECONDS = 30
TEST_CUDA_IPC = (
torch.cuda.is_available()
and sys.platform != "darwin"
and sys.platform != "win32"
and not IS_JETSON
and not TEST_WITH_ROCM
) # https://github.com/pytorch/pytorch/issues/90940
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1
if TEST_CUDA_IPC:

View File

@ -40,7 +40,7 @@ from torch._dynamo.testing import CompileCounterWithBackend
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
from torch.testing._internal.common_cuda import (
SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION
@ -2570,6 +2570,7 @@ class TestSDPACudaOnly(NNTestCase):
# verified passing successfully on H100
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
@ -2671,6 +2672,7 @@ class TestSDPACudaOnly(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512])
@parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512])
@ -2788,6 +2790,7 @@ class TestSDPACudaOnly(NNTestCase):
atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])

View File

@ -31,6 +31,8 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
def evaluate_gfx_arch_exact(matching_arch):
if not torch.cuda.is_available():
return False