mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
1 Commits
v4.40.1
...
custom_blo
Author | SHA1 | Date | |
---|---|---|---|
22ddccc40e |
30
src/transformers/models/bloom/gelu.cpp
Normal file
30
src/transformers/models/bloom/gelu.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void launch_bias_gelu(T* input,
|
||||||
|
const T* bias,
|
||||||
|
int intermediate_size,
|
||||||
|
int batch_size);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
|
||||||
|
{
|
||||||
|
auto input_cont = input.contiguous();
|
||||||
|
|
||||||
|
int bsz = input_cont.size(0) * input_cont.size(1);
|
||||||
|
int intermediate_size = input_cont.size(2);
|
||||||
|
|
||||||
|
launch_bias_gelu((T*)input_cont.data_ptr(),
|
||||||
|
(T*)bias.data_ptr(),
|
||||||
|
intermediate_size,
|
||||||
|
bsz);
|
||||||
|
return input_cont;
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
|
{
|
||||||
|
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
|
||||||
|
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
|
||||||
|
}
|
100
src/transformers/models/bloom/gelu.cu
Normal file
100
src/transformers/models/bloom/gelu.cu
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#define MAX_CAP 4
|
||||||
|
#define MAX_SEQ 2048
|
||||||
|
|
||||||
|
inline __device__ float gelu(const float x)
|
||||||
|
{
|
||||||
|
const float sqrt_param = 0.79788456080286535587989211986876f;
|
||||||
|
const float mul_param = 0.044715;
|
||||||
|
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void fused_bias_gelu(float* input,
|
||||||
|
const float* bias,
|
||||||
|
int total_count,
|
||||||
|
int intermediate_size)
|
||||||
|
{
|
||||||
|
float4* input_cast = reinterpret_cast<float4*>(input);
|
||||||
|
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
|
||||||
|
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (offset < total_count) {
|
||||||
|
float4 data = input_cast[offset];
|
||||||
|
float4 bias_data = bias_cast[offset % intermediate_size];
|
||||||
|
|
||||||
|
data.x += bias_data.x;
|
||||||
|
data.y += bias_data.y;
|
||||||
|
data.z += bias_data.z;
|
||||||
|
data.w += bias_data.w;
|
||||||
|
|
||||||
|
data.x = gelu(data.x);
|
||||||
|
data.y = gelu(data.y);
|
||||||
|
data.z = gelu(data.z);
|
||||||
|
data.w = gelu(data.w);
|
||||||
|
|
||||||
|
input_cast[offset] = data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void fused_bias_gelu(__half* input,
|
||||||
|
const __half* bias,
|
||||||
|
int total_count,
|
||||||
|
int intermediate_size)
|
||||||
|
{
|
||||||
|
#ifdef HALF_PRECISION_AVAILABLE
|
||||||
|
|
||||||
|
float2* input_cast = reinterpret_cast<float2*>(input);
|
||||||
|
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
|
||||||
|
|
||||||
|
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (offset < total_count) {
|
||||||
|
float2 vals_vec = input_cast[offset];
|
||||||
|
float2 bias_vec = bias_cast[offset % intermediate_size];
|
||||||
|
|
||||||
|
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
|
||||||
|
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
|
||||||
|
|
||||||
|
float2 low_data = __half22float2(vals_half[0]);
|
||||||
|
float2 high_data = __half22float2(vals_half[1]);
|
||||||
|
|
||||||
|
float2 low_bias = __half22float2(bias_half[0]);
|
||||||
|
float2 high_bias = __half22float2(bias_half[1]);
|
||||||
|
|
||||||
|
low_data.x += low_bias.x;
|
||||||
|
low_data.y += low_bias.y;
|
||||||
|
high_data.x += high_bias.x;
|
||||||
|
high_data.y += high_bias.y;
|
||||||
|
|
||||||
|
low_data.x = gelu(low_data.x);
|
||||||
|
low_data.y = gelu(low_data.y);
|
||||||
|
high_data.x = gelu(high_data.x);
|
||||||
|
high_data.y = gelu(high_data.y);
|
||||||
|
|
||||||
|
vals_half[0] = __float22half2_rn(low_data);
|
||||||
|
vals_half[1] = __float22half2_rn(high_data);
|
||||||
|
|
||||||
|
input_cast[offset] = vals_vec;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void launch_bias_gelu(T* input,
|
||||||
|
const T* bias,
|
||||||
|
int intermediate_size,
|
||||||
|
int batch_size)
|
||||||
|
{
|
||||||
|
int total_count = batch_size * (intermediate_size / 4);
|
||||||
|
int threads = 1024; // intermediate_size / iterations / 4;
|
||||||
|
dim3 block_dims(threads);
|
||||||
|
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
|
||||||
|
|
||||||
|
fused_bias_gelu<<<grid_dims, block_dims>>>(
|
||||||
|
input, bias, total_count, intermediate_size / 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
template void launch_bias_gelu<float>(float*, const float*, int, int);
|
||||||
|
template void launch_bias_gelu<__half>(__half*, const __half*, int, int);
|
@ -15,12 +15,14 @@
|
|||||||
"""PyTorch BLOOM model."""
|
"""PyTorch BLOOM model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@ -51,6 +53,11 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
cuda_gelu = load(name="bias_gelu_fp16", sources=[os.path.join(dirname, "gelu.cpp"), os.path.join(dirname, "gelu.cu")])
|
||||||
|
|
||||||
|
|
||||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||||
"""
|
"""
|
||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
|
Reference in New Issue
Block a user