mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Enable BF16 NCHW Mixed batchnorm on MIOpen if ROCm>=6.4 (#154611)
This PR enables MIOpen for BF16 NCHW Mixed batchnorm if MIOpen version >=3.4 (ROCm >= 6.4) CUDAHooks::versionMIOpen() was added to detect MIOpen version Pull Request resolved: https://github.com/pytorch/pytorch/pull/154611 Approved by: https://github.com/jeffdaily, https://github.com/jithunnair-amd
This commit is contained in:
committed by
PyTorch MergeBot
parent
085f270a00
commit
f402eed4d9
@ -331,6 +331,16 @@ long CUDAHooks::versionCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionMIOpen() const {
|
||||
#if AT_ROCM_ENABLED()
|
||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||
MIOPEN_VERSION_MINOR * 100 +
|
||||
MIOPEN_VERSION_PATCH;
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version if ATen_cuda is not built with ROCm");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCUDART() const {
|
||||
#ifdef CUDART_VERSION
|
||||
return CUDART_VERSION;
|
||||
|
@ -46,6 +46,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
long versionMIOpen() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
|
||||
|
@ -162,6 +162,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionMIOpen() const {
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionCUDART() const {
|
||||
TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
@ -521,17 +521,17 @@ BatchNormBackend _select_batch_norm_backend(
|
||||
}
|
||||
|
||||
if (
|
||||
input.is_cuda()
|
||||
detail::getCUDAHooks().compiledWithMIOpen()
|
||||
&& cudnn_enabled
|
||||
&& input.is_cuda()
|
||||
&& input.dim() <= MIOPEN_DIM_MAX
|
||||
&& input.dim() >= 3
|
||||
&& input.scalar_type() != at::kDouble
|
||||
&& input.scalar_type() != at::kBFloat16
|
||||
&& (weight.scalar_type() != at::kHalf)
|
||||
&& (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16)
|
||||
&& weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input
|
||||
&& weight.defined() && bias.defined()
|
||||
&& ((running_mean.defined() && running_var.defined())
|
||||
|| (!running_mean.defined() && !running_var.defined() && training))
|
||||
&& (input.dim() >= 3)
|
||||
&& detail::getCUDAHooks().compiledWithMIOpen()
|
||||
&& cudnn_enabled
|
||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
|
||||
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
|
||||
) {
|
||||
|
@ -79,7 +79,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
|
||||
checkAllDefined(c, {running_mean, running_var});
|
||||
}
|
||||
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
|
||||
if (input->scalar_type() != ScalarType::Half) {
|
||||
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
|
||||
checkScalarType(c, weight, ScalarType::Float);
|
||||
} else {
|
||||
checkAllSameType(c, {input, weight});
|
||||
}
|
||||
checkAllSameType(c, {weight, bias, running_mean, running_var});
|
||||
@ -186,7 +188,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
|
||||
|
||||
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
|
||||
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
|
||||
if (input->scalar_type() == ScalarType::Half) {
|
||||
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
|
||||
checkScalarType(c, weight, ScalarType::Float);
|
||||
} else {
|
||||
checkAllSameType(c, {input, weight});
|
||||
|
Reference in New Issue
Block a user