[ROCm] Enable MI355 CI on PRs, and run full set of UTs on PRs (#160215)

Useful to have PR testing for PRs such as https://github.com/pytorch/pytorch/pull/151360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160215
Approved by: https://github.com/malfet, https://github.com/atalman

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jithun Nair
2025-10-09 18:03:08 +00:00
committed by PyTorch MergeBot
parent 3c0577bd15
commit ee6a1ecb0a
7 changed files with 18 additions and 6 deletions

View File

@ -30,6 +30,7 @@ ciflow_push_tags:
- ciflow/riscv64
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/s390
- ciflow/slow
- ciflow/torchbench

View File

@ -1,6 +1,9 @@
name: rocm-mi355
on:
push:
tags:
- ciflow/rocm-mi355/*
workflow_dispatch:
schedule:
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
@ -64,5 +67,7 @@ jobs:
build-environment: linux-noble-rocm-py3.12-mi355
docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
tests-to-include: >-
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor test_matmul_cuda test_scaled_matmul_cuda'
|| '' }}
secrets: inherit

View File

@ -127,7 +127,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
return diff == 0 ? 0 : uint32_t(Align) - diff;
}
#if defined (__gfx90a__) || defined(__gfx942__)
#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#define CDNA2_OR_LATER 1
#else
#define CDNA2_OR_LATER 0
@ -143,7 +143,7 @@ template<typename T, uint32_t Rank>
using VecT = T __attribute__((ext_vector_type(Rank)));
static bool isCDNA2orLater(int index) {
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index);
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index);
}
#else

View File

@ -39,6 +39,8 @@ from torch.testing._internal.common_utils import (
DeterministicGuard,
freeze_rng_state,
IS_FBCODE,
MI350_ARCH,
skipIfRocmArch,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
xfailIfPy312Plus,
@ -218,6 +220,7 @@ class CudaReproTests(TestCase):
# dont check rng state
self.assertEqual(out[:2], fn(query, key, value, input_tensor2)[:2])
@skipIfRocmArch(MI350_ARCH)
def test_effn_attn_bias_padding_misaligned(self):
seqlen_start = 1008

View File

@ -31,6 +31,7 @@ from torch.testing._internal.common_utils import (
IS_LINUX,
IS_X86,
MI300_ARCH,
MI350_ARCH,
parametrize,
skipIfNoXPU,
skipIfRocm,
@ -1187,7 +1188,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfRocmArch(MI300_ARCH)
@skipIfRocmArch(MI300_ARCH + MI350_ARCH)
def test_qconv2d_int8_mixed_bf16(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
@ -1197,7 +1198,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfRocmArch(MI300_ARCH)
@skipIfRocmArch(MI300_ARCH + MI350_ARCH)
def test_qconv2d_int8_mixed_bf16_use_autocast(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.

View File

@ -13,6 +13,7 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_TORCHDYNAMO,
@ -853,7 +854,7 @@ class TestFlopCounter(TestCase):
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_scaled_mm(self):
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
dtype = e4m3_type
with FlopCounterMode() as mode:
torch._scaled_mm(
torch.randn((3 * 16, 5 * 16), device="cuda").to(dtype),

View File

@ -102,6 +102,7 @@ except ImportError:
SEED = 1234
MI350_ARCH = ("gfx950",)
MI300_ARCH = ("gfx942",)
MI200_ARCH = ("gfx90a")
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")