Enable CI on SM89 (#140305)

Using EC2 G6 instance, based on NVIDIA L4, added to scale config in https://github.com/pytorch/test-infra/pull/5376

To enable more balanced sharding, had to push 148ae19935

Added `@xfailIfSM89` to the following tests:
 - test_fp8_pattern_2
 - test_original_aten_preserved_split_addmm
 - test_sparse_semi_structured_scaled_mm
 - test_sparse_semi_structured_scaled_mm_fp8
 - test_sparse_fp8fp8_mm

Increased tolerance to 2e-4 for `RNNTest.BidirectionalMultilayerGRU_CPU_vs_CUDA`

Skipped following inductor tests (that either flaky OOMs or timeouts):
 - test_reduction_fn_std_float64
 - test_reduction_fn_var_mean_float64
 - test_multi_output_unbacked_custom_op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140305
Approved by: https://github.com/wdvr, https://github.com/ZainRizvi
This commit is contained in:
Nikita Shulga
2024-12-03 04:49:46 +00:00
committed by PyTorch MergeBot
parent af88326250
commit 38bbe37187
9 changed files with 74 additions and 21 deletions

View File

@ -476,35 +476,35 @@ jobs:
]}
secrets: inherit
linux-focal-cuda12_4-py3_10-gcc9-sm86-build:
name: linux-focal-cuda12.4-py3.10-gcc9-sm86
linux-focal-cuda12_4-py3_10-gcc9-sm89-build:
name: linux-focal-cuda12.4-py3.10-gcc9-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm89
docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9
cuda-arch-list: 8.6
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-focal-cuda12_4-py3_10-gcc9-sm86-test:
name: linux-focal-cuda12.4-py3.10-gcc9-sm86
linux-focal-cuda12_4-py3_10-gcc9-sm89-test:
name: linux-focal-cuda12.4-py3.10-gcc9-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-focal-cuda12_4-py3_10-gcc9-sm86-build
- linux-focal-cuda12_4-py3_10-gcc9-sm89-build
- target-determination
with:
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }}
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm89
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3-clang12-executorch-build:

View File

@ -3,6 +3,9 @@
#include <torch/torch.h>
#include <test/cpp/api/support.h>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
using namespace torch::nn;
using namespace torch::test;
@ -552,6 +555,15 @@ TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
}
TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
#ifdef USE_CUDA
// Get device properties
const auto prop = at::cuda::getCurrentDeviceProperties();
// TODO: Investigate why results on sm89 are much less accurate
// See https://github.com/pytorch/pytorch/issues/141915
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
#else
constexpr auto tolerance = 1e-5;
#endif
// Create two GRUs with the same options
auto opt =
GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
@ -600,13 +612,22 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
ASSERT_NEAR(
std::get<0>(output_cpu)[i][j][k].item<float>(),
std::get<0>(output_cuda)[i][j][k].item<float>(),
1e-5);
tolerance);
}
}
}
}
TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
#ifdef USE_CUDA
// Get device properties
const auto prop = at::cuda::getCurrentDeviceProperties();
// TODO: Investigate why results on sm89 are much less accurate
// See https://github.com/pytorch/pytorch/issues/141915
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
#else
constexpr auto tolerance = 1e-5;
#endif
// Create two LSTMs with the same options
auto opt =
LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
@ -654,13 +675,22 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
ASSERT_NEAR(
std::get<0>(output_cpu)[i][j][k].item<float>(),
std::get<0>(output_cuda)[i][j][k].item<float>(),
1e-5);
tolerance);
}
}
}
}
TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
#ifdef USE_CUDA
// Get device properties
const auto prop = at::cuda::getCurrentDeviceProperties();
// TODO: Investigate why results on sm89 are much less accurate
// See https://github.com/pytorch/pytorch/issues/141915
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
#else
constexpr auto tolerance = 1e-5;
#endif
// Create two LSTMs with the same options
auto opt = LSTMOptions(2, 4)
.num_layers(3)
@ -711,7 +741,7 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
ASSERT_NEAR(
std::get<0>(output_cpu)[i][j][k].item<float>(),
std::get<0>(output_cuda)[i][j][k].item<float>(),
1e-5);
tolerance);
}
}
}

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: inductor"]
import unittest
from typing import Any, Dict, List, Type
import sympy
@ -11,6 +12,7 @@ from torch._inductor.codegen.simd_kernel_features import SIMDKernelFeatures
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_cuda import IS_SM89
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -60,6 +62,9 @@ class CooperativeReductionTests(TestCase):
)
@parametrize("dtype", [torch.float16, torch.float32, torch.float64])
def test_reduction_fns(self, name, dtype):
if IS_SM89 and dtype == torch.float64 and name in ["std", "var_mean"]:
raise unittest.SkipTest("Timeouts on SM89")
def fn(x, y):
return reduction_fn(x + y, dim=-1)

View File

@ -13,6 +13,7 @@ from torch._inductor.codecache import PyCodeCache
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import xfailIfSM89
from torch.testing._internal.common_device_type import expectedFailureXPU
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
@ -384,6 +385,7 @@ class TestKernelBenchmark(TestCase):
self.check_bandwidth(compiled_module, "0.006")
@expectedFailureXPU
@xfailIfSM89
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_slice_mm_bandwidth_computation(self):
M, N, K = 1000, 2000, 3000

View File

@ -18,7 +18,7 @@ from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.test_operators import realize
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, xfailIfSM89
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.utils._pytree import tree_map
from torch.utils._sympy.functions import ModularIndexing
@ -406,6 +406,7 @@ class LoopOrderingTest(TestCase):
self.assertEqual(1, metrics.generated_kernel_count)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
@xfailIfSM89
def test_fp8_pattern_2(self):
"""
This test repros the fp8 fusion relation issue here:

View File

@ -33,7 +33,7 @@ from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
from torch.testing._internal.inductor_utils import (
@ -1309,6 +1309,7 @@ class TestPatternMatcher(TestCase):
self.assertTrue(pattern.pattern_eq(search_fn_pattern))
@skipIfXpu
@xfailIfSM89
@inductor_config.patch(
{
"triton.unique_kernel_names": "original_aten",

View File

@ -20,6 +20,7 @@ from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import IS_SM89
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -575,6 +576,10 @@ class TestInductorDynamic(TestCase):
f(torch.tensor([3], device=device))
@unittest.skipIf(
IS_SM89,
"Fails(with OOMS) on SM89, see https://github.com/pytorch/pytorch/issues/141915",
)
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
)

View File

@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, xfailIfSM89
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
@ -1047,6 +1047,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
self.skipTest('cuSPARSELt not enabled')
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@xfailIfSM89
@parametrize("dense_input_shape", [(256, 128)])
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
if torch.backends.cusparselt.version() < 602:
@ -1066,6 +1067,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
dense_result = torch.mm(A_fp8_sparse, B_fp8)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@xfailIfSM89
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
(k, l, m) = (32, 64, 32)
x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device)
@ -1082,6 +1084,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
@xfailIfSM89
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("dense_input_shape", [(256, 128)])
def test_sparse_semi_structured_scaled_mm(

View File

@ -9,6 +9,7 @@ from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_
import inspect
import contextlib
import os
import unittest
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
@ -33,6 +34,7 @@ SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
def CDNA2OrLater():
if TEST_WITH_ROCM:
@ -316,6 +318,10 @@ def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.
) + (data, loss_fn, skip_iter)
def xfailIfSM89(func):
return func if not IS_SM89 else unittest.expectedFailure(func)
# Importing this module should NOT eagerly initialize CUDA
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
assert not torch.cuda.is_initialized()