mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610 - Replace HIP_PLATFORM_HCC with USE_ROCM - Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION. - In the next PR - Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify. - HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc. cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd Reviewed By: jbschlosser Differential Revision: D30909053 Pulled By: ezyang fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
273 lines
8.2 KiB
Plaintext
273 lines
8.2 KiB
Plaintext
#include "caffe2/core/common_gpu.h"
|
|
#include "caffe2/core/context_gpu.h"
|
|
|
|
#include "caffe2/sgd/fp16_momentum_sgd_op.h"
|
|
|
|
namespace caffe2 {
|
|
namespace {
|
|
|
|
#ifdef __HIPCC__
|
|
typedef __half half;
|
|
typedef __half2 half2;
|
|
#endif
|
|
|
|
__global__ void FP16MomentumSGDKernel(
|
|
int N,
|
|
const half2* g,
|
|
const half2* m,
|
|
half2* ng,
|
|
half2* nm,
|
|
const float* lr,
|
|
const float mom,
|
|
bool nesterov,
|
|
const float wd,
|
|
half2* param) {
|
|
#if __CUDA_ARCH__ >= 530 || defined(USE_ROCM)
|
|
const float lr2 = lr[0];
|
|
const half2 LR = __float2half2_rn(lr2);
|
|
const half2 momentum = __float2half2_rn(mom);
|
|
const half2 weight_decay = __float2half2_rn(wd);
|
|
|
|
int n = N / 2;
|
|
if (!nesterov) {
|
|
CUDA_1D_KERNEL_LOOP(i, n) {
|
|
ng[i] = __hfma2(weight_decay, param[i], g[i]);
|
|
const half2 adjusted_gradient =
|
|
__hfma2(LR, ng[i], __hmul2(momentum, m[i]));
|
|
nm[i] = adjusted_gradient;
|
|
ng[i] = adjusted_gradient;
|
|
if (param) {
|
|
param[i] = __hsub2(param[i], ng[i]);
|
|
}
|
|
|
|
// odd number of elements
|
|
if (i == 0 && (N % 2)) {
|
|
half *g_half = (half*)g, *param_half = (half*)param, *m_half = (half*)m,
|
|
*nm_half = (half*)nm, *ng_half = (half*)ng;
|
|
ng_half[N - 1] =
|
|
__hfma(__high2half(weight_decay), param_half[N - 1], g_half[N - 1]);
|
|
const half adjusted_gradient_half = __hfma(
|
|
__high2half(LR),
|
|
ng_half[N - 1],
|
|
__hmul(__high2half(momentum), m_half[N - 1]));
|
|
nm_half[N - 1] = adjusted_gradient_half;
|
|
ng_half[N - 1] = adjusted_gradient_half;
|
|
if (param) {
|
|
param_half[N - 1] = __hsub(param_half[N - 1], adjusted_gradient_half);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(i, n) {
|
|
// computing the term (grad + lambda*weight)
|
|
// might need to change in case of denormalization
|
|
ng[i] = __hfma2(weight_decay, param[i], g[i]);
|
|
const half2 mi = m[i];
|
|
const half2 mom_mi = __hmul2(momentum, mi);
|
|
const half2 mi_new = __hfma2(LR, ng[i], mom_mi);
|
|
nm[i] = mi_new;
|
|
ng[i] = __hsub2(__hfma2(mi_new, momentum, mi_new), mom_mi);
|
|
|
|
if (param) {
|
|
param[i] = __hsub2(param[i], ng[i]);
|
|
}
|
|
|
|
// odd number of elements
|
|
if (i == 0 && (N % 2)) {
|
|
half *g_half = (half*)g, *param_half = (half*)param, *m_half = (half*)m,
|
|
*nm_half = (half*)nm, *ng_half = (half*)ng;
|
|
ng_half[N - 1] =
|
|
__hfma(__high2half(weight_decay), param_half[N - 1], g_half[N - 1]);
|
|
const half mi_half = m_half[N - 1];
|
|
const half mom_mi_half = __hmul(__high2half(momentum), mi_half);
|
|
const half mi_new_half =
|
|
__hfma(__high2half(LR), ng_half[N - 1], mom_mi_half);
|
|
nm_half[N - 1] = mi_new_half;
|
|
ng_half[N - 1] = __hsub(
|
|
__hfma(mi_new_half, __high2half(momentum), mi_new_half),
|
|
mom_mi_half);
|
|
if (param) {
|
|
param_half[N - 1] = __hsub(param_half[N - 1], ng_half[N - 1]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#else
|
|
CUDA_KERNEL_ASSERT(false);
|
|
#endif // CAFFE_HAS_CUDA_FP16
|
|
}
|
|
|
|
__global__ void FP16MomentumSGDFP32Kernel(
|
|
int N,
|
|
const half2* g,
|
|
const half2* m,
|
|
half2* ng,
|
|
half2* nm,
|
|
const float* lr,
|
|
const float mom,
|
|
bool nesterov,
|
|
const float wd,
|
|
half2* param) {
|
|
#if __CUDA_ARCH__ >= 530 || defined(USE_ROCM)
|
|
const float lr2 = lr[0];
|
|
const float LR = lr2;
|
|
const float momentum = mom;
|
|
const float weight_decay = wd;
|
|
|
|
int n = N / 2;
|
|
if (!nesterov) {
|
|
CUDA_1D_KERNEL_LOOP(i, n) {
|
|
float2 param_float2 = __half22float2(param[i]);
|
|
float2 g_float2 = __half22float2(g[i]);
|
|
|
|
float2 ng_float2;
|
|
ng_float2.x = __fmaf_rn(weight_decay, param_float2.x, g_float2.x);
|
|
ng_float2.y = __fmaf_rn(weight_decay, param_float2.y, g_float2.y);
|
|
|
|
float2 m_float2 = __half22float2(m[i]);
|
|
float2 adjusted_gradient_float2;
|
|
adjusted_gradient_float2.x =
|
|
__fmaf_rn(LR, ng_float2.x, __fmul_rn(momentum, m_float2.x));
|
|
adjusted_gradient_float2.y =
|
|
__fmaf_rn(LR, ng_float2.y, __fmul_rn(momentum, m_float2.y));
|
|
|
|
nm[i] = __float22half2_rn(adjusted_gradient_float2);
|
|
ng[i] = __float22half2_rn(adjusted_gradient_float2);
|
|
|
|
if (param) {
|
|
param_float2.x = __fsub_rn(param_float2.x, adjusted_gradient_float2.x);
|
|
param_float2.y = __fsub_rn(param_float2.y, adjusted_gradient_float2.y);
|
|
param[i] = __float22half2_rn(param_float2);
|
|
}
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(i, n) {
|
|
// computing the term (grad + lambda*weight)
|
|
// might need to change in case of denormalization
|
|
|
|
float2 param_float2 = __half22float2(param[i]);
|
|
float2 g_float2 = __half22float2(g[i]);
|
|
|
|
float2 ng_float2;
|
|
ng_float2.x = __fmaf_rn(weight_decay, param_float2.x, g_float2.x);
|
|
ng_float2.y = __fmaf_rn(weight_decay, param_float2.y, g_float2.y);
|
|
|
|
const float2 mi_float2 = __half22float2(m[i]);
|
|
float2 mom_mi_float2;
|
|
mom_mi_float2.x = __fmul_rn(momentum, mi_float2.x);
|
|
mom_mi_float2.y = __fmul_rn(momentum, mi_float2.y);
|
|
float2 mi_new_float2;
|
|
mi_new_float2.x = __fmaf_rn(LR, ng_float2.x, mom_mi_float2.x);
|
|
mi_new_float2.y = __fmaf_rn(LR, ng_float2.y, mom_mi_float2.y);
|
|
|
|
nm[i] = __float22half2_rn(mi_new_float2);
|
|
ng_float2.x = __fsub_rn(
|
|
__fmaf_rn(mi_new_float2.x, momentum, mi_new_float2.x),
|
|
mom_mi_float2.x);
|
|
ng_float2.y = __fsub_rn(
|
|
__fmaf_rn(mi_new_float2.y, momentum, mi_new_float2.y),
|
|
mom_mi_float2.y);
|
|
ng[i] = __float22half2_rn(ng_float2);
|
|
|
|
if (param) {
|
|
param_float2.x = __fsub_rn(param_float2.x, ng_float2.x);
|
|
param_float2.y = __fsub_rn(param_float2.y, ng_float2.y);
|
|
param[i] = __float22half2_rn(param_float2);
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
CUDA_KERNEL_ASSERT(false);
|
|
#endif // CAFFE_HAS_CUDA_FP16
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void fp16_momentum_sgd_update<CUDAContext>(
|
|
int N,
|
|
const at::Half* g,
|
|
const at::Half* m,
|
|
at::Half* ng,
|
|
at::Half* nm,
|
|
const float* lr,
|
|
float momentum,
|
|
bool nesterov,
|
|
float weight_decay,
|
|
bool fp32_update,
|
|
at::Half* param,
|
|
CUDAContext* context) {
|
|
const cudaDeviceProp& prop = GetDeviceProperty(0);
|
|
if (prop.major >= kFp16CUDADevicePropMajor) {
|
|
if (!fp32_update) {
|
|
FP16MomentumSGDKernel<<<
|
|
CAFFE_GET_BLOCKS(N / 2),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context->cuda_stream()>>>(
|
|
N,
|
|
reinterpret_cast<const half2*>(g),
|
|
reinterpret_cast<const half2*>(m),
|
|
reinterpret_cast<half2*>(ng),
|
|
reinterpret_cast<half2*>(nm),
|
|
lr,
|
|
momentum,
|
|
nesterov,
|
|
weight_decay,
|
|
reinterpret_cast<half2*>(param));
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
// not setting N to N/2
|
|
} else {
|
|
FP16MomentumSGDFP32Kernel<<<
|
|
CAFFE_GET_BLOCKS(N / 2),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context->cuda_stream()>>>(
|
|
N,
|
|
reinterpret_cast<const half2*>(g),
|
|
reinterpret_cast<const half2*>(m),
|
|
reinterpret_cast<half2*>(ng),
|
|
reinterpret_cast<half2*>(nm),
|
|
lr,
|
|
momentum,
|
|
nesterov,
|
|
weight_decay,
|
|
reinterpret_cast<half2*>(param));
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
// not setting N to N/2
|
|
}
|
|
|
|
} else {
|
|
CAFFE_ENFORCE(false, "FP16MomentumSGDUpdate not supported. Major: ",
|
|
prop.major, " Minor: ", prop.minor);
|
|
}
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR(
|
|
FP16MomentumSGDUpdate,
|
|
FP16MomentumSGDUpdateOp<at::Half, CUDAContext>);
|
|
OPERATOR_SCHEMA(FP16MomentumSGDUpdate)
|
|
.NumInputs(4)
|
|
.NumOutputs(3)
|
|
.AllowInplace({{0, 0}, {1, 1}, {3, 2}})
|
|
.TensorInferenceFunction([](const OperatorDef& /* unused */,
|
|
const vector<TensorShape>& in) {
|
|
vector<TensorShape> out(3);
|
|
out[0] = in[0];
|
|
out[1] = in[1];
|
|
out[2] = in[3];
|
|
return out;
|
|
})
|
|
.SetDoc(R"DOC(
|
|
|
|
Computes the momentum SGD update similarly to the MomentumSGDUpdateOp,
|
|
however this op also performs the weight decay update at the same time, thus
|
|
making it more efficient.
|
|
|
|
This op is also functionally equivalent to the FP32MomentumSGDUpdateOp, however
|
|
it expects FP16 data and performs its updates in either FP16 precision
|
|
(default), or FP32 precision if the 'fp32_update' flag is set to True.
|
|
|
|
)DOC");
|
|
}
|