[CUDNN] Remove defunct cuDNN V8 API build flag (#120006)

The flag basically does nothing following #95722

Let's see if the quantization tests break

CC @malfet @atalmanagement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120006
Approved by: https://github.com/malfet
This commit is contained in:
eqy
2024-05-06 23:13:54 +00:00
committed by PyTorch MergeBot
parent b98c689261
commit ee4cafa098
3 changed files with 11 additions and 28 deletions

View File

@ -74,7 +74,6 @@ function(caffe2_print_configuration_summary)
message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}")
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")

View File

@ -33,7 +33,6 @@ default_compiler_flags = [
"-DTH_INDEX_BASE=0",
"-DMAGMA_V2",
"-DNO_CUDNN_DESTROY_HANDLE",
"-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api
"-DUSE_FBGEMM",
"-DUSE_QNNPACK",
"-DUSE_PYTORCH_QNNPACK",

View File

@ -21,6 +21,7 @@ from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
@ -31,7 +32,7 @@ from torch.testing._internal.common_quantized import (
qengine_is_onednn,
)
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA
from torch.testing._internal.optests import opcheck
import torch.backends.xnnpack
@ -905,9 +906,7 @@ class TestQuantizedOps(TestCase):
"""Tests the correctness of the cudnn add and add_relu op
(Similar to test_qadd_relu_different_qparams, will probably merge in the future)"""
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the test_qadd_relu_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qadd_relu_cudnn(self):
dtype = torch.qint8
add_relu = torch.ops.quantized.add_relu
@ -940,9 +939,7 @@ class TestQuantizedOps(TestCase):
"""Tests the correctness of the cudnn add and add_relu op for nhwc format"""
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the test_qadd_relu_cudnn_nhwc op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qadd_relu_cudnn_nhwc(self):
dtype = torch.qint8
add_relu = torch.ops.quantized.add_relu
@ -1379,7 +1376,7 @@ class TestQuantizedOps(TestCase):
self.assertEqual(a_ref, a_hat.dequantize(),
msg="ops.quantized.max_pool1d results are off")
# TODO: merge this test with test_max_pool2d when USE_EXPERIMENTAL_CUDNN_V8_API flag is enabled in CI
# TODO: merge this test with test_max_pool2d
"""Tests 2D cudnn max pool operation on quantized tensors."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=1, max_side=10),
@ -1394,9 +1391,7 @@ class TestQuantizedOps(TestCase):
padding=st.integers(0, 2),
ceil_mode=st.booleans())
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qconv2d_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(TEST_CUDNN_VERSION <= 90100, "cuDNN maxpool2d mishandles -128 before v90100")
def test_max_pool2d_cudnn(self, X, kernel, stride, dilation, padding, ceil_mode):
X, (scale, zero_point, torch_type) = X
assume(kernel // 2 >= padding) # Kernel cannot be overhanging!
@ -4050,9 +4045,7 @@ class TestQuantizedLinear(TestCase):
use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qlinear_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
# TODO: check with yang regarding CUDNN flags
def test_qlinear_cudnn(self, batch_size, input_channels, output_channels, use_bias,
use_relu, use_multi_dim_input, use_channelwise):
@ -5427,9 +5420,7 @@ class TestQuantizedConv(TestCase):
use_channelwise=st.sampled_from([False]))
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qconv2d_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qconv2d_cudnn(
self,
batch_size,
@ -5510,9 +5501,7 @@ class TestQuantizedConv(TestCase):
use_channelwise=st.sampled_from([False]))
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qconv2d_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qconv2d_relu_cudnn(
self,
batch_size,
@ -6245,9 +6234,7 @@ class TestQuantizedConv(TestCase):
use_channelwise=st.sampled_from([False]))
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qconv1d_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qconv1d_cudnn(
self,
batch_size,
@ -6319,9 +6306,7 @@ class TestQuantizedConv(TestCase):
use_channelwise=st.sampled_from([False]))
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skip("Local only - currently the qconv1d_cudnn op is bulid "
"with USE_EXPERIMENTAL_CUDNN_V8_API, we can enable the test "
"after it is built by default")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
def test_qconv1d_relu_cudnn(
self,
batch_size,