Compare commits

...

1 Commits

3 changed files with 137 additions and 0 deletions

View File

@ -0,0 +1,30 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_fp16.h>
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size);
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz);
return input_cont;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
}

View File

@ -0,0 +1,100 @@
#include <cuda.h>
#include <cuda_fp16.h>
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
fused_bias_gelu<<<grid_dims, block_dims>>>(
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int);

View File

@ -15,12 +15,14 @@
"""PyTorch BLOOM model."""
import math
import os
from typing import Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.utils.cpp_extension import load
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
@ -51,6 +53,11 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
dirname = os.path.dirname(__file__)
cuda_gelu = load(name="bias_gelu_fp16", sources=[os.path.join(dirname, "gelu.cpp"), os.path.join(dirname, "gelu.cu")])
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.