Files
pytorch/torch/csrc/inductor/cpp_prefix.h
PyTorch MergeBot 7614338b69 Revert "Add SVE128 ISA (#158932)"
This reverts commit 92284fb2ff44f09a9c7df0d8cf6cac9903e376a4.

Reverted https://github.com/pytorch/pytorch/pull/158932 on behalf of https://github.com/malfet due to Hmm, but from OSS point of view, this is a no-op ([comment](https://github.com/pytorch/pytorch/pull/158932#issuecomment-3387961238))
2025-10-10 01:17:02 +00:00

1279 lines
37 KiB
C++

#pragma once
#include <omp.h>
#include <algorithm>
#include <atomic>
#include <cmath>
#include <cstdlib>
#include <limits>
#include <map>
#include <memory>
#include <optional>
// WARNING: be extra careful when including more ATen/c10 header files here!
// Because AOTInductor generated code will copy-paste this cpp_prefix.h for
// the CPU backend, we have to make sure the used headers are implemented
// in a header-only way, i.e. all the function and class definitions are
// in .h files instead of .cpp files, to avoid ABI backward-compatiblity
// breakage.
#include <ATen/NumericUtils.h>
#include <ATen/core/PhiloxRNGEngine.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <c10/util/generic_math.h>
#include <c10/util/irange.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || \
defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || \
defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)
#define INDUCTOR_USE_VECTOR_TYPES() 1
#else
#define INDUCTOR_USE_VECTOR_TYPES() 0
#endif
#if INDUCTOR_USE_VECTOR_TYPES()
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#else
// For calc_erfinv
#include <ATen/native/Math.h>
#endif
template <typename T>
struct Welford {
T mean = T(0);
T m2 = T(0);
// Use weight for tail cases since the index of each element in the vec may be
// different. A single index can not express masked welford reduction.
T weight = T(0);
uint64_t index = 0;
};
template <typename T>
struct IsVecType : std::false_type {};
template <typename T>
struct IsVecMaskType : std::false_type {};
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T>
struct IsVecType<at::vec::Vectorized<T>> : std::true_type {};
template <typename T, int N>
struct IsVecType<at::vec::VectorizedN<T, N>> : std::true_type {};
template <typename T, int N>
struct IsVecMaskType<at::vec::VecMask<T, N>> : std::true_type {};
#endif
template <typename T, uint64_t kChunkSize>
struct CascadeSumHelper {
// A data struct to help cascade summation:
std::vector<T> sum_stk{};
uint64_t depth{0}; // depth of sum_stk.
uint64_t num_chunks{0}; // number of chunks stored in sum_stk.
uint64_t index{0}; // index of the current data.
CascadeSumHelper() = default;
CascadeSumHelper(uint64_t N) {
uint64_t m = (N + kChunkSize - 1) / kChunkSize; // div up
depth = m > 0
? static_cast<std::uint64_t>(ceil(log2(static_cast<double>(m))))
: 0;
if constexpr (IsVecType<T>::value) {
sum_stk.assign(
std::max(depth, static_cast<uint64_t>(1)),
T(typename T::value_type(0)));
} else {
sum_stk.assign(std::max(depth, static_cast<uint64_t>(1)), T(0));
}
}
};
template <typename T, uint64_t kChunkSize = 0>
inline T cascade_sum_combine(T& data, CascadeSumHelper<T, kChunkSize>* c) {
// Note: In order to be consistent with other reductions in inductor,
// the returned value may be wrong and cascade_sum_final must be executed to
// get the final correct result. Inductor uses the reduction suffix to ensure
// that cascade_sum_final is called in the end.
c->sum_stk[0] = c->sum_stk[0] + data;
// Use cascade summation to improve numerical stability.
// https://en.wikipedia.org/wiki/Pairwise_summation
if (c->depth > 0) {
c->index++;
if (c->index == kChunkSize) {
c->num_chunks += 1;
c->index = 0;
uint64_t mask = c->num_chunks;
uint64_t j = 1;
for (; j < c->depth && (mask & 1) == 0; ++j) {
c->sum_stk[j] = c->sum_stk[j] + c->sum_stk[j - 1];
c->sum_stk[j - 1] = T(0);
mask >>= 1;
}
return c->sum_stk[j - 1];
}
}
return c->sum_stk[0];
}
template <typename T, uint64_t kChunkSize = 0>
inline T cascade_sum_final(CascadeSumHelper<T, kChunkSize>* c) {
T result = c->sum_stk[0];
for (const auto i : c10::irange(1, c->depth)) {
result = result + c->sum_stk[i];
}
return result;
}
template <typename T, uint64_t kChunkSize>
struct WelfordHelper {
// A data struct to help welford reduction:
// 1. Save the reciprocal of weights to avoid redundant divisions.
// 2. Save the welford stack, which is used to combine welford reduction
// with cascade summation to improve numerical stability.
static std::vector<typename T::value_type> weight_recps;
std::vector<Welford<T>> welford_stk{};
uint64_t depth{0}; // depth of welford_stk.
uint64_t num_chunks{0}; // number of chunks stored in welford_stk.
WelfordHelper() = default;
WelfordHelper(uint64_t N) {
uint64_t m = (N + kChunkSize - 1) / kChunkSize; // div up
depth = m > 0
? static_cast<std::uint64_t>(ceil(log2(static_cast<double>(m))))
: 0;
welford_stk.assign(depth, Welford<T>());
}
};
template <typename T, uint64_t kChunkSize>
std::vector<typename T::value_type> WelfordHelper<T, kChunkSize>::weight_recps =
[]() {
using scalar_t = typename T::value_type;
std::vector<scalar_t> temp(kChunkSize);
for (const auto i : c10::irange(kChunkSize)) {
temp[i] = scalar_t(static_cast<double>(1) / static_cast<double>(i + 1));
}
return temp;
}();
template <typename T>
Welford<T> welford_combine(
const Welford<T>& a,
const Welford<T>& b,
bool use_index = false) {
if (a.index == 0) {
return b;
}
if (b.index == 0) {
return a;
}
auto delta = b.mean - a.mean;
auto a_weight = use_index ? T(a.index) : a.weight;
auto b_weight = use_index ? T(b.index) : b.weight;
auto new_weight = a_weight + b_weight;
auto new_index = a.index + b.index;
auto wb_over_w = b_weight / new_weight;
if constexpr (IsVecType<T>::value) {
// Guard against division by zero
wb_over_w = T::blendv(wb_over_w, T(0), new_weight == T(0));
}
auto result = Welford<T>{
a.mean + delta * wb_over_w,
a.m2 + b.m2 + delta * delta * a_weight * wb_over_w,
new_weight,
new_index};
return result;
}
template <typename T, uint64_t kChunkSize = 0>
Welford<T> welford_combine(
Welford<T>& acc,
T& data,
WelfordHelper<T, kChunkSize>* w = nullptr) {
// Combine welford reduction with cascade summation to improve numerical
// stability.
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// https://en.wikipedia.org/wiki/Pairwise_summation
if constexpr (IsVecType<T>::value) {
if (w != nullptr && w->depth > 0 && acc.index == kChunkSize) {
w->welford_stk[0] = welford_combine(w->welford_stk[0], acc);
w->num_chunks += 1;
acc.mean = T(0);
acc.m2 = T(0);
acc.weight = T(0);
acc.index = 0;
uint64_t mask = w->num_chunks;
for (uint64_t j = 1; j < w->depth && (mask & 1) == 0; ++j) {
w->welford_stk[j] =
welford_combine(w->welford_stk[j], w->welford_stk[j - 1]);
w->welford_stk[j - 1] = Welford<T>();
mask >>= 1;
}
}
}
// Add a single data point
uint64_t new_index = acc.index + 1;
auto new_weight = acc.weight + T(1);
auto delta = data - acc.mean;
T new_mean;
if constexpr (!IsVecType<T>::value) {
new_mean = acc.mean + delta / new_weight;
} else {
// use new_index to fecth 1 / new_weight to avoid divisions
new_mean = acc.mean +
((w == nullptr || acc.index >= w->weight_recps.size())
? delta / new_weight
: delta * T(w->weight_recps[acc.index]));
}
auto new_delta = data - new_mean;
auto result =
Welford<T>{new_mean, acc.m2 + delta * new_delta, new_weight, new_index};
return result;
}
template <typename T, uint64_t kChunkSize = 0>
Welford<T> welford_combine(Welford<T>& acc, WelfordHelper<T, kChunkSize>* w) {
for (const auto i : c10::irange(w->depth)) {
acc = welford_combine(acc, w->welford_stk[i]);
}
return acc;
}
template <typename T>
struct IndexValue {
int64_t index{};
T value;
IndexValue(int64_t idx, T val) : index(idx), value(val) {}
IndexValue() = default;
};
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T, uint64_t kChunkSize>
Welford<T> welford_combine(
Welford<T>& acc,
T& data,
int64_t tail_size,
WelfordHelper<T, kChunkSize>* w = nullptr) {
auto out = welford_combine(acc, data, w);
return Welford<T>{
T::set(acc.mean, out.mean, tail_size),
T::set(acc.m2, out.m2, tail_size),
T::set(acc.weight, out.weight, tail_size),
out.index};
}
template <typename T, uint64_t kChunkSize = 0>
inline T cascade_sum_combine(
T& data,
int64_t tail_size,
CascadeSumHelper<T, kChunkSize>* c) {
auto out = c->sum_stk[0] + data;
c->sum_stk[0] = T::set(c->sum_stk[0], out, tail_size);
if (c->depth > 0) {
c->index++;
if (c->index == kChunkSize) {
c->num_chunks += 1;
c->index = 0;
uint64_t mask = c->num_chunks;
uint64_t j = 1;
for (; j < c->depth && (mask & 1) == 0; ++j) {
c->sum_stk[j] = c->sum_stk[j] + c->sum_stk[j - 1];
c->sum_stk[j - 1] = T(0);
mask >>= 1;
}
return c->sum_stk[j - 1];
}
}
return c->sum_stk[0];
}
template <typename T>
T max_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = at::vec::maximum(a, b);
return T::set(a, out, tail_size);
}
template <typename T>
T min_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = at::vec::minimum(a, b);
return T::set(a, out, tail_size);
}
template <typename T>
T sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a + b;
return T::set(a, out, tail_size);
}
template <typename T>
T prod_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a * b;
return T::set(a, out, tail_size);
}
template <typename T>
T xor_sum_masked_reduce(const T& a, const T& b, const int64_t tail_size) {
auto out = a ^ b;
return T::set(a, out, tail_size);
}
#endif
// Refer to
// https://github.com/pytorch/pytorch/blob/b5b36cf0c4e1958f1ff25120f5d4beeef3288187/
// aten/src/ATen/native/SharedReduceOps.h#L419-L445
template <typename scalar_t>
inline bool greater_or_nan(
scalar_t a,
scalar_t b,
int64_t idx_a,
int64_t idx_b) {
// If (a == b), then choose the one with lower idx, else max(a, b)
if (at::_isnan(a)) {
if (at::_isnan(b)) {
return idx_a < idx_b;
}
return true;
}
return (a == b) ? idx_a < idx_b : (a > b);
}
template <typename scalar_t>
inline bool less_or_nan(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) {
// If (a == b), then choose the one with lower idx, else min(a, b)
if (at::_isnan(a)) {
if (at::_isnan(b)) {
return idx_a < idx_b;
}
return true;
}
return (a == b) ? idx_a < idx_b : (a < b);
}
template <typename T>
inline IndexValue<T>& argmin_combine(
IndexValue<T>& a,
T next_value,
int64_t next_index) {
if (!(less_or_nan(a.value, next_value, a.index, next_index))) {
a.value = next_value;
a.index = next_index;
}
return a;
}
template <typename T>
inline IndexValue<T>& argmax_combine(
IndexValue<T>& a,
T next_value,
int64_t next_index) {
if (!(greater_or_nan(a.value, next_value, a.index, next_index))) {
a.value = next_value;
a.index = next_index;
}
return a;
}
template <typename T>
inline IndexValue<T>& argmin_combine(
IndexValue<T>& a,
const IndexValue<T>& next) {
return argmin_combine(a, next.value, next.index);
}
template <typename T>
inline IndexValue<T>& argmax_combine(
IndexValue<T>& a,
const IndexValue<T>& next) {
return argmax_combine(a, next.value, next.index);
}
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename scalar_t>
inline at::vec::Vectorized<scalar_t> div_floor_floating_vec(
const at::vec::Vectorized<scalar_t>& a,
const at::vec::Vectorized<scalar_t>& b) {
using vec_t = at::vec::Vectorized<scalar_t>;
const auto basic_div = a / b;
vec_t inf(std::numeric_limits<scalar_t>::infinity());
auto mod = a.fmod(b);
// Fixup for a case that isn't properly handled by Sleef_fmod
auto floor =
vec_t::blendv(a - mod, a, (basic_div.abs() == inf) & (a.abs() != inf));
auto div = floor / b;
const auto zero = vec_t(0);
auto mask = (mod != zero) & ((b < zero) ^ (mod < zero));
const auto one = vec_t(1);
div = vec_t::blendv(div, div - one, mask);
auto floordiv = div.floor();
mask = (div - floordiv) > vec_t(0.5);
floordiv = vec_t::blendv(floordiv, floordiv + one, mask);
floordiv = vec_t::blendv(floordiv, zero.copysign(basic_div), div == zero);
floordiv = vec_t::blendv(floordiv, basic_div, b == zero);
return floordiv;
};
template <typename scalar_t, int N>
inline at::vec::VectorizedN<scalar_t, N> div_floor_floating_vec(
const at::vec::VectorizedN<scalar_t, N>& a,
const at::vec::VectorizedN<scalar_t, N>& b) {
at::vec::VectorizedN<scalar_t, N> result;
#ifndef _MSC_VER
#pragma unroll
#endif
for (int i = 0; i < N; ++i) {
result[i] = div_floor_floating_vec(a[i], b[i]);
}
return result;
}
template <typename T, int NV, int NI>
struct IndexValueVec {
at::vec::VectorizedN<T, NV> value;
at::vec::VectorizedN<int64_t, NI> index;
IndexValueVec(const T _value) {
value = at::vec::VectorizedN<T, NV>(_value);
index = at::vec::VectorizedN<int64_t, NI>(0);
};
IndexValueVec() {};
};
template <
typename T,
int NV,
int NI,
typename std::enable_if_t<at::vec::is_floating_point_v<T>, int> = 0>
at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax(
const at::vec::VecMask<T, NV>& vmask,
const IndexValueVec<T, NV, NI>& a,
const at::vec::VectorizedN<T, NV>& value,
const at::vec::VectorizedN<int64_t, NI>& index) {
/*
vec impl for less_or_nan and greater_or_nan
example for argmin:
a.value = [NaN, NaN, 0, 2, 1, 0]
value = [NaN, 0, 0, 1, 2, NaN]
vmask = [false, false, false, false, true, false]
all_nan_or_equal = [true, false, true, false, false, false]
imask = [a.index[0] < index[0], ..., a.index[-1] < index[-1]]
iv_mask = blendv (vmask, imask, all_nan_or_equal)
[a.index[0] < index[0], false, a.index[2] < index[2], false, true,
false] a_nan_b_not: [false, false, false, false, false, true] mask = iv_mask |
a_nan_b_not [a.index[0] < index[0], false, a.index[2] < index[2], false, true,
true]
*/
using v_t = at::vec::VecMask<T, NV>;
using i_t = at::vec::VecMask<int64_t, NI>;
i_t vmask_itype = vmask.template cast<int64_t, NI>();
// use itype here since there is vec impl for operator~ for itype
// while there may not vec impl for vtype
v_t isnan_a = a.value.isnan();
i_t isnan_a_itype = isnan_a.template cast<int64_t, NI>();
v_t isnan_b = value.isnan();
i_t isnan_b_type = isnan_b.template cast<int64_t, NI>();
i_t all_nan_mask = isnan_a_itype & isnan_b_type;
v_t equal_mask = (a.value == value);
i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>();
i_t all_nan_or_equal = all_nan_mask | equal_mask_itype;
i_t imask(a.index < index);
i_t iv_mask = i_t::blendv(vmask_itype, imask, all_nan_or_equal);
i_t isnan_a_notnan_b = isnan_a_itype & (~isnan_b_type);
return iv_mask | isnan_a_notnan_b;
}
template <
typename T,
int NV,
int NI,
typename std::enable_if_t<!at::vec::is_floating_point_v<T>, int> = 0>
at::vec::VecMask<int64_t, NI> inline get_mask_for_argmin_argmax(
const at::vec::VecMask<T, NV>& vmask,
const IndexValueVec<T, NV, NI>& a,
const at::vec::VectorizedN<T, NV>& value,
const at::vec::VectorizedN<int64_t, NI>& index) {
using v_t = at::vec::VecMask<T, NV>;
using i_t = at::vec::VecMask<int64_t, NI>;
i_t vmask_itype = vmask.template cast<int64_t, NI>();
v_t equal_mask = (a.value == value);
i_t equal_mask_itype = equal_mask.template cast<int64_t, NI>();
i_t imask(a.index < index);
return i_t::blendv(vmask_itype, imask, equal_mask_itype);
}
template <typename T, int NV, int NI>
inline IndexValueVec<T, NV, NI>& argmin_vec_impl(
IndexValueVec<T, NV, NI>& a,
at::vec::VectorizedN<T, NV> value,
at::vec::VectorizedN<int64_t, NI> index,
std::optional<int64_t> tail_size) {
at::vec::VecMask<T, NV> vmask(a.value < value);
at::vec::VecMask<int64_t, NI> final_mask =
get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index);
if (tail_size.has_value()) {
a.value = at::vec::VectorizedN<T, NV>::set(
a.value, at::vec::minimum(a.value, value), tail_size.value());
a.index = at::vec::VectorizedN<int64_t, NI>::set(
a.index,
at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask),
tail_size.value());
} else {
a.value = at::vec::minimum(a.value, value);
a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask);
}
return a;
}
template <typename T, int NV, int NI>
inline IndexValueVec<T, NV, NI>& argmax_vec_impl(
IndexValueVec<T, NV, NI>& a,
at::vec::VectorizedN<T, NV> value,
at::vec::VectorizedN<int64_t, NI> index,
std::optional<int64_t> tail_size) {
at::vec::VecMask<T, NV> vmask(a.value > value);
at::vec::VecMask<int64_t, NI> final_mask =
get_mask_for_argmin_argmax<T, NV, NI>(vmask, a, value, index);
if (tail_size.has_value()) {
a.value = at::vec::VectorizedN<T, NV>::set(
a.value, at::vec::maximum(a.value, value), tail_size.value());
a.index = at::vec::VectorizedN<int64_t, NI>::set(
a.index,
at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask),
tail_size.value());
} else {
a.value = at::vec::maximum(a.value, value);
a.index = at::vec::VecMask<int64_t, NI>::blendv(index, a.index, final_mask);
}
return a;
}
template <typename T, int NI, bool horizontal>
inline at::vec::VectorizedN<int64_t, NI> create_index(int64_t next_index) {
at::vec::VectorizedN<int64_t, NI> next_idx;
if constexpr (horizontal) {
next_idx = at::vec::VectorizedN<int64_t, NI>::arange(next_index, 1);
} else {
next_idx = at::vec::VectorizedN<int64_t, NI>(next_index);
}
return next_idx;
}
template <typename T, int NV, int NI, bool horizontal>
inline IndexValueVec<T, NV, NI>& argmin_combine_vec(
IndexValueVec<T, NV, NI>& a,
at::vec::VectorizedN<T, NV> next_value,
int64_t next_index,
std::optional<int64_t> tail_size = std::nullopt) {
auto next_idx = create_index<T, NI, horizontal>(next_index);
return argmin_vec_impl(a, next_value, next_idx, tail_size);
}
template <typename T, int NV, int NI, bool horizontal>
inline IndexValueVec<T, NV, NI>& argmax_combine_vec(
IndexValueVec<T, NV, NI>& a,
at::vec::VectorizedN<T, NV> next_value,
int64_t next_index,
std::optional<int64_t> tail_size = std::nullopt) {
auto next_idx = create_index<T, NI, horizontal>(next_index);
return argmax_vec_impl(a, next_value, next_idx, tail_size);
}
template <typename T, int NV, int NI>
inline IndexValue<T> argmin_vec_reduce_all(
const IndexValueVec<T, NV, NI>& vec) {
constexpr int len = at::vec::VectorizedN<T, NV>::size();
__at_align__ T tmpval[len];
__at_align__ int64_t tmpidx[len];
vec.value.store(tmpval);
vec.index.store(tmpidx);
IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]);
for (int i = 1; i < len; i++) {
res = argmin_combine(res, tmpval[i], tmpidx[i]);
}
return res;
}
template <typename T, int NV, int NI>
inline IndexValue<T> argmax_vec_reduce_all(
const IndexValueVec<T, NV, NI>& vec) {
constexpr int len = at::vec::VectorizedN<T, NV>::size();
__at_align__ T tmpval[len];
__at_align__ int64_t tmpidx[len];
vec.value.store(tmpval);
vec.index.store(tmpidx);
IndexValue res = IndexValue<T>(tmpidx[0], tmpval[0]);
for (int i = 1; i < len; i++) {
res = argmax_combine(res, tmpval[i], tmpidx[i]);
}
return res;
}
template <typename T, int NV, int NI>
inline IndexValueVec<T, NV, NI>& argmin_combine_vec(
IndexValueVec<T, NV, NI>& vec_a,
const IndexValueVec<T, NV, NI>& vec_b,
std::optional<int64_t> tail_size = std::nullopt) {
return argmin_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size);
}
template <typename T, int NV, int NI>
inline IndexValueVec<T, NV, NI>& argmax_combine_vec(
IndexValueVec<T, NV, NI>& vec_a,
const IndexValueVec<T, NV, NI>& vec_b,
std::optional<int64_t> tail_size = std::nullopt) {
return argmax_vec_impl(vec_a, vec_b.value, vec_b.index, tail_size);
}
template <typename scalar_t>
inline at::vec::Vectorized<scalar_t> vec_shuffle_down(
at::vec::Vectorized<scalar_t> x,
size_t n) {
using Vec = at::vec::Vectorized<scalar_t>;
alignas(alignof(Vec)) scalar_t array[Vec::size()];
x.store(array);
for (size_t i = 0; i + n < Vec::size(); i += 2 * n) {
array[i] = array[i + n];
}
return Vec::loadu(array);
}
#ifdef CPU_CAPABILITY_AVX2
inline at::vec::Vectorized<float> vec_shuffle_down(
at::vec::Vectorized<float> x,
size_t n) {
using vec_t = at::vec::Vectorized<float>;
#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
switch (n) {
case 1:
return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
case 2:
return vec_t(_mm256_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
case 4:
return vec_t(_mm256_permute2f128_ps(x, x, SHUFFLE_MASK(1, 1, 1, 1)));
}
throw std::runtime_error(
"Unhandled vec_shuffle_down value " + std::to_string(n));
}
#endif
#ifdef CPU_CAPABILITY_AVX512
inline at::vec::Vectorized<float> vec_shuffle_down(
at::vec::Vectorized<float> x,
size_t n) {
using vec_t = at::vec::Vectorized<float>;
#define SHUFFLE_MASK(z, y, x, w) ((z << 6) | (y << 4) | (x << 2) | w)
switch (n) {
case 1:
return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(1, 1, 3, 3)));
case 2:
return vec_t(_mm512_permute_ps(x, SHUFFLE_MASK(2, 2, 2, 2)));
case 4:
return vec_t(_mm512_permutexvar_ps(
_mm512_set_epi32(
12, 12, 12, 12, 12, 12, 12, 12, 4, 4, 4, 4, 4, 4, 4, 4),
x));
case 8:
return vec_t(_mm512_permutexvar_ps(
_mm512_set_epi32(8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8), x));
}
throw std::runtime_error(
"Unhandled vec_shuffle_down value " + std::to_string(n));
}
#endif
template <typename scalar_t>
Welford<scalar_t> welford_vec_reduce_all(
Welford<at::vec::Vectorized<scalar_t>> acc) {
using Vec = at::vec::Vectorized<scalar_t>;
Welford<scalar_t> result;
if (acc.index == 0) {
return result;
}
// if all values of acc.weight are same as index,
// use index to reduce to save the overhead of vec_shuffle_down for acc.weight
bool use_index = (acc.weight - Vec(acc.index)).zero_mask() ==
static_cast<int>((1 << Vec::size()) - 1);
for (size_t n = 1; n < Vec::size(); n *= 2) {
auto shuffled = Welford<Vec>{
vec_shuffle_down(acc.mean, n),
vec_shuffle_down(acc.m2, n),
use_index ? Vec(0) : vec_shuffle_down(acc.weight, n),
acc.index};
acc = welford_combine(acc, shuffled, use_index);
}
alignas(alignof(Vec)) scalar_t array[Vec::size()];
acc.mean.store(array);
result.mean = array[0];
acc.m2.store(array);
result.m2 = array[0];
acc.weight.store(array);
result.weight = array[0];
result.index = result.weight;
return result;
}
template <typename scalar_t>
Welford<scalar_t> welford_vec_reduce_all(
Welford<at::vec::VectorizedN<scalar_t, 2>> acc) {
auto Welford0 = Welford<at::vec::Vectorized<scalar_t>>{
acc.mean[0], acc.m2[0], acc.weight[0], acc.index};
auto Welford1 = Welford<at::vec::Vectorized<scalar_t>>{
acc.mean[1], acc.m2[1], acc.weight[1], acc.index};
return welford_vec_reduce_all(welford_combine(Welford0, Welford1));
}
#endif
template <typename T, typename U>
inline typename std::common_type_t<T, U> mod(T a, U b) {
return a % b;
}
template <>
inline float mod(float a, float b) {
return std::fmod(a, b);
}
template <>
inline double mod(double a, double b) {
return std::fmod(a, b);
}
template <typename scalar_t>
inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
if (at::_isnan(a)) {
return a;
}
return a > b ? a : b;
}
template <typename scalar_t>
inline scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
if (at::_isnan(a)) {
return a;
}
return a < b ? a : b;
}
constexpr float uint32_to_uniform_float(uint32_t value) {
// maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
constexpr float scale = 4.6566127342e-10;
return static_cast<float>(value & 0x7FFFFFFF) * scale;
}
inline float normalized_rand_cpu(uint32_t seed, uint32_t offset) {
return uint32_to_uniform_float(at::Philox4_32(seed, 0, offset)());
}
inline float randn_cpu(uint32_t seed, uint32_t offset) {
at::Philox4_32 engine(seed, 0, offset);
return engine.randn(10);
}
inline int64_t randint64_cpu(
uint32_t seed,
uint32_t offset,
int64_t low,
int64_t high) {
auto gen = at::Philox4_32(seed, 0, offset);
uint64_t r0 = gen();
uint64_t r1 = gen();
uint64_t result = r0 | (r1 << 32);
return static_cast<int64_t>(result % (high - low)) + low;
}
template <typename T>
struct AsIntegerType {
typedef T type;
};
template <>
struct AsIntegerType<float> {
typedef uint32_t type;
};
template <>
struct AsIntegerType<double> {
typedef uint64_t type;
};
template <>
struct AsIntegerType<at::BFloat16> {
typedef uint16_t type;
};
template <typename T>
typename std::enable_if_t<
!c10::is_reduced_floating_point_v<T>,
T> inline fetch_value(volatile T* addr) {
return *addr;
}
template <typename T>
typename std::enable_if_t<
c10::is_reduced_floating_point_v<T>,
T> inline fetch_value(volatile T* addr) {
return T(addr->x, T::from_bits());
}
template <typename T>
typename std::enable_if_t<!std::is_integral_v<T>> atomic_add(
volatile T* addr,
T offset) {
typedef typename AsIntegerType<T>::type alt_type;
static_assert(
sizeof(std::atomic<alt_type>) == sizeof(T), "std::atomic issue");
alt_type expected;
alt_type desired;
std::atomic<alt_type>* atomic_addr = (std::atomic<alt_type>*)addr;
do {
T val = fetch_value(addr);
reinterpret_cast<T*>(&expected)[0] = val;
reinterpret_cast<T*>(&desired)[0] = val + offset;
} while (!atomic_addr->compare_exchange_weak(
expected, desired, std::memory_order_relaxed));
}
// Since C++20 float is supported by fetch_add, but the performance may not
// better than compare_exchange_weak, which can be checked by microbenchmark
// inductor_cpu_atomic.py
template <typename T>
typename std::enable_if_t<std::is_integral_v<T>> atomic_add(
volatile T* addr,
T offset) {
static_assert(sizeof(std::atomic<T>) == sizeof(T), "std::atomic issue");
std::atomic<T>* atomic_addr = (std::atomic<T>*)addr;
atomic_addr->fetch_add(offset, std::memory_order_relaxed);
}
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T, int NI, int NV>
void atomic_add_vec(
T* addr,
at::vec::VectorizedN<int64_t, NI> index,
at::vec::VectorizedN<T, NV> offset) {
constexpr int len = at::vec::VectorizedN<int64_t, NI>::size();
static_assert(len <= at::vec::VectorizedN<T, NV>::size());
__at_align__ std::array<T, len> tmpbuf;
__at_align__ std::array<int64_t, len> tmpidx;
offset.store(tmpbuf.data(), len);
index.store(tmpidx.data(), len);
for (int i = 0; i < len; i++) {
atomic_add(addr + tmpidx[i], tmpbuf[i]);
}
}
template <typename T, bool atomic_add>
struct transpose_mxn_helper;
template <typename T>
struct transpose_mxn_helper<T, true> {
static void call(
const T* src,
int64_t ld_src,
T* dst,
int64_t ld_dst,
int M,
int N) {
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
atomic_add(&dst[j * ld_dst + i], src[i * ld_src + j]);
}
}
}
};
template <typename T>
struct transpose_mxn_helper<T, false> {
static void call(
const T* src,
int64_t ld_src,
T* dst,
int64_t ld_dst,
int M,
int N) {
at::vec::transpose_mxn<T>(src, ld_src, dst, ld_dst, M, N);
}
};
template <typename T, bool atomic_add>
inline void transpose_mxn(
const T* src,
int64_t ld_src,
T* dst,
int64_t ld_dst,
int M,
int N) {
transpose_mxn_helper<T, atomic_add>::call(src, ld_src, dst, ld_dst, M, N);
}
template <typename T, int M, int N, bool atomic_add>
inline void transpose_mxn(
const T* src,
int64_t ld_src,
T* dst,
int64_t ld_dst) {
transpose_mxn<T, atomic_add>(src, ld_src, dst, ld_dst, M, N);
}
#endif
// NOLINTBEGIN(*-avoid-c-arrays)
inline std::tuple<std::shared_ptr<int64_t[]>, int> _get_factors(
int64_t number) {
int count = 0;
for (auto i = static_cast<int64_t>(std::sqrt(number)); i > 0; --i) {
if (number % i == 0) {
count += 2;
}
}
auto factors = std::shared_ptr<int64_t[]>(new int64_t[count]);
int index = 0;
for (auto i = static_cast<int64_t>(std::sqrt(number)); i > 0; --i) {
if (number % i == 0) {
factors[index++] = number / i;
factors[index++] = i;
}
}
return std::make_tuple(factors, count);
}
inline std::tuple<std::shared_ptr<int64_t[]>, int> get_factors(int64_t number) {
thread_local std::map<int64_t, std::tuple<std::shared_ptr<int64_t[]>, int>>
cache;
auto it = cache.find(number);
if (it != cache.end()) {
return it->second;
} else {
auto factors = _get_factors(number);
cache[number] = factors;
return factors;
}
}
// NOLINTEND(*-avoid-c-arrays)
inline void _mm_get_thread_blocking(
int num_threads,
int max_k_slices,
int64_t M,
int64_t N,
int64_t K,
int64_t Mr,
int64_t Nr,
int64_t Kr,
int64_t& Mt,
int64_t& Nt,
int64_t& Kt) {
// see NOTE [Thread blocking in Cpp GEMM] for heuristics
Mt = Nt = Kt = 0;
auto get_blocking = [](int64_t m_factor,
int64_t n_factor,
int64_t k_factor,
int64_t m_blocks,
int64_t n_blocks,
int64_t k_blocks) {
int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor;
int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor;
int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor;
return std::make_tuple(thread_block_m, thread_block_n, thread_block_k);
};
auto is_better_blocking = [=](int64_t Mt_,
int64_t Nt_,
int64_t Kt_,
int64_t Mt,
int64_t Nt,
int64_t Kt) {
return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr;
};
int64_t m_blocks = (M + Mr - 1) / Mr;
int64_t n_blocks = (N + Nr - 1) / Nr;
int64_t k_blocks = (K + Kr - 1) / Kr;
auto [factors, count] = get_factors(num_threads);
assert(count > 0);
for (int i = 0; i < count; ++i) {
int64_t n_factor = factors[i];
int64_t m_factor = num_threads / n_factor;
if (n_blocks >= n_factor && m_blocks >= m_factor) {
auto [Mt_, Nt_, Kt_] =
get_blocking(m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks);
if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
}
}
}
if (Mt != 0) {
return;
}
for (int i = 0; i < count; ++i) {
int64_t k_factor = factors[i];
if (k_blocks >= k_factor &&
(max_k_slices == 0 || k_factor <= max_k_slices)) {
auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor);
for (int j = 0; j < mxn_count; ++j) {
int64_t n_factor = mxn_factors[j];
int64_t m_factor = num_threads / (k_factor * n_factor);
if (n_blocks >= n_factor && m_blocks >= m_factor) {
auto [Mt_, Nt_, Kt_] = get_blocking(
m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks);
if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
}
}
}
}
}
if (Mt != 0) {
return;
}
for (int i = 0; i < count; ++i) {
int64_t n_factor = factors[i];
int64_t m_factor = num_threads / n_factor;
if (n_blocks >= n_factor || m_blocks >= m_factor) {
auto [Mt_, Nt_, Kt_] =
get_blocking(m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks);
if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) {
std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_);
}
}
}
assert(Mt != 0);
}
inline void mm_get_thread_blocking(
int num_threads,
int max_k_slices,
int64_t M,
int64_t N,
int64_t K,
int64_t Mr,
int64_t Nr,
int64_t Kr,
int64_t& Mt,
int64_t& Nt,
int64_t& Kt) {
thread_local std::map<
std::
tuple<int, int, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t>,
std::tuple<int64_t, int64_t, int64_t>>
cache;
auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr);
auto it = cache.find(key);
if (it != cache.end()) {
std::tie(Mt, Nt, Kt) = it->second;
return;
} else {
_mm_get_thread_blocking(
num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt);
cache[key] = std::make_tuple(Mt, Nt, Kt);
}
}
// NOLINTBEGIN(*-narrowing-conversions)
template <typename X_t, typename W_t>
void _mm_get_cache_blocking(
int num_threads,
int64_t M,
int64_t N,
int64_t K,
int64_t Mr,
int64_t Nr,
int64_t Kr,
int64_t Mt_blocks,
int64_t Nt_blocks,
int64_t Kt_blocks,
int64_t& Mc_blocks,
int64_t& Nc_blocks,
int64_t& Kc_blocks,
uint32_t L1_cache_size,
uint32_t L2_cache_size) {
// See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking
// algorithm.
// TODO(jgong5): cache cache blocking results
// TODO: tune the factor here
float L1_limit_factor = 0.8;
float L2_limit_factor = 0.5;
auto L1 = L1_cache_size * L1_limit_factor;
auto L2 = L2_cache_size * L2_limit_factor;
constexpr size_t num_byte_A = sizeof(X_t);
constexpr size_t num_byte_B = sizeof(W_t);
int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B;
Kc_blocks = Kt_blocks;
if (size_cache_B > L1) {
Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B));
}
float min_Mc_ratio = 2;
int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr);
auto Kt_bytes = Kt_blocks * Kr * num_byte_A;
if (min_Mc_blocks * Mr * Kt_bytes < L2) {
Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes)));
Nc_blocks = 1;
} else {
Mc_blocks = Mt_blocks;
Nc_blocks =
std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
auto Nc_bytes = Nc_blocks * Nr * 4;
auto Kc_bytes = Kc_blocks * Kr * num_byte_A;
if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) {
auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8;
if (M_max < Mc_blocks * Mr) {
Mc_blocks = (int64_t)std::floor(M_max / Mr);
Nc_blocks =
std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks);
}
}
}
}
// NOLINTEND(*-narrowing-conversions)
template <typename X_t, typename W_t>
void mm_get_cache_blocking(
int num_threads,
int64_t M,
int64_t N,
int64_t K,
int64_t Mr,
int64_t Nr,
int64_t Kr,
int64_t Mt_blocks,
int64_t Nt_blocks,
int64_t Kt_blocks,
int64_t& Mc_blocks,
int64_t& Nc_blocks,
int64_t& Kc_blocks,
uint32_t L1_cache_size,
uint32_t L2_cache_size) {
thread_local std::map<
std::tuple<
int,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t>,
std::tuple<int64_t, int64_t, int64_t>>
cache;
auto key = std::make_tuple(
num_threads,
M,
N,
K,
Mr,
Nr,
Kr,
Mt_blocks,
Nt_blocks,
Kt_blocks,
L1_cache_size,
L2_cache_size);
auto it = cache.find(key);
if (it != cache.end()) {
std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second;
return;
} else {
_mm_get_cache_blocking<X_t, W_t>(
num_threads,
M,
N,
K,
Mr,
Nr,
Kr,
Mt_blocks,
Nt_blocks,
Kt_blocks,
Mc_blocks,
Nc_blocks,
Kc_blocks,
L1_cache_size,
L2_cache_size);
cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks);
}
}
struct amx_tilecfg {
uint8_t palette_id{0};
uint8_t start_row{0};
std::array<uint8_t, 14> reserved_0{};
std::array<uint16_t, 16> colsb{};
std::array<uint8_t, 16> rows{};
};
class AMXState {
private:
amx_tilecfg tilecfg_{};
uint8_t rows_{0};
uint16_t colsb_{0};
uint8_t num_tile_rows_{0};
uint8_t num_tile_columns_{0};
public:
AMXState() = default;
inline void configure(
uint8_t rows,
uint16_t colsb,
uint8_t num_tile_rows,
uint8_t num_tile_columns,
void (*loadconfig)(const amx_tilecfg&)) {
if (tilecfg_.palette_id == 1 && rows_ == rows && colsb_ == colsb &&
num_tile_rows_ == num_tile_rows &&
num_tile_columns_ == num_tile_columns) {
return;
}
tilecfg_.palette_id = 1;
rows_ = rows;
colsb_ = colsb;
num_tile_rows_ = num_tile_rows;
num_tile_columns_ = num_tile_columns;
const auto num_c_tiles = num_tile_rows * num_tile_columns;
// For C
for (int i = 0; i < num_c_tiles; i++) {
tilecfg_.rows[i] = rows;
tilecfg_.colsb[i] = 64;
}
// For A
for (int i = 0; i < num_tile_rows; i++) {
tilecfg_.rows[i + num_c_tiles] = rows;
tilecfg_.colsb[i + num_c_tiles] = colsb;
}
// For B
for (int i = 0; i < num_tile_columns; i++) {
tilecfg_.rows[i + num_c_tiles + num_tile_rows] = colsb / 4;
tilecfg_.colsb[i + num_c_tiles + num_tile_rows] = 64;
}
loadconfig(tilecfg_);
}
inline void release(void (*tile_release)()) {
tilecfg_.palette_id = 0;
tile_release();
}
};