mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add Half support for layer_norm on CPU (#99590)
### Testing Single socket (icx, 32cores): | shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) | | -- | -- | -- | -- | -- | -- | -- | | (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 | | (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 | | (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 | | (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 | | (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 | Single core (icx): | shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) | | -- | -- | -- | -- | -- | -- | -- | | (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 | | (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 | | (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 | | (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 | | (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
45cfe9cdf7
commit
c173a9d9b3
@ -69,7 +69,7 @@ inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, con
|
||||
//
|
||||
template <typename scalar_t, typename Op,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
||||
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
if (size < bVec::size()) {
|
||||
@ -111,7 +111,7 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
|
||||
|
||||
template <typename scalar_t, typename Op1, typename Op2,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
|
||||
inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
|
||||
const scalar_t* data, int64_t size) {
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
@ -169,7 +169,7 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
|
||||
|
||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline scalar_t map_reduce_all(
|
||||
inline float map_reduce_all(
|
||||
const MapOp& map_fun,
|
||||
const ReduceOp& red_fun,
|
||||
const scalar_t* data,
|
||||
@ -225,7 +225,7 @@ inline scalar_t map_reduce_all(
|
||||
|
||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline scalar_t map2_reduce_all(
|
||||
inline float map2_reduce_all(
|
||||
const MapOp& map_fun,
|
||||
const ReduceOp& red_fun,
|
||||
const scalar_t* data,
|
||||
@ -294,7 +294,7 @@ inline scalar_t map2_reduce_all(
|
||||
|
||||
template <typename scalar_t, typename MapOp, typename ReduceOp,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline scalar_t map3_reduce_all(
|
||||
inline float map3_reduce_all(
|
||||
const MapOp& map_fun,
|
||||
const ReduceOp& red_fun,
|
||||
const scalar_t* data,
|
||||
|
@ -23,14 +23,15 @@ namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
template <typename T,
|
||||
typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
|
||||
void LayerNormKernelImplInternal(
|
||||
const Tensor& X,
|
||||
const Tensor& gamma,
|
||||
const Tensor& beta,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
T_ACC eps,
|
||||
T eps,
|
||||
Tensor* Y,
|
||||
Tensor* mean,
|
||||
Tensor* rstd) {
|
||||
@ -83,7 +84,8 @@ void LayerNormKernelImplInternal(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename param_t>
|
||||
template <typename T, typename param_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
||||
void layer_norm_kernel_mixed_type(
|
||||
const Tensor& X,
|
||||
const Tensor& gamma,
|
||||
@ -94,12 +96,12 @@ void layer_norm_kernel_mixed_type(
|
||||
Tensor* Y,
|
||||
Tensor* mean,
|
||||
Tensor* rstd) {
|
||||
using bVec = Vectorized<BFloat16>;
|
||||
using bVec = Vectorized<T>;
|
||||
using fVec = Vectorized<float>;
|
||||
const BFloat16* X_data = X.data_ptr<BFloat16>();
|
||||
const T* X_data = X.data_ptr<T>();
|
||||
const param_t* gamma_data = gamma.defined() ? gamma.data_ptr<param_t>() : nullptr;
|
||||
const param_t* beta_data = beta.defined() ? beta.data_ptr<param_t>() : nullptr;
|
||||
BFloat16* Y_data = Y->data_ptr<BFloat16>();
|
||||
T* Y_data = Y->data_ptr<T>();
|
||||
param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
|
||||
param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
|
||||
|
||||
@ -109,38 +111,29 @@ void layer_norm_kernel_mixed_type(
|
||||
const bool rstd_null = rstd_data == nullptr;
|
||||
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const BFloat16* X_ptr = X_data + i * N;
|
||||
BFloat16* Y_ptr = Y_data + i * N;
|
||||
const T* X_ptr = X_data + i * N;
|
||||
T* Y_ptr = Y_data + i * N;
|
||||
float mean_val;
|
||||
float rstd_val;
|
||||
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
|
||||
rstd_val = float(1) / std::sqrt(rstd_val + eps);
|
||||
const float scale = rstd_val;
|
||||
const float bias = -rstd_val * mean_val;
|
||||
if (gamma_null || beta_null) {
|
||||
for (const auto j : c10::irange(N)) {
|
||||
const param_t gamma_v = gamma_null ? param_t(1) : gamma_data[j];
|
||||
const param_t beta_v = beta_null ? param_t(0) : beta_data[j];
|
||||
Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
|
||||
}
|
||||
} else {
|
||||
int64_t d = 0;
|
||||
for (; d < N - (N % bVec::size()); d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(X_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
fVec gamma_fvec0, gamma_fvec1;
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
|
||||
fVec beta_fvec0, beta_fvec1;
|
||||
std::tie(beta_fvec0, beta_fvec1) = load2f(beta_data + d);
|
||||
fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
|
||||
fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
|
||||
bVec y_bvec = convert_float_bfloat16(y_fvec0, y_fvec1);
|
||||
y_bvec.store(Y_ptr + d);
|
||||
}
|
||||
for (; d < N; d++) {
|
||||
Y_ptr[d] = (X_ptr[d] * scale + bias) * gamma_data[d] + beta_data[d];
|
||||
}
|
||||
int64_t d = 0;
|
||||
for (; d < N - (N % bVec::size()); d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(X_ptr + d);
|
||||
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
|
||||
auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
|
||||
auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
|
||||
fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
|
||||
fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
|
||||
bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
|
||||
y_bvec.store(Y_ptr + d);
|
||||
}
|
||||
for (; d < N; d++) {
|
||||
const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
|
||||
const float beta_v = beta_null ? float(0) : float(beta_data[d]);
|
||||
Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
|
||||
}
|
||||
if (!mean_null) {
|
||||
mean_data[i] = mean_val;
|
||||
@ -152,8 +145,9 @@ void layer_norm_kernel_mixed_type(
|
||||
});
|
||||
}
|
||||
|
||||
template <>
|
||||
void LayerNormKernelImplInternal<BFloat16, float>(
|
||||
template <typename T,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
|
||||
void LayerNormKernelImplInternal(
|
||||
const Tensor& X,
|
||||
const Tensor& gamma,
|
||||
const Tensor& beta,
|
||||
@ -165,9 +159,9 @@ void LayerNormKernelImplInternal<BFloat16, float>(
|
||||
Tensor* rstd) {
|
||||
const bool mixed_type = is_mixed_type(X, gamma, beta);
|
||||
if (mixed_type) {
|
||||
layer_norm_kernel_mixed_type<float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
|
||||
layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
|
||||
} else {
|
||||
layer_norm_kernel_mixed_type<BFloat16>(X, gamma, beta, M, N, eps, Y, mean, rstd);
|
||||
layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
|
||||
}
|
||||
}
|
||||
|
||||
@ -184,15 +178,14 @@ void LayerNormKernelImpl(
|
||||
TORCH_DCHECK_EQ(X.numel(), M * N);
|
||||
DCHECK(!gamma.defined() || gamma.numel() == N);
|
||||
DCHECK(!beta.defined() || beta.numel() == N);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
|
||||
"LayerNormKernelImpl", [&]() {
|
||||
using acc_t = at::opmath_type<scalar_t>;
|
||||
LayerNormKernelImplInternal<scalar_t, acc_t>(
|
||||
X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
|
||||
LayerNormKernelImplInternal<scalar_t>(
|
||||
X, gamma, beta, M, N, eps, Y, mean, rstd);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename T2, typename T_ACC>
|
||||
template <typename T, typename T2, typename opmath_t>
|
||||
void layer_norm_backward_frame(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
@ -202,19 +195,19 @@ void layer_norm_backward_frame(
|
||||
T* dX_data,
|
||||
T* dgamma_buffer_ptr,
|
||||
T* dbeta_buffer_ptr,
|
||||
const T_ACC scale,
|
||||
const opmath_t scale,
|
||||
const bool gamma_null,
|
||||
const bool dX_null,
|
||||
const bool dgamma_null,
|
||||
const bool dbeta_null,
|
||||
int64_t N,
|
||||
int64_t i) {
|
||||
using Vec = vec::Vectorized<T_ACC>;
|
||||
using Vec = vec::Vectorized<opmath_t>;
|
||||
const T* dY_ptr = dY_data + i * N;
|
||||
const T* X_ptr = X_data + i * N;
|
||||
if (!dgamma_null) {
|
||||
const T_ACC a = rstd_data[i];
|
||||
const T_ACC b = -a * mean_data[i];
|
||||
const opmath_t a = rstd_data[i];
|
||||
const opmath_t b = -a * mean_data[i];
|
||||
// Scalar math:
|
||||
// for (const auto j : c10::irange(N)) {
|
||||
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
|
||||
@ -243,8 +236,8 @@ void layer_norm_backward_frame(
|
||||
}
|
||||
if (!dX_null) {
|
||||
T* dX_ptr = dX_data + i * N;
|
||||
T_ACC ds = T_ACC(0);
|
||||
T_ACC db = T_ACC(0);
|
||||
opmath_t ds = opmath_t(0);
|
||||
opmath_t db = opmath_t(0);
|
||||
// Scalar math:
|
||||
// for (const auto j : c10::irange(N)) {
|
||||
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
|
||||
@ -275,9 +268,9 @@ void layer_norm_backward_frame(
|
||||
gamma_data,
|
||||
N);
|
||||
}
|
||||
const T_ACC a = rstd_data[i];
|
||||
const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
|
||||
const T_ACC c = -b * mean_data[i] - db * a * scale;
|
||||
const opmath_t a = rstd_data[i];
|
||||
const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
|
||||
const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
|
||||
// Scalar math:
|
||||
// for (const auto j : c10::irange(N)) {
|
||||
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
|
||||
@ -306,16 +299,17 @@ void layer_norm_backward_frame(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
const BFloat16* dY_data,
|
||||
const BFloat16* X_data,
|
||||
template <typename T, typename T2, typename opmath_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
|
||||
void layer_norm_backward_frame(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const float* mean_data,
|
||||
const float* rstd_data,
|
||||
const float* gamma_data,
|
||||
BFloat16* dX_data,
|
||||
BFloat16* dgamma_buffer_ptr,
|
||||
BFloat16* dbeta_buffer_ptr,
|
||||
T* dX_data,
|
||||
T* dgamma_buffer_ptr,
|
||||
T* dbeta_buffer_ptr,
|
||||
const float scale,
|
||||
const bool gamma_null,
|
||||
const bool dX_null,
|
||||
@ -323,10 +317,10 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
const bool dbeta_null,
|
||||
int64_t N,
|
||||
int64_t i) {
|
||||
using bVec = Vectorized<BFloat16>;
|
||||
using bVec = Vectorized<T>;
|
||||
using fVec = Vectorized<float>;
|
||||
const BFloat16* dY_ptr = dY_data + i * N;
|
||||
const BFloat16* X_ptr = X_data + i * N;
|
||||
const T* dY_ptr = dY_data + i * N;
|
||||
const T* X_ptr = X_data + i * N;
|
||||
if (!dgamma_null) {
|
||||
const float a = rstd_data[i];
|
||||
const float b = -a * mean_data[i];
|
||||
@ -334,7 +328,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
// for (const auto j : c10::irange(N)) {
|
||||
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
|
||||
// }
|
||||
vec::map3<BFloat16>(
|
||||
vec::map3<T>(
|
||||
[a, b](fVec dgamma, fVec dy, fVec x) {
|
||||
return dgamma + dy * (fVec(a) * x + fVec(b));
|
||||
},
|
||||
@ -349,7 +343,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
// for (const auto j : c10::irange(N)) {
|
||||
// dbeta_data[j] += dY_ptr[j];
|
||||
// }
|
||||
vec::map2<BFloat16>(
|
||||
vec::map2<T>(
|
||||
[](fVec dbeta, fVec dy) { return dbeta + dy; },
|
||||
dbeta_buffer_ptr,
|
||||
dbeta_buffer_ptr,
|
||||
@ -357,7 +351,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
N);
|
||||
}
|
||||
if (!dX_null) {
|
||||
BFloat16* dX_ptr = dX_data + i * N;
|
||||
T* dX_ptr = dX_data + i * N;
|
||||
float ds = float(0);
|
||||
float db = float(0);
|
||||
// Scalar math:
|
||||
@ -367,21 +361,21 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
// db += dY_ptr[j] * gamma_v;
|
||||
// }
|
||||
if (gamma_null) {
|
||||
ds = vec::map2_reduce_all<BFloat16>(
|
||||
ds = vec::map2_reduce_all<T>(
|
||||
[](fVec x, fVec y) { return x * y; },
|
||||
[](fVec x, fVec y) { return x + y; },
|
||||
dY_ptr,
|
||||
X_ptr,
|
||||
N);
|
||||
db = vec::reduce_all<BFloat16>(
|
||||
db = vec::reduce_all<T>(
|
||||
[](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
|
||||
} else {
|
||||
if (N < bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(X_ptr, N);
|
||||
bVec dy_bvec = bVec::loadu(dY_ptr, N);
|
||||
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data, N);
|
||||
if (N > fVec::size()) {
|
||||
fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
|
||||
@ -404,8 +398,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
bVec dy_bvec = bVec::loadu(dY_ptr);
|
||||
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
|
||||
fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data);
|
||||
acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
|
||||
acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
|
||||
@ -414,8 +408,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
for (; d < N - (N % bVec::size()); d += bVec::size()) {
|
||||
x_bvec = bVec::loadu(X_ptr + d);
|
||||
dy_bvec = bVec::loadu(dY_ptr + d);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
|
||||
db_fvec0 = dy_fvec0 * gamma_fvec0;
|
||||
db_fvec1 = dy_fvec1 * gamma_fvec1;
|
||||
@ -429,8 +423,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
if (N - d > 0) {
|
||||
x_bvec = bVec::loadu(X_ptr + d, N - d);
|
||||
dy_bvec = bVec::loadu(dY_ptr + d, N - d);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
|
||||
if (N - d > fVec::size()) {
|
||||
db_fvec0 = dy_fvec0 * gamma_fvec0;
|
||||
@ -463,7 +457,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
// dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
|
||||
// }
|
||||
if (gamma_null) {
|
||||
vec::map2<BFloat16>(
|
||||
vec::map2<T>(
|
||||
[a, b, c](fVec dy, fVec x) {
|
||||
return fVec(a) * dy + fVec(b) * x + fVec(c);
|
||||
},
|
||||
@ -477,24 +471,24 @@ void layer_norm_backward_frame<BFloat16, float, float>(
|
||||
bVec x_bvec = bVec::loadu(X_ptr + d);
|
||||
bVec dy_bvec = bVec::loadu(dY_ptr + d);
|
||||
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
|
||||
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
|
||||
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
|
||||
bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
|
||||
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
|
||||
r_bvec.store(dX_ptr + d);
|
||||
}
|
||||
if (N - d > 0) {
|
||||
bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
|
||||
bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
|
||||
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
|
||||
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
|
||||
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
|
||||
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
|
||||
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
|
||||
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
|
||||
bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
|
||||
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
|
||||
r_bvec.store(dX_ptr + d, N - d);
|
||||
}
|
||||
}
|
||||
@ -513,7 +507,7 @@ void LayerNormBackwardKernelImplInternal(
|
||||
Tensor* dX,
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta) {
|
||||
using T_ACC = at::opmath_type<T>;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
TORCH_DCHECK_EQ(dY.numel(), M * N);
|
||||
TORCH_DCHECK_EQ(X.numel(), M * N);
|
||||
TORCH_DCHECK_EQ(mean.numel(), M);
|
||||
@ -528,7 +522,7 @@ void LayerNormBackwardKernelImplInternal(
|
||||
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
|
||||
T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
|
||||
T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
|
||||
const T_ACC scale = T_ACC(1) / static_cast<T_ACC>(N);
|
||||
const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
|
||||
const bool gamma_null = gamma_data == nullptr;
|
||||
const bool dX_null = dX_data == nullptr;
|
||||
const bool dgamma_null = dgamma_data == nullptr;
|
||||
@ -565,7 +559,7 @@ void LayerNormBackwardKernelImplInternal(
|
||||
T* dbeta_buffer_ptr =
|
||||
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
|
||||
layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
|
||||
}
|
||||
});
|
||||
|
||||
@ -573,8 +567,8 @@ void LayerNormBackwardKernelImplInternal(
|
||||
if (buffer_data != nullptr) {
|
||||
parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
|
||||
for (const auto j : c10::irange(start, end)) {
|
||||
T_ACC dgamma_v = T_ACC(0);
|
||||
T_ACC dbeta_v = T_ACC(0);
|
||||
opmath_t dgamma_v = opmath_t(0);
|
||||
opmath_t dbeta_v = opmath_t(0);
|
||||
for (const auto i : c10::irange(num_threads)) {
|
||||
dgamma_v += buffer_data[i * N + j];
|
||||
dbeta_v += buffer_data[num_threads * N + i * N + j];
|
||||
@ -603,16 +597,22 @@ void LayerNormBackwardKernelImpl(
|
||||
Tensor* dX,
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
|
||||
"LayerNormBackwardKernelImpl", [&]() {
|
||||
if (X.scalar_type() == at::kBFloat16 && gamma.scalar_type() == at::kFloat) {
|
||||
LayerNormBackwardKernelImplInternal<BFloat16, float>(
|
||||
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
|
||||
} else {
|
||||
if (at::isReducedFloatingType(X.scalar_type())) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
|
||||
if (gamma.scalar_type() == at::kFloat) {
|
||||
LayerNormBackwardKernelImplInternal<scalar_t, float>(
|
||||
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
|
||||
} else {
|
||||
LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
|
||||
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
|
||||
LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
|
||||
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -76,7 +76,7 @@ UpdateMomentsVec(
|
||||
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
|
||||
}
|
||||
|
||||
// each bfloat16 vector will be converted to two float vectors,
|
||||
// each bfloat16/half vector will be converted to two float vectors,
|
||||
// and accumulated successively on m1_stk0/m2_stk0.
|
||||
template <typename T>
|
||||
inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type
|
||||
|
@ -710,7 +710,7 @@ meta_function_device_skips = defaultdict(dict)
|
||||
meta_function_device_expected_failures['cpu'] = {
|
||||
torch.native_batch_norm: {bf16, f16},
|
||||
torch._native_batch_norm_legit: {bf16, f16},
|
||||
torch.native_layer_norm: {bf16},
|
||||
torch.native_layer_norm: {bf16, f16},
|
||||
}
|
||||
|
||||
meta_function_device_expected_failures['cuda'] = {
|
||||
@ -855,7 +855,7 @@ meta_dispatch_device_expected_failures['cpu'] = {
|
||||
aten.native_batch_norm.default: {bf16, f16},
|
||||
aten._native_batch_norm_legit.default: {bf16, f16},
|
||||
aten._native_batch_norm_legit.no_stats: {bf16, f16},
|
||||
aten.native_layer_norm.default: {bf16},
|
||||
aten.native_layer_norm.default: {bf16, f16},
|
||||
aten.histc.default: {f16},
|
||||
aten.histc.out: {f16},
|
||||
}
|
||||
|
@ -11104,6 +11104,8 @@ class TestConsistency(TestCaseMPS):
|
||||
'cross', 'linalg.cross',
|
||||
'prod', 'masked.prod',
|
||||
'nextafter',
|
||||
'native_layer_norm',
|
||||
'nn.functional.layer_norm',
|
||||
|
||||
# for macOS 12
|
||||
'masked.normalize', 'masked.sum', 'masked.var',
|
||||
|
@ -7948,7 +7948,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
mean = out_reshaped.mean(-1)
|
||||
var = out_reshaped.var(-1, unbiased=False)
|
||||
|
||||
delta = 1e-1 if dtype == torch.bfloat16 else 1e-5
|
||||
delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
|
||||
self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
|
||||
self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
|
||||
|
||||
@ -7982,12 +7982,12 @@ class TestNNDeviceType(NNTestCase):
|
||||
output.sum().backward()
|
||||
self.assertEqualTypeString(output, input)
|
||||
|
||||
def _test_LayerNorm_cpu_mixed_dtype(self, device):
|
||||
def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
|
||||
for elementwise_affine in [True, False]:
|
||||
# layer norm input shape is normalized to m x n, cpu vectorized on n,
|
||||
# so make sure n exceeds vector length
|
||||
input = torch.empty(2, 3, 11, 3, device=device, dtype=torch.bfloat16).random_(1, 10)
|
||||
m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, torch.bfloat16)
|
||||
input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
|
||||
m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
|
||||
|
||||
# fp32
|
||||
m_fp32 = deepcopy(m).to(device, torch.float)
|
||||
@ -7995,21 +7995,21 @@ class TestNNDeviceType(NNTestCase):
|
||||
out_fp32 = m_fp32(x_fp32)
|
||||
out_fp32.sum().backward()
|
||||
|
||||
# bf16
|
||||
# bf16/half
|
||||
m_bf16 = deepcopy(m)
|
||||
x_bf16 = input.clone().detach().requires_grad_()
|
||||
out_bf16 = m_bf16(x_bf16)
|
||||
out_bf16.sum().backward()
|
||||
|
||||
# bf16 mixed type
|
||||
# bf16/half mixed type
|
||||
m_mix = deepcopy(m).to(device, torch.float)
|
||||
x_mix = input.clone().detach().requires_grad_()
|
||||
out_mix = m_mix(x_mix)
|
||||
out_mix.sum().backward()
|
||||
self.assertEqual(out_fp32.bfloat16(), out_bf16)
|
||||
self.assertEqual(out_fp32.bfloat16(), out_mix)
|
||||
self.assertEqual(x_fp32.grad.bfloat16(), x_bf16.grad, atol=1e-1, rtol=1e-1)
|
||||
self.assertEqual(x_fp32.grad.bfloat16(), x_mix.grad, atol=1e-1, rtol=1e-1)
|
||||
self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
|
||||
self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
|
||||
self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
|
||||
self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
|
||||
|
||||
def _test_GroupNorm_general(self, device, dtype=torch.float):
|
||||
good_shape_g = {
|
||||
@ -8518,13 +8518,15 @@ class TestNNDeviceType(NNTestCase):
|
||||
self._test_LayerNorm_general(device)
|
||||
|
||||
if self.device_type == 'cuda' or self.device_type == 'cpu':
|
||||
self._test_LayerNorm_general(device, dtype=torch.bfloat16)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
self._test_LayerNorm_general(device, dtype=dtype)
|
||||
|
||||
if self.device_type == 'cuda':
|
||||
self._test_LayerNorm_cuda_half(device)
|
||||
|
||||
if self.device_type == 'cpu':
|
||||
self._test_LayerNorm_cpu_mixed_dtype(device)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_LayerNorm_numeric(self, device):
|
||||
|
@ -12844,8 +12844,7 @@ op_db: List[OpInfo] = [
|
||||
OpInfo('native_layer_norm',
|
||||
aten_name='native_layer_norm',
|
||||
ref=reference_native_layer_norm,
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
assert_jit_shape_analysis=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
@ -13378,8 +13377,7 @@ op_db: List[OpInfo] = [
|
||||
aten_backward_name='layer_norm_backward',
|
||||
aliases=('layer_norm',),
|
||||
ref=reference_layer_norm,
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
|
Reference in New Issue
Block a user