mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[cuDNN][64-bit indexing] update conv depthwise 64bit indexing dispatch condition to match native kernel (#156140)"
This reverts commit a5f59cc2eab3a5201712c52fe48c268357ba4f3c. Reverted https://github.com/pytorch/pytorch/pull/156140 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/156140#issuecomment-2988441548))
This commit is contained in:
@ -3,7 +3,6 @@
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/native/CanUse32BitIndexMath.h>
|
||||
#include <ATen/native/ConvolutionMM3d.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
@ -464,7 +463,7 @@ struct ConvParams {
|
||||
return true;
|
||||
}
|
||||
// native kernel doesn't support 64-bit non-splittable case
|
||||
if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
|
||||
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
|
||||
|
@ -4057,22 +4057,11 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
@largeTensorTest("20GB")
|
||||
@largeTensorTest("80GB", "cpu")
|
||||
def test_depthwise_conv_64bit_indexing(self, device):
|
||||
x = torch.randn(1, 2, 32800, 32800, dtype=torch.bfloat16).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
c = nn.Conv2d(
|
||||
2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.bfloat16
|
||||
).to(memory_format=torch.channels_last)
|
||||
x = torch.randn(1, 2, 32800, 32800)
|
||||
c = nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, groups=2)
|
||||
yref = c(x)
|
||||
y = c.to(device=device)(x.to(device=device))
|
||||
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
||||
del y, yref
|
||||
|
||||
# try a batch-splittable case
|
||||
x = x.reshape(100, 2, 3280, 3280).contiguous(memory_format=torch.channels_last)
|
||||
yref = c(x)
|
||||
y = c.to(device=device)(x.to(device=device))
|
||||
self.assertEqual(yref, y, atol=1e-3, rtol=1e-4)
|
||||
self.assertEqual(yref, y)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
|
||||
|
Reference in New Issue
Block a user