add Half support for layer_norm on CPU (#99590)

### Testing
Single socket (icx, 32cores):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 |
| (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 |
| (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 |
| (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 |
| (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 |

Single core (icx):

| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 |
| (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 |
| (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 |
| (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 |
| (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
This commit is contained in:
Sun, Jiayi
2023-12-19 15:39:04 +08:00
committed by PyTorch MergeBot
parent 45cfe9cdf7
commit c173a9d9b3
7 changed files with 119 additions and 117 deletions

View File

@ -69,7 +69,7 @@ inline Vectorized<Half> convert_from_float<Half>(const Vectorized<float>& a, con
//
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
@ -111,7 +111,7 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size
template <typename scalar_t, typename Op1, typename Op2,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
@ -169,7 +169,7 @@ inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2&
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map_reduce_all(
inline float map_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
@ -225,7 +225,7 @@ inline scalar_t map_reduce_all(
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map2_reduce_all(
inline float map2_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
@ -294,7 +294,7 @@ inline scalar_t map2_reduce_all(
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline scalar_t map3_reduce_all(
inline float map3_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,

View File

@ -23,14 +23,15 @@ namespace at::native {
namespace {
template <typename T, typename T_ACC>
template <typename T,
typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
void LayerNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t M,
int64_t N,
T_ACC eps,
T eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
@ -83,7 +84,8 @@ void LayerNormKernelImplInternal(
});
}
template <typename param_t>
template <typename T, typename param_t,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
void layer_norm_kernel_mixed_type(
const Tensor& X,
const Tensor& gamma,
@ -94,12 +96,12 @@ void layer_norm_kernel_mixed_type(
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
using bVec = Vectorized<BFloat16>;
using bVec = Vectorized<T>;
using fVec = Vectorized<float>;
const BFloat16* X_data = X.data_ptr<BFloat16>();
const T* X_data = X.data_ptr<T>();
const param_t* gamma_data = gamma.defined() ? gamma.data_ptr<param_t>() : nullptr;
const param_t* beta_data = beta.defined() ? beta.data_ptr<param_t>() : nullptr;
BFloat16* Y_data = Y->data_ptr<BFloat16>();
T* Y_data = Y->data_ptr<T>();
param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
@ -109,38 +111,29 @@ void layer_norm_kernel_mixed_type(
const bool rstd_null = rstd_data == nullptr;
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const BFloat16* X_ptr = X_data + i * N;
BFloat16* Y_ptr = Y_data + i * N;
const T* X_ptr = X_data + i * N;
T* Y_ptr = Y_data + i * N;
float mean_val;
float rstd_val;
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
rstd_val = float(1) / std::sqrt(rstd_val + eps);
const float scale = rstd_val;
const float bias = -rstd_val * mean_val;
if (gamma_null || beta_null) {
for (const auto j : c10::irange(N)) {
const param_t gamma_v = gamma_null ? param_t(1) : gamma_data[j];
const param_t beta_v = beta_null ? param_t(0) : beta_data[j];
Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
}
} else {
int64_t d = 0;
for (; d < N - (N % bVec::size()); d += bVec::size()) {
bVec x_bvec = bVec::loadu(X_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
fVec gamma_fvec0, gamma_fvec1;
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
fVec beta_fvec0, beta_fvec1;
std::tie(beta_fvec0, beta_fvec1) = load2f(beta_data + d);
fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
bVec y_bvec = convert_float_bfloat16(y_fvec0, y_fvec1);
y_bvec.store(Y_ptr + d);
}
for (; d < N; d++) {
Y_ptr[d] = (X_ptr[d] * scale + bias) * gamma_data[d] + beta_data[d];
}
int64_t d = 0;
for (; d < N - (N % bVec::size()); d += bVec::size()) {
bVec x_bvec = bVec::loadu(X_ptr + d);
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
y_bvec.store(Y_ptr + d);
}
for (; d < N; d++) {
const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
const float beta_v = beta_null ? float(0) : float(beta_data[d]);
Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
}
if (!mean_null) {
mean_data[i] = mean_val;
@ -152,8 +145,9 @@ void layer_norm_kernel_mixed_type(
});
}
template <>
void LayerNormKernelImplInternal<BFloat16, float>(
template <typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
void LayerNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
@ -165,9 +159,9 @@ void LayerNormKernelImplInternal<BFloat16, float>(
Tensor* rstd) {
const bool mixed_type = is_mixed_type(X, gamma, beta);
if (mixed_type) {
layer_norm_kernel_mixed_type<float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
} else {
layer_norm_kernel_mixed_type<BFloat16>(X, gamma, beta, M, N, eps, Y, mean, rstd);
layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
}
}
@ -184,15 +178,14 @@ void LayerNormKernelImpl(
TORCH_DCHECK_EQ(X.numel(), M * N);
DCHECK(!gamma.defined() || gamma.numel() == N);
DCHECK(!beta.defined() || beta.numel() == N);
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
"LayerNormKernelImpl", [&]() {
using acc_t = at::opmath_type<scalar_t>;
LayerNormKernelImplInternal<scalar_t, acc_t>(
X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
LayerNormKernelImplInternal<scalar_t>(
X, gamma, beta, M, N, eps, Y, mean, rstd);
});
}
template <typename T, typename T2, typename T_ACC>
template <typename T, typename T2, typename opmath_t>
void layer_norm_backward_frame(
const T* dY_data,
const T* X_data,
@ -202,19 +195,19 @@ void layer_norm_backward_frame(
T* dX_data,
T* dgamma_buffer_ptr,
T* dbeta_buffer_ptr,
const T_ACC scale,
const opmath_t scale,
const bool gamma_null,
const bool dX_null,
const bool dgamma_null,
const bool dbeta_null,
int64_t N,
int64_t i) {
using Vec = vec::Vectorized<T_ACC>;
using Vec = vec::Vectorized<opmath_t>;
const T* dY_ptr = dY_data + i * N;
const T* X_ptr = X_data + i * N;
if (!dgamma_null) {
const T_ACC a = rstd_data[i];
const T_ACC b = -a * mean_data[i];
const opmath_t a = rstd_data[i];
const opmath_t b = -a * mean_data[i];
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
@ -243,8 +236,8 @@ void layer_norm_backward_frame(
}
if (!dX_null) {
T* dX_ptr = dX_data + i * N;
T_ACC ds = T_ACC(0);
T_ACC db = T_ACC(0);
opmath_t ds = opmath_t(0);
opmath_t db = opmath_t(0);
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@ -275,9 +268,9 @@ void layer_norm_backward_frame(
gamma_data,
N);
}
const T_ACC a = rstd_data[i];
const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
const T_ACC c = -b * mean_data[i] - db * a * scale;
const opmath_t a = rstd_data[i];
const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@ -306,16 +299,17 @@ void layer_norm_backward_frame(
}
}
template <>
void layer_norm_backward_frame<BFloat16, float, float>(
const BFloat16* dY_data,
const BFloat16* X_data,
template <typename T, typename T2, typename opmath_t,
typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
void layer_norm_backward_frame(
const T* dY_data,
const T* X_data,
const float* mean_data,
const float* rstd_data,
const float* gamma_data,
BFloat16* dX_data,
BFloat16* dgamma_buffer_ptr,
BFloat16* dbeta_buffer_ptr,
T* dX_data,
T* dgamma_buffer_ptr,
T* dbeta_buffer_ptr,
const float scale,
const bool gamma_null,
const bool dX_null,
@ -323,10 +317,10 @@ void layer_norm_backward_frame<BFloat16, float, float>(
const bool dbeta_null,
int64_t N,
int64_t i) {
using bVec = Vectorized<BFloat16>;
using bVec = Vectorized<T>;
using fVec = Vectorized<float>;
const BFloat16* dY_ptr = dY_data + i * N;
const BFloat16* X_ptr = X_data + i * N;
const T* dY_ptr = dY_data + i * N;
const T* X_ptr = X_data + i * N;
if (!dgamma_null) {
const float a = rstd_data[i];
const float b = -a * mean_data[i];
@ -334,7 +328,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
// }
vec::map3<BFloat16>(
vec::map3<T>(
[a, b](fVec dgamma, fVec dy, fVec x) {
return dgamma + dy * (fVec(a) * x + fVec(b));
},
@ -349,7 +343,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
// for (const auto j : c10::irange(N)) {
// dbeta_data[j] += dY_ptr[j];
// }
vec::map2<BFloat16>(
vec::map2<T>(
[](fVec dbeta, fVec dy) { return dbeta + dy; },
dbeta_buffer_ptr,
dbeta_buffer_ptr,
@ -357,7 +351,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
N);
}
if (!dX_null) {
BFloat16* dX_ptr = dX_data + i * N;
T* dX_ptr = dX_data + i * N;
float ds = float(0);
float db = float(0);
// Scalar math:
@ -367,21 +361,21 @@ void layer_norm_backward_frame<BFloat16, float, float>(
// db += dY_ptr[j] * gamma_v;
// }
if (gamma_null) {
ds = vec::map2_reduce_all<BFloat16>(
ds = vec::map2_reduce_all<T>(
[](fVec x, fVec y) { return x * y; },
[](fVec x, fVec y) { return x + y; },
dY_ptr,
X_ptr,
N);
db = vec::reduce_all<BFloat16>(
db = vec::reduce_all<T>(
[](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
} else {
if (N < bVec::size()) {
bVec x_bvec = bVec::loadu(X_ptr, N);
bVec dy_bvec = bVec::loadu(dY_ptr, N);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data, N);
if (N > fVec::size()) {
fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
@ -404,8 +398,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
bVec dy_bvec = bVec::loadu(dY_ptr);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data);
acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
@ -414,8 +408,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
for (; d < N - (N % bVec::size()); d += bVec::size()) {
x_bvec = bVec::loadu(X_ptr + d);
dy_bvec = bVec::loadu(dY_ptr + d);
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
db_fvec0 = dy_fvec0 * gamma_fvec0;
db_fvec1 = dy_fvec1 * gamma_fvec1;
@ -429,8 +423,8 @@ void layer_norm_backward_frame<BFloat16, float, float>(
if (N - d > 0) {
x_bvec = bVec::loadu(X_ptr + d, N - d);
dy_bvec = bVec::loadu(dY_ptr + d, N - d);
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
if (N - d > fVec::size()) {
db_fvec0 = dy_fvec0 * gamma_fvec0;
@ -463,7 +457,7 @@ void layer_norm_backward_frame<BFloat16, float, float>(
// dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
// }
if (gamma_null) {
vec::map2<BFloat16>(
vec::map2<T>(
[a, b, c](fVec dy, fVec x) {
return fVec(a) * dy + fVec(b) * x + fVec(c);
},
@ -477,24 +471,24 @@ void layer_norm_backward_frame<BFloat16, float, float>(
bVec x_bvec = bVec::loadu(X_ptr + d);
bVec dy_bvec = bVec::loadu(dY_ptr + d);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
r_bvec.store(dX_ptr + d);
}
if (N - d > 0) {
bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
r_bvec.store(dX_ptr + d, N - d);
}
}
@ -513,7 +507,7 @@ void LayerNormBackwardKernelImplInternal(
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
using T_ACC = at::opmath_type<T>;
using opmath_t = at::opmath_type<T>;
TORCH_DCHECK_EQ(dY.numel(), M * N);
TORCH_DCHECK_EQ(X.numel(), M * N);
TORCH_DCHECK_EQ(mean.numel(), M);
@ -528,7 +522,7 @@ void LayerNormBackwardKernelImplInternal(
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
const T_ACC scale = T_ACC(1) / static_cast<T_ACC>(N);
const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
const bool gamma_null = gamma_data == nullptr;
const bool dX_null = dX_data == nullptr;
const bool dgamma_null = dgamma_data == nullptr;
@ -565,7 +559,7 @@ void LayerNormBackwardKernelImplInternal(
T* dbeta_buffer_ptr =
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
for (const auto i : c10::irange(start, end)) {
layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
}
});
@ -573,8 +567,8 @@ void LayerNormBackwardKernelImplInternal(
if (buffer_data != nullptr) {
parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
for (const auto j : c10::irange(start, end)) {
T_ACC dgamma_v = T_ACC(0);
T_ACC dbeta_v = T_ACC(0);
opmath_t dgamma_v = opmath_t(0);
opmath_t dbeta_v = opmath_t(0);
for (const auto i : c10::irange(num_threads)) {
dgamma_v += buffer_data[i * N + j];
dbeta_v += buffer_data[num_threads * N + i * N + j];
@ -603,16 +597,22 @@ void LayerNormBackwardKernelImpl(
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
"LayerNormBackwardKernelImpl", [&]() {
if (X.scalar_type() == at::kBFloat16 && gamma.scalar_type() == at::kFloat) {
LayerNormBackwardKernelImplInternal<BFloat16, float>(
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
} else {
if (at::isReducedFloatingType(X.scalar_type())) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
if (gamma.scalar_type() == at::kFloat) {
LayerNormBackwardKernelImplInternal<scalar_t, float>(
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
} else {
LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
}
});
} else {
AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
}
});
});
}
}
} // namespace

View File

@ -76,7 +76,7 @@ UpdateMomentsVec(
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
}
// each bfloat16 vector will be converted to two float vectors,
// each bfloat16/half vector will be converted to two float vectors,
// and accumulated successively on m1_stk0/m2_stk0.
template <typename T>
inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type

View File

@ -710,7 +710,7 @@ meta_function_device_skips = defaultdict(dict)
meta_function_device_expected_failures['cpu'] = {
torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16},
torch.native_layer_norm: {bf16},
torch.native_layer_norm: {bf16, f16},
}
meta_function_device_expected_failures['cuda'] = {
@ -855,7 +855,7 @@ meta_dispatch_device_expected_failures['cpu'] = {
aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
aten.native_layer_norm.default: {bf16},
aten.native_layer_norm.default: {bf16, f16},
aten.histc.default: {f16},
aten.histc.out: {f16},
}

View File

@ -11104,6 +11104,8 @@ class TestConsistency(TestCaseMPS):
'cross', 'linalg.cross',
'prod', 'masked.prod',
'nextafter',
'native_layer_norm',
'nn.functional.layer_norm',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',

View File

@ -7948,7 +7948,7 @@ class TestNNDeviceType(NNTestCase):
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
delta = 1e-1 if dtype == torch.bfloat16 else 1e-5
delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
@ -7982,12 +7982,12 @@ class TestNNDeviceType(NNTestCase):
output.sum().backward()
self.assertEqualTypeString(output, input)
def _test_LayerNorm_cpu_mixed_dtype(self, device):
def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
for elementwise_affine in [True, False]:
# layer norm input shape is normalized to m x n, cpu vectorized on n,
# so make sure n exceeds vector length
input = torch.empty(2, 3, 11, 3, device=device, dtype=torch.bfloat16).random_(1, 10)
m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, torch.bfloat16)
input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
# fp32
m_fp32 = deepcopy(m).to(device, torch.float)
@ -7995,21 +7995,21 @@ class TestNNDeviceType(NNTestCase):
out_fp32 = m_fp32(x_fp32)
out_fp32.sum().backward()
# bf16
# bf16/half
m_bf16 = deepcopy(m)
x_bf16 = input.clone().detach().requires_grad_()
out_bf16 = m_bf16(x_bf16)
out_bf16.sum().backward()
# bf16 mixed type
# bf16/half mixed type
m_mix = deepcopy(m).to(device, torch.float)
x_mix = input.clone().detach().requires_grad_()
out_mix = m_mix(x_mix)
out_mix.sum().backward()
self.assertEqual(out_fp32.bfloat16(), out_bf16)
self.assertEqual(out_fp32.bfloat16(), out_mix)
self.assertEqual(x_fp32.grad.bfloat16(), x_bf16.grad, atol=1e-1, rtol=1e-1)
self.assertEqual(x_fp32.grad.bfloat16(), x_mix.grad, atol=1e-1, rtol=1e-1)
self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
def _test_GroupNorm_general(self, device, dtype=torch.float):
good_shape_g = {
@ -8518,13 +8518,15 @@ class TestNNDeviceType(NNTestCase):
self._test_LayerNorm_general(device)
if self.device_type == 'cuda' or self.device_type == 'cpu':
self._test_LayerNorm_general(device, dtype=torch.bfloat16)
for dtype in [torch.half, torch.bfloat16]:
self._test_LayerNorm_general(device, dtype=dtype)
if self.device_type == 'cuda':
self._test_LayerNorm_cuda_half(device)
if self.device_type == 'cpu':
self._test_LayerNorm_cpu_mixed_dtype(device)
for dtype in [torch.half, torch.bfloat16]:
self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
@onlyNativeDeviceTypes
def test_LayerNorm_numeric(self, device):

View File

@ -12844,8 +12844,7 @@ op_db: List[OpInfo] = [
OpInfo('native_layer_norm',
aten_name='native_layer_norm',
ref=reference_native_layer_norm,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
assert_jit_shape_analysis=True,
supports_fwgrad_bwgrad=True,
@ -13378,8 +13377,7 @@ op_db: List[OpInfo] = [
aten_backward_name='layer_norm_backward',
aliases=('layer_norm',),
ref=reference_layer_norm,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,