[ROCm] Enable BF16 NCHW Mixed batchnorm on MIOpen if ROCm>=6.4 (#154611)

This PR enables MIOpen for BF16 NCHW Mixed batchnorm if MIOpen version >=3.4 (ROCm >= 6.4)

CUDAHooks::versionMIOpen() was added to detect MIOpen version

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154611
Approved by: https://github.com/jeffdaily, https://github.com/jithunnair-amd
This commit is contained in:
Dmitry Nikolaev
2025-06-19 17:22:37 +00:00
committed by PyTorch MergeBot
parent 085f270a00
commit f402eed4d9
5 changed files with 25 additions and 8 deletions

View File

@ -331,6 +331,16 @@ long CUDAHooks::versionCuDNN() const {
#endif
}
long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 +
MIOPEN_VERSION_MINOR * 100 +
MIOPEN_VERSION_PATCH;
#else
TORCH_CHECK(false, "Cannot query MIOpen version if ATen_cuda is not built with ROCm");
#endif
}
long CUDAHooks::versionCUDART() const {
#ifdef CUDART_VERSION
return CUDART_VERSION;

View File

@ -46,6 +46,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
long versionMIOpen() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;

View File

@ -162,6 +162,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionMIOpen() const {
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionCUDART() const {
TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -521,17 +521,17 @@ BatchNormBackend _select_batch_norm_backend(
}
if (
input.is_cuda()
detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.dim() >= 3
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16)
&& weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& (input.dim() >= 3)
&& detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
) {

View File

@ -79,7 +79,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->scalar_type() != ScalarType::Half) {
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
@ -186,7 +188,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
if (input->scalar_type() == ScalarType::Half) {
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});