mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Add riscv64 cpu support in deepspeed_shm_comm op (#7519)
This patch adds riscv64 support for the deepspeed_shm_comm operator,enabling DeepSpeed to perform CPU training/inference on RISCV64 hosts, for research purposes. Based on the discussion in pull #7387 , this patch refactors some original code to support multiple CPU architectures. Related tests have passed on x86 and RISC-V CPU, and I successfully ran Qwen2.5 on a RISC-V CPU, ```bash (myenv) [root@openeuler-riscv64 DeepSpeed ]$ pytest tests/unit/comm/test_dist.py::TestDistInferenceAllReduce -vv ====================================================================== test session starts ======================================================================= platform linux -- Python 3.11.4, pytest-7.2.0, pluggy-1.6.0 -- /root/myenv/bin/python3 cachedir: .pytest_cache hypothesis profile 'default' rootdir: /root/ecosystem/DeepSpeed/tests, configfile: pytest.ini plugins: mock-3.14.1, hypothesis-6.135.14, forked-1.6.0 collected 3 items tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype0] PASSED [ 33%] tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype1] PASSED [ 66%] tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype2] PASSED [100%] (myenv) root@ubuntu-2204:~/soft-working-dir/DeepSpeed# pytest tests/unit/comm/test_dist.py::TestDistInferenceAllReduce -vv ====================================================================== test session starts ======================================================================= platform linux -- Python 3.12.3, pytest-7.2.0, pluggy-1.6.0 -- /root/soft-working-dir/myenv/bin/python3 cachedir: .pytest_cache rootdir: /root/soft-working-dir/DeepSpeed/tests, configfile: pytest.ini plugins: forked-1.6.0 collected 3 items tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype0] PASSED [ 33%] tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype1] PASSED [ 66%] tests/unit/comm/test_dist.py::TestDistInferenceAllReduce::test[dtype2] PASSED [100%] ``` --------- Signed-off-by: heyujiao99 <he.yujiao@sanechips.com.cn> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>
This commit is contained in:
83
csrc/cpu/comm/riscv64/shm.h
Normal file
83
csrc/cpu/comm/riscv64/shm.h
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
// DeepSpeed Team
|
||||||
|
|
||||||
|
#include <riscv_vector.h>
|
||||||
|
#include <cmath>
|
||||||
|
using float16_t = _Float16;
|
||||||
|
|
||||||
|
inline vfloat32m2_t cvt_bf16_to_fp32(vuint16m1_t src, size_t vl) __attribute__((target("arch=+v")));
|
||||||
|
inline vfloat32m2_t cvt_bf16_to_fp32(vuint16m1_t src, size_t vl)
|
||||||
|
{
|
||||||
|
vuint32m2_t widened = __riscv_vwcvtu_x_x_v_u32m2(src, vl);
|
||||||
|
return __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsll_vx_u32m2(widened, 16, vl));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline vuint16m1_t cvt_fp32_to_bf16(vfloat32m2_t src, size_t vl) __attribute__((target("arch=+v")));
|
||||||
|
inline vuint16m1_t cvt_fp32_to_bf16(vfloat32m2_t src, size_t vl)
|
||||||
|
{
|
||||||
|
vuint32m2_t value = __riscv_vreinterpret_v_f32m2_u32m2(src);
|
||||||
|
vuint32m2_t nan = __riscv_vmv_v_x_u32m2(0xFFFF, vl);
|
||||||
|
vbool16_t mask_value = __riscv_vmfne_vv_f32m2_b16(src, src, vl);
|
||||||
|
vuint32m2_t ones = __riscv_vmv_v_x_u32m2(0x1, vl);
|
||||||
|
vuint32m2_t vec_bias = __riscv_vmv_v_x_u32m2(0x7FFF, vl);
|
||||||
|
// uint32_t lsb = (input >> 16) & 1;
|
||||||
|
vuint32m2_t t_value = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(value, 16, vl), 0x1, vl);
|
||||||
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
t_value = __riscv_vadd_vv_u32m2(t_value, vec_bias, vl);
|
||||||
|
// input += rounding_bias;
|
||||||
|
t_value = __riscv_vadd_vv_u32m2(t_value, value, vl);
|
||||||
|
// input = input >> 16;
|
||||||
|
t_value = __riscv_vsrl_vx_u32m2(t_value, 16, vl);
|
||||||
|
// Check NaN before converting back to bf16
|
||||||
|
t_value = __riscv_vmerge_vvm_u32m2(t_value, nan, mask_value, vl);
|
||||||
|
|
||||||
|
return __riscv_vncvt_x_x_w_u16m1(t_value, vl);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline vfloat32m2_t cvt_fp16_to_fp32(vfloat16m1_t src, size_t vl)
|
||||||
|
__attribute__((target("arch=+v,+zvfh")));
|
||||||
|
inline vfloat32m2_t cvt_fp16_to_fp32(vfloat16m1_t src, size_t vl)
|
||||||
|
{
|
||||||
|
return __riscv_vfwcvt_f_f_v_f32m2(src, vl);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline vfloat16m1_t cvt_fp32_to_fp16(vfloat32m2_t src, size_t vl)
|
||||||
|
__attribute__((target("arch=+v,+zvfh")));
|
||||||
|
inline vfloat16m1_t cvt_fp32_to_fp16(vfloat32m2_t src, size_t vl)
|
||||||
|
{
|
||||||
|
return __riscv_vfncvt_rod_f_f_w_f16m1(src, vl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
|
||||||
|
// iteration depends on vector length. Dynamically acquired via the vsetvl instruction to
|
||||||
|
// compatible with different vector length.
|
||||||
|
static int vector_length_in_bytes = -1;
|
||||||
|
|
||||||
|
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("arch=+v")));
|
||||||
|
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("arch=+v,+zvfh")));
|
||||||
|
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("arch=+v")));
|
||||||
|
|
||||||
|
void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("arch=+v")));
|
||||||
|
|
||||||
|
#define VLOAD_U8(X) __riscv_vle8_v_u8m1((uint8_t*)(X), vl)
|
||||||
|
#define VLOAD_U16(X) __riscv_vle16_v_u16m1((uint16_t*)(X), vl)
|
||||||
|
#define VLOAD_F16(X) __riscv_vle16_v_f16m1((float16_t*)(X), vl)
|
||||||
|
#define VLOAD_F32(X) __riscv_vle32_v_f32m1((float*)(X), vl)
|
||||||
|
|
||||||
|
#define VSTORE_U8(A, B) __riscv_vse8_v_u8m1((uint8_t*)(A), B, vl)
|
||||||
|
#define VSTORE_U16(A, B) __riscv_vse16_v_u16m1((uint16_t*)(A), B, vl)
|
||||||
|
#define VSTORE_F16(A, B) __riscv_vse16_v_f16m1((float16_t*)(A), B, vl)
|
||||||
|
#define VSTORE_F32(A, B) __riscv_vse32_v_f32m1((float*)(A), B, vl)
|
||||||
|
|
||||||
|
#define VADD_F32(A, B) __riscv_vfadd_vv_f32m1(A, B, vl)
|
||||||
|
#define VADD_F32_2VL(A, B) __riscv_vfadd_vv_f32m2(A, B, vl)
|
||||||
|
|
||||||
|
#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X, vl)
|
||||||
|
#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X, vl)
|
||||||
|
#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X, vl)
|
||||||
|
#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X, vl)
|
@ -7,11 +7,17 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <immintrin.h>
|
|
||||||
#include <semaphore.h>
|
#include <semaphore.h>
|
||||||
#include <sys/mman.h>
|
#include <sys/mman.h>
|
||||||
#include "shm.h"
|
#include "shm.h"
|
||||||
|
|
||||||
|
#if defined(__riscv)
|
||||||
|
#define TARGET_RISCV 1
|
||||||
|
#include "riscv64/shm.h"
|
||||||
|
#else
|
||||||
|
#include "x86_64/shm.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
// #define DO_PROFILE
|
// #define DO_PROFILE
|
||||||
#ifdef DO_PROFILE
|
#ifdef DO_PROFILE
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
@ -115,52 +121,6 @@ void wait_buffer_state_until_2(int index,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
|
||||||
inline __m512 cvt_bf16_to_fp32(const __m256i src)
|
|
||||||
{
|
|
||||||
auto y = _mm512_cvtepu16_epi32(src);
|
|
||||||
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
|
|
||||||
inline __m256i cvt_fp32_to_bf16(const __m512 src)
|
|
||||||
{
|
|
||||||
__m512i value = _mm512_castps_si512(src);
|
|
||||||
__m512i nan = _mm512_set1_epi32(0xffff);
|
|
||||||
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
|
||||||
__m512i ones = _mm512_set1_epi32(0x1);
|
|
||||||
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
|
||||||
// uint32_t lsb = (input >> 16) & 1;
|
|
||||||
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
|
||||||
// uint32_t rounding_bias = 0x7fff + lsb;
|
|
||||||
t_value = _mm512_add_epi32(t_value, vec_bias);
|
|
||||||
// input += rounding_bias;
|
|
||||||
t_value = _mm512_add_epi32(t_value, value);
|
|
||||||
// input = input >> 16;
|
|
||||||
t_value = _mm512_srli_epi32(t_value, 16);
|
|
||||||
// Check NaN before converting back to bf16
|
|
||||||
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
|
||||||
return _mm512_cvtusepi32_epi16(t_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
|
||||||
inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); }
|
|
||||||
|
|
||||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
|
||||||
inline __m256i cvt_fp32_to_fp16(const __m512 src)
|
|
||||||
{
|
|
||||||
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
|
||||||
}
|
|
||||||
|
|
||||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
||||||
__attribute__((target("avx512bw")));
|
|
||||||
|
|
||||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
||||||
__attribute__((target("avx512bw")));
|
|
||||||
|
|
||||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
||||||
__attribute__((target("avx512bw")));
|
|
||||||
|
|
||||||
void reduce_all_buffers(int start_elements,
|
void reduce_all_buffers(int start_elements,
|
||||||
int num_elements,
|
int num_elements,
|
||||||
c10::ScalarType scalar_type,
|
c10::ScalarType scalar_type,
|
||||||
@ -182,30 +142,29 @@ void reduce_all_buffers(int start_elements,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CVT_ADD_BF16(x) \
|
#define CVT_ADD_BF16(x) \
|
||||||
do { \
|
do { \
|
||||||
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
auto in##x##_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[x] + i)); \
|
||||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
inout_val = VADD_F32_2VL(inout_val, in##x##_val); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
|
|
||||||
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
|
|
||||||
// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
|
|
||||||
// to be changed
|
|
||||||
#define VECTOR_LENGTH_IN_BYTES 32
|
|
||||||
|
|
||||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
{
|
{
|
||||||
const int element_size = 2;
|
const int element_size = 2;
|
||||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
#if TARGET_RISCV
|
||||||
int main_elements = num_elements - (num_elements % vector_length);
|
size_t vl = __riscv_vsetvl_e16m1(num_elements);
|
||||||
int remain_elements = num_elements % vector_length;
|
vector_length_in_bytes = vl * element_size;
|
||||||
|
#else
|
||||||
|
const int vl = vector_length_in_bytes / element_size;
|
||||||
|
#endif
|
||||||
|
int main_elements = num_elements - (num_elements % vl);
|
||||||
|
int remain_elements = num_elements % vl;
|
||||||
|
|
||||||
// process aligned part
|
// process aligned part
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||||
i += VECTOR_LENGTH_IN_BYTES) {
|
i += vector_length_in_bytes) {
|
||||||
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
auto inout_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[0] + i));
|
||||||
switch (world_size) {
|
switch (world_size) {
|
||||||
case 16: CVT_ADD_BF16(15);
|
case 16: CVT_ADD_BF16(15);
|
||||||
case 15: CVT_ADD_BF16(14);
|
case 15: CVT_ADD_BF16(14);
|
||||||
@ -225,11 +184,11 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
|
|||||||
case 1: break;
|
case 1: break;
|
||||||
default:
|
default:
|
||||||
for (int j = 1; j < world_size; j++) {
|
for (int j = 1; j < world_size; j++) {
|
||||||
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
auto in_val = CVT_BF16_TO_FP32(VLOAD_U16(buffers[j] + i));
|
||||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
inout_val = VADD_F32_2VL(inout_val, in_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
|
VSTORE_U16(to_buffer + i, CVT_FP32_TO_BF16(inout_val));
|
||||||
}
|
}
|
||||||
|
|
||||||
// process remaining part
|
// process remaining part
|
||||||
@ -243,24 +202,29 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CVT_ADD_FP16(x) \
|
#define CVT_ADD_FP16(x) \
|
||||||
do { \
|
do { \
|
||||||
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
auto in##x##_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[x] + i)); \
|
||||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
inout_val = VADD_F32_2VL(inout_val, in##x##_val); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
{
|
{
|
||||||
const int element_size = 2;
|
const int element_size = 2;
|
||||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
#if TARGET_RISCV
|
||||||
int main_elements = num_elements - (num_elements % vector_length);
|
size_t vl = __riscv_vsetvl_e16m1(num_elements);
|
||||||
int remain_elements = num_elements % vector_length;
|
vector_length_in_bytes = vl * element_size;
|
||||||
|
#else
|
||||||
|
const int vl = vector_length_in_bytes / element_size;
|
||||||
|
#endif
|
||||||
|
int main_elements = num_elements - (num_elements % vl);
|
||||||
|
int remain_elements = num_elements % vl;
|
||||||
|
|
||||||
// process aligned part
|
// process aligned part
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||||
i += VECTOR_LENGTH_IN_BYTES) {
|
i += vector_length_in_bytes) {
|
||||||
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
auto inout_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[0] + i));
|
||||||
switch (world_size) {
|
switch (world_size) {
|
||||||
case 16: CVT_ADD_FP16(15);
|
case 16: CVT_ADD_FP16(15);
|
||||||
case 15: CVT_ADD_FP16(14);
|
case 15: CVT_ADD_FP16(14);
|
||||||
@ -280,11 +244,11 @@ void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer,
|
|||||||
case 1: break;
|
case 1: break;
|
||||||
default:
|
default:
|
||||||
for (int j = 1; j < world_size; j++) {
|
for (int j = 1; j < world_size; j++) {
|
||||||
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
auto in_val = CVT_FP16_TO_FP32(VLOAD_F16(buffers[j] + i));
|
||||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
inout_val = VADD_F32_2VL(inout_val, in_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
|
VSTORE_F16(to_buffer + i, CVT_FP32_TO_FP16(inout_val));
|
||||||
}
|
}
|
||||||
|
|
||||||
// process remaining part
|
// process remaining part
|
||||||
@ -298,24 +262,29 @@ void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CVT_ADD_F32(x) \
|
#define CVT_ADD_F32(x) \
|
||||||
do { \
|
do { \
|
||||||
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
|
auto in##x##_val = VLOAD_F32(buffers[x] + i); \
|
||||||
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
inout_val = VADD_F32(inout_val, in##x##_val); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
{
|
{
|
||||||
const int element_size = 4;
|
const int element_size = 4;
|
||||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
#if TARGET_RISCV
|
||||||
int main_elements = num_elements - (num_elements % vector_length);
|
size_t vl = __riscv_vsetvl_e32m1(num_elements);
|
||||||
int remain_elements = num_elements % vector_length;
|
vector_length_in_bytes = vl * element_size;
|
||||||
|
#else
|
||||||
|
const int vl = vector_length_in_bytes / element_size;
|
||||||
|
#endif
|
||||||
|
int main_elements = num_elements - (num_elements % vl);
|
||||||
|
int remain_elements = num_elements % vl;
|
||||||
|
|
||||||
// process aligned part
|
// process aligned part
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||||
i += VECTOR_LENGTH_IN_BYTES) {
|
i += vector_length_in_bytes) {
|
||||||
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
|
auto inout_val = VLOAD_F32(buffers[0] + i);
|
||||||
switch (world_size) {
|
switch (world_size) {
|
||||||
case 16: CVT_ADD_F32(15);
|
case 16: CVT_ADD_F32(15);
|
||||||
case 15: CVT_ADD_F32(14);
|
case 15: CVT_ADD_F32(14);
|
||||||
@ -335,11 +304,11 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
|
|||||||
case 1: break;
|
case 1: break;
|
||||||
default:
|
default:
|
||||||
for (int j = 1; j < world_size; j++) {
|
for (int j = 1; j < world_size; j++) {
|
||||||
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
|
auto in_val = VLOAD_F32(buffers[j] + i);
|
||||||
inout_val = _mm256_add_ps(inout_val, in_val);
|
inout_val = VADD_F32(inout_val, in_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
|
VSTORE_F32(to_buffer + i, inout_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// process remaining part
|
// process remaining part
|
||||||
@ -412,16 +381,18 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes)
|
void parallel_memcpy(void* to, void* from, size_t n_bytes)
|
||||||
__attribute__((target("avx512bw")));
|
|
||||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes)
|
|
||||||
{
|
{
|
||||||
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
|
#if TARGET_RISCV
|
||||||
|
size_t vl = __riscv_vsetvl_e8m1(n_bytes);
|
||||||
|
vector_length_in_bytes = vl;
|
||||||
|
#endif
|
||||||
|
auto aligned_bytes = n_bytes - (n_bytes % vector_length_in_bytes);
|
||||||
// process aligned part
|
// process aligned part
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
|
for (int i = 0; i < aligned_bytes; i += vector_length_in_bytes) {
|
||||||
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
|
auto val = VLOAD_U8((char*)from + i);
|
||||||
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
|
VSTORE_U8((char*)to + i, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// process remaining part
|
// process remaining part
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#ifndef __SHM_COLLECTIVES__
|
#ifndef __SHM_COLLECTIVES__
|
||||||
#define __SHM_COLLECTIVES__
|
#define __SHM_COLLECTIVES__
|
||||||
#define VECTOR_LENGTH_IN_BYTES 32
|
|
||||||
void shm_initialize(int size, int rank, char* addr_string, char* port_string);
|
void shm_initialize(int size, int rank, char* addr_string, char* port_string);
|
||||||
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
|
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
|
||||||
void barrier_wait(int root_idx, int num_ranks);
|
void barrier_wait(int root_idx, int num_ranks);
|
||||||
|
76
csrc/cpu/comm/x86_64/shm.h
Normal file
76
csrc/cpu/comm/x86_64/shm.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
// Copyright (c) Microsoft Corporation.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
// DeepSpeed Team
|
||||||
|
|
||||||
|
#include <immintrin.h>
|
||||||
|
|
||||||
|
inline __m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||||
|
inline __m512 cvt_bf16_to_fp32(const __m256i src)
|
||||||
|
{
|
||||||
|
auto y = _mm512_cvtepu16_epi32(src);
|
||||||
|
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
|
||||||
|
inline __m256i cvt_fp32_to_bf16(const __m512 src)
|
||||||
|
{
|
||||||
|
__m512i value = _mm512_castps_si512(src);
|
||||||
|
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||||
|
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
||||||
|
__m512i ones = _mm512_set1_epi32(0x1);
|
||||||
|
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
||||||
|
// uint32_t lsb = (input >> 16) & 1;
|
||||||
|
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
||||||
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||||
|
t_value = _mm512_add_epi32(t_value, vec_bias);
|
||||||
|
// input += rounding_bias;
|
||||||
|
t_value = _mm512_add_epi32(t_value, value);
|
||||||
|
// input = input >> 16;
|
||||||
|
t_value = _mm512_srli_epi32(t_value, 16);
|
||||||
|
// Check NaN before converting back to bf16
|
||||||
|
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
||||||
|
return _mm512_cvtusepi32_epi16(t_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||||
|
inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); }
|
||||||
|
|
||||||
|
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
||||||
|
inline __m256i cvt_fp32_to_fp16(const __m512 src)
|
||||||
|
{
|
||||||
|
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
|
||||||
|
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
|
||||||
|
// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
|
||||||
|
// to be changed
|
||||||
|
static int vector_length_in_bytes = 32;
|
||||||
|
|
||||||
|
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("avx512bw")));
|
||||||
|
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("avx512bw")));
|
||||||
|
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||||
|
__attribute__((target("avx512bw")));
|
||||||
|
|
||||||
|
void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
|
||||||
|
|
||||||
|
#define VLOAD_U8(X) _mm256_loadu_si256((__m256i*)(X))
|
||||||
|
#define VLOAD_U16(X) _mm256_loadu_si256((__m256i*)(X))
|
||||||
|
#define VLOAD_F16(X) _mm256_loadu_si256((__m256i*)(X))
|
||||||
|
#define VLOAD_F32(X) _mm256_loadu_ps((float*)(X))
|
||||||
|
|
||||||
|
#define VSTORE_U8(A, B) _mm256_storeu_si256((__m256i*)(A), B)
|
||||||
|
#define VSTORE_U16(A, B) _mm256_storeu_si256((__m256i*)(A), B)
|
||||||
|
#define VSTORE_F16(A, B) _mm256_storeu_si256((__m256i*)(A), B)
|
||||||
|
#define VSTORE_F32(A, B) _mm256_storeu_ps((float*)(A), B)
|
||||||
|
|
||||||
|
#define VADD_F32(A, B) _mm256_add_ps(A, B)
|
||||||
|
#define VADD_F32_2VL(A, B) _mm512_add_ps(A, B)
|
||||||
|
|
||||||
|
#define CVT_BF16_TO_FP32(X) cvt_bf16_to_fp32(X)
|
||||||
|
#define CVT_FP16_TO_FP32(X) cvt_fp16_to_fp32(X)
|
||||||
|
#define CVT_FP32_TO_BF16(X) cvt_fp32_to_bf16(X)
|
||||||
|
#define CVT_FP32_TO_FP16(X) cvt_fp32_to_fp16(X)
|
@ -420,6 +420,8 @@ class OpBuilder(ABC):
|
|||||||
if cpu_info['arch'].startswith('PPC_'):
|
if cpu_info['arch'].startswith('PPC_'):
|
||||||
# gcc does not provide -march on PowerPC, use -mcpu instead
|
# gcc does not provide -march on PowerPC, use -mcpu instead
|
||||||
return '-mcpu=native'
|
return '-mcpu=native'
|
||||||
|
elif cpu_info['arch'].startswith('riscv64'):
|
||||||
|
return '-march=rv64gc'
|
||||||
return '-march=native'
|
return '-march=native'
|
||||||
|
|
||||||
def get_cuda_compile_flag(self):
|
def get_cuda_compile_flag(self):
|
||||||
@ -455,6 +457,8 @@ class OpBuilder(ABC):
|
|||||||
cpu_info['flags'] += 'avx2'
|
cpu_info['flags'] += 'avx2'
|
||||||
elif 'ppc64le' in result:
|
elif 'ppc64le' in result:
|
||||||
cpu_info['arch'] = "PPC_"
|
cpu_info['arch'] = "PPC_"
|
||||||
|
elif 'riscv64' in result:
|
||||||
|
cpu_info['arch'] = "riscv64"
|
||||||
|
|
||||||
return cpu_info
|
return cpu_info
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user