mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] fix miopen batchnorm changing output format (#162112)
It was found that the integration of miopen batchnorm was causing the output to always be in default contig memory format even when the input was channels last. This also unskips a number of related unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162112 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com> Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac72f81c12
commit
d65ffdef3d
@ -624,7 +624,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
|
||||
if (backend == BatchNormBackend::Miopen) {
|
||||
return std::tuple_cat(
|
||||
at::miopen_batch_norm(
|
||||
input.contiguous(), weight.contiguous(), bias.contiguous(),
|
||||
input.contiguous(input.suggest_memory_format()),
|
||||
weight.contiguous(),
|
||||
bias.contiguous(),
|
||||
running_mean.defined() ? running_mean.contiguous() : running_mean,
|
||||
running_var.defined() ? running_var.contiguous() : running_var,
|
||||
training, momentum, eps),
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/miopen_batch_norm_native.h>
|
||||
#include <ATen/ops/miopen_batch_norm_backward_native.h>
|
||||
#endif
|
||||
@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
|
||||
mode = miopenBNSpatial;
|
||||
}
|
||||
|
||||
auto output_t = at::empty(input->sizes(), input->options());
|
||||
auto output_t = at::empty_like(input_t, input_t.options(), input_t.suggest_memory_format());
|
||||
TensorArg output{ output_t, "output", 0 };
|
||||
|
||||
auto handle = getMiopenHandle();
|
||||
@ -170,20 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
|
||||
const std::optional<Tensor>& save_var_t_opt,
|
||||
double epsilon) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
const Tensor& running_mean =
|
||||
running_mean_opt.value_or(Tensor());
|
||||
const Tensor& running_var =
|
||||
running_var_opt.value_or(Tensor());
|
||||
const Tensor& save_mean_t =
|
||||
save_mean_t_opt.value_or(Tensor());
|
||||
const Tensor& save_var_t =
|
||||
save_var_t_opt.value_or(Tensor());
|
||||
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
|
||||
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
|
||||
|
||||
TensorArg input{ input_t, "input", 1 },
|
||||
grad_output{ grad_output_t, "grad_output", 2 },
|
||||
weight{ weight_t, "weight", 3 },
|
||||
save_mean{ save_mean_t, "save_mean", 4 },
|
||||
save_var{ save_var_t, "save_var", 5 };
|
||||
auto grad_output_contig =
|
||||
grad_output_t.contiguous(input_t.suggest_memory_format());
|
||||
TensorArg input{input_t, "input", 1},
|
||||
grad_output{grad_output_contig, "grad_output", 2},
|
||||
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
|
||||
save_var{save_var_t, "save_var", 5};
|
||||
CheckedFrom c = "miopen_batch_norm_backward";
|
||||
|
||||
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
|
||||
@ -195,7 +191,11 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
|
||||
}
|
||||
checkAllSameType(c, {input, grad_output});
|
||||
checkAllSameType(c, {weight, save_mean, save_var});
|
||||
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
|
||||
// TODO: is weight required to be contiguous?
|
||||
checkAllContiguous(c, {save_mean, save_var});
|
||||
// TODO: TensorArg check should start handle memory format
|
||||
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
|
||||
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
|
||||
checkDimRange(c, input, 2, 6 /* exclusive */);
|
||||
checkSameSize(c, input, grad_output);
|
||||
auto num_features = input->size(1);
|
||||
@ -210,7 +210,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
|
||||
mode = miopenBNSpatial;
|
||||
}
|
||||
|
||||
auto grad_input_t = at::empty(input->sizes(), input->options());
|
||||
auto grad_input_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
|
||||
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
|
||||
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
|
||||
|
||||
|
@ -468,13 +468,6 @@ class TestOperators(TestCase):
|
||||
), # Works on ROCm
|
||||
xfail("torch.ops.aten._flash_attention_forward"),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"),
|
||||
# RuntimeError: Expected contiguous tensor, but got
|
||||
# non-contiguous tensor for argument #2 'grad_output'
|
||||
decorate(
|
||||
"_batch_norm_with_update",
|
||||
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
||||
device_type="cuda",
|
||||
),
|
||||
}
|
||||
),
|
||||
)
|
||||
@ -2400,13 +2393,6 @@ class TestOperators(TestCase):
|
||||
skip("sparse.sampled_addmm", ""),
|
||||
skip("sparse.mm", "reduce"),
|
||||
skip("native_layer_norm", "", device_type="cpu"),
|
||||
# RuntimeError: Expected contiguous tensor, but got
|
||||
# non-contiguous tensor for argument #2 'grad_output'
|
||||
decorate(
|
||||
"_batch_norm_with_update",
|
||||
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
||||
device_type="cuda",
|
||||
),
|
||||
},
|
||||
)
|
||||
@opsToleranceOverride(
|
||||
|
@ -30,7 +30,6 @@ from torch.testing._internal.common_device_type import (
|
||||
skipCUDAIfMiopen,
|
||||
skipCUDAIfNoCudnn,
|
||||
skipCUDAIfNoMiopen,
|
||||
skipCUDAIfNotMiopenSuggestNHWC,
|
||||
skipCUDAIfRocm,
|
||||
skipMeta,
|
||||
skipMPS,
|
||||
@ -51,8 +50,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize as parametrize_test,
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
skipIfNotMiopenSuggestNHWC,
|
||||
skipIfRocmVersionLessThan,
|
||||
subtest,
|
||||
TEST_SCIPY,
|
||||
TEST_WITH_ROCM,
|
||||
@ -64,6 +61,7 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
|
||||
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
|
||||
|
||||
|
||||
if TEST_SCIPY:
|
||||
@ -715,7 +713,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
# Almost identical to the above `test_Conv2d_naive_groups`
|
||||
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
||||
@tf32_on_and_off(0.001)
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
|
||||
def test_Conv2d_groups_nobias(self):
|
||||
dev_dtypes = [("cpu", torch.float)]
|
||||
if TEST_CUDA:
|
||||
@ -761,7 +758,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
|
||||
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
||||
@tf32_on_and_off(0.001)
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
|
||||
def test_Conv2d_groups_nobias_v2(self):
|
||||
torch.manual_seed(123)
|
||||
dev_dtypes = [("cpu", torch.float)]
|
||||
@ -896,7 +892,6 @@ class TestConvolutionNN(NNTestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
|
||||
@skipIfNotMiopenSuggestNHWC
|
||||
def test_grouped_conv_cudnn_nhwc_support(self):
|
||||
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
|
||||
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
|
||||
@ -3146,7 +3141,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("12GB")
|
||||
@skipIfRocmVersionLessThan((6, 0))
|
||||
def test_conv_transposed_large(self, device):
|
||||
dtype = torch.half if self.device_type == "cuda" else torch.float
|
||||
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
|
||||
@ -3190,7 +3184,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
self.assertEqual(maxdiff3, 0)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@largeTensorTest("12GB")
|
||||
def test_conv_large(self, device):
|
||||
dtype = torch.half if self.device_type == "cuda" else torch.float
|
||||
@ -3223,7 +3216,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@largeTensorTest("20GB", "cpu")
|
||||
@largeTensorTest("60GB", "cuda")
|
||||
def test_conv_large_batch_1(self, device):
|
||||
@ -3360,7 +3352,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
@dtypes(torch.float)
|
||||
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
|
||||
@tf32_on_and_off(0.001)
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
|
||||
def test_Conv2d_naive_groups(self, device, dtype):
|
||||
# Check that grouped convolutions matches two half convolutions
|
||||
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
|
||||
@ -3629,19 +3620,21 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotMiopenSuggestNHWC
|
||||
@dtypes(torch.half, torch.float, torch.cfloat)
|
||||
def test_conv_cudnn_nhwc(self, device, dtype):
|
||||
def helper(n, c, h, w, out_channels, kernel_size, groups):
|
||||
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
# randint with dtype=torch.cfloat fails with
|
||||
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
|
||||
# must create randint and randint_like using default int64, then cast to desired
|
||||
input = torch.randint(
|
||||
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
|
||||
).to(dtype, memory_format=torch.channels_last)
|
||||
input.requires_grad_()
|
||||
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
|
||||
device="cuda", dtype=dtype, memory_format=torch.channels_last
|
||||
)
|
||||
for p in conv.parameters():
|
||||
p.data = torch.randint_like(p, -3, 3)
|
||||
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
|
||||
|
||||
# use FP64 channels-first conv as reference
|
||||
ref_input = input.detach().clone().contiguous().double().requires_grad_()
|
||||
@ -3655,7 +3648,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
out = conv(input)
|
||||
ref_out = ref_conv(ref_input)
|
||||
|
||||
grad = torch.randint_like(out, -3, 3)
|
||||
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
|
||||
ref_grad = grad.detach().clone().double().contiguous()
|
||||
|
||||
out.backward(grad)
|
||||
@ -3682,7 +3675,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.half, torch.float)
|
||||
def test_conv_cudnn_ndhwc(self, device, dtype):
|
||||
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
|
||||
@ -3812,7 +3804,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotMiopenSuggestNHWC
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_conv_cudnn_mismatch_memory_format(self, device):
|
||||
configs = [
|
||||
@ -3945,7 +3936,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_convert_conv2d_weight_memory_format(self, device):
|
||||
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
|
||||
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
|
||||
@ -3965,7 +3955,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
||||
self.assertTrue(out.is_contiguous(memory_format=memory_format))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_convert_conv3d_weight_memory_format(self, device):
|
||||
input = torch.randint(
|
||||
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
|
||||
|
@ -62,6 +62,7 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
|
||||
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -3514,7 +3515,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
|
||||
@skipIfRocm
|
||||
def test_cudnn_weight_format(self):
|
||||
rnns = [
|
||||
nn.LSTM(10, 20, batch_first=True),
|
||||
@ -3522,7 +3522,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
nn.GRU(10, 20, batch_first=True),
|
||||
nn.RNN(10, 20, batch_first=True)
|
||||
]
|
||||
first_warn = True
|
||||
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
|
||||
first_warn = False if torch.version.hip else True
|
||||
for rnn in rnns:
|
||||
rnn.cuda()
|
||||
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
|
||||
@ -5171,24 +5172,38 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
("NCHW", "native", False, torch.float),
|
||||
("NCHW", "native", True, torch.half),
|
||||
("NCHW", "native", True, torch.bfloat16),
|
||||
|
||||
("NHWC", "cpu", False, torch.float),
|
||||
("NHWC", "cpu", True, torch.half),
|
||||
("NHWC", "cpu", True, torch.bfloat16),
|
||||
|
||||
("NHWC", "native", False, torch.float),
|
||||
("NHWC", "native", True, torch.half),
|
||||
("NHWC", "native", True, torch.bfloat16),
|
||||
|
||||
("NHWC", "NCHW", False, torch.float),
|
||||
("NHWC", "NCHW", True, torch.half),
|
||||
("NHWC", "NCHW", True, torch.bfloat16),
|
||||
],
|
||||
name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}"
|
||||
)
|
||||
def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype):
|
||||
if torch.version.cuda:
|
||||
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
|
||||
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16"):
|
||||
self.skipTest("bfloat16 NHWC train failed on CUDA due to native tolerance issue "
|
||||
"https://github.com/pytorch/pytorch/issues/156513")
|
||||
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
|
||||
self.skipTest("Batchnorm 3D NHWC train failed on CUDA")
|
||||
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
|
||||
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
|
||||
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16",
|
||||
"test_batchnorm_3D_train_NCHW_vs_native_mixed_float16"):
|
||||
self.skipTest("Failed on CUDA")
|
||||
|
||||
if torch.version.hip:
|
||||
if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16",
|
||||
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16") \
|
||||
"test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16",
|
||||
"test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16",
|
||||
"test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16") \
|
||||
and _get_torch_rocm_version() < (6, 4):
|
||||
# NCHW bfloat16 path uses native kernels for rocm<=6.3
|
||||
# train failed on rocm<=6.3 due to native tolerance issue
|
||||
# train failed on rocm<=6.3 due to native accuracy issue
|
||||
# https://github.com/pytorch/pytorch/issues/156513
|
||||
self.skipTest("bfloat16 NHWC train failed on ROCm <= 6.3")
|
||||
|
||||
@ -5198,9 +5213,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
# https://github.com/pytorch/pytorch/issues/156513
|
||||
self.skipTest("bfloat16 NCHW train failed due to native tolerance issue")
|
||||
|
||||
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16" \
|
||||
and _get_torch_rocm_version() < (7, 0):
|
||||
self.skipTest("3D float16 NCHW train failed on ROCm<7.0")
|
||||
if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16":
|
||||
self.skipTest("3D float16 NCHW train failed on ROCm")
|
||||
|
||||
if dims == 3 and memory_format in ("NHWC", "NCHW"):
|
||||
memory_format = memory_format + "3D"
|
||||
|
@ -2801,7 +2801,7 @@
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<c10::SymInt>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
|
||||
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)
|
||||
|
||||
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
|
||||
|
Reference in New Issue
Block a user