Revert "[cuDNN][64-bit indexing] update conv depthwise 64bit indexing dispatch condition to match native kernel (#156140)"

This reverts commit a5f59cc2eab3a5201712c52fe48c268357ba4f3c.

Reverted https://github.com/pytorch/pytorch/pull/156140 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/156140#issuecomment-2988441548))
This commit is contained in:
PyTorch MergeBot
2025-06-19 15:09:29 +00:00
parent ab3393e923
commit 317af4c87b
2 changed files with 4 additions and 16 deletions

View File

@ -3,7 +3,6 @@
#include <ATen/Config.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/ConvolutionMM3d.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/Pool.h>
@ -464,7 +463,7 @@ struct ConvParams {
return true;
}
// native kernel doesn't support 64-bit non-splittable case
if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"

View File

@ -4057,22 +4057,11 @@ class TestConvolutionNNDeviceType(NNTestCase):
@largeTensorTest("20GB")
@largeTensorTest("80GB", "cpu")
def test_depthwise_conv_64bit_indexing(self, device):
x = torch.randn(1, 2, 32800, 32800, dtype=torch.bfloat16).to(
memory_format=torch.channels_last
)
c = nn.Conv2d(
2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
x = torch.randn(1, 2, 32800, 32800)
c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2)
yref = c(x)
y = c.to(device=device)(x.to(device=device))
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
del y, yref
# try a batch-splittable case
x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last)
yref = c(x)
y = c.to(device=device)(x.to(device=device))
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
self.assertEqual(yref, y)
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)