mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
64 Commits
v4.43.3
...
thomas/add
Author | SHA1 | Date | |
---|---|---|---|
46d37bece7 | |||
bd232fb97d | |||
5924480094 | |||
a009ca5a3b | |||
4e383e7a33 | |||
870e180f08 | |||
b395423c8b | |||
c63e0a7377 | |||
ccf4b2d610 | |||
989fac0556 | |||
3fc5ebb9e3 | |||
c4edcf9c70 | |||
94e6226b3a | |||
4787e17582 | |||
07e9409032 | |||
435f8fcbf2 | |||
9d4257a74c | |||
68944f597d | |||
0c10634a36 | |||
bd7101dace | |||
ee4fe12157 | |||
44430fbe14 | |||
6c25f536cc | |||
6fd4c59911 | |||
b70a818d50 | |||
2b266f0a39 | |||
9e80f9b190 | |||
9bdd6c8286 | |||
389f196b1c | |||
b0d6f72b55 | |||
9b448ad29d | |||
218f457388 | |||
616690fa8a | |||
12ac8b7ec3 | |||
379ef0a888 | |||
5bdb735c28 | |||
8d8f5864b7 | |||
56915604ac | |||
99c3d575aa | |||
961a341e31 | |||
76e18bcfd1 | |||
20d9aa6058 | |||
82dbd6f64f | |||
066f1d8c47 | |||
493dc31d78 | |||
7168222a3a | |||
dfaa2e37f9 | |||
aa47de5a56 | |||
9067981e0c | |||
e4bd0884ce | |||
ee234a808f | |||
62ef07b72e | |||
a7034e3211 | |||
c698398bb3 | |||
0cad39a4e5 | |||
d92aa10587 | |||
90821b2f13 | |||
68f0b97531 | |||
2c1e4554d0 | |||
452b9d00f8 | |||
8941c2df47 | |||
e8e15437c5 | |||
4f98d8e136 | |||
f4d0dc3c15 |
5
README_BUILD.md
Normal file
5
README_BUILD.md
Normal file
@ -0,0 +1,5 @@
|
||||
## We provide some functions in order to custom build some kernels
|
||||
|
||||
```bash
|
||||
python setup.py build_ext --inplace
|
||||
```
|
33
setup.py
33
setup.py
@ -75,8 +75,10 @@ from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
||||
|
||||
stale_egg_info = Path(__file__).parent / "transformers.egg-info"
|
||||
if stale_egg_info.exists():
|
||||
print(
|
||||
@ -397,6 +399,31 @@ install_requires = [
|
||||
deps["tokenizers"],
|
||||
deps["tqdm"], # progress bars in model download and training scripts
|
||||
]
|
||||
def get_extensions():
|
||||
# TODO @thomasw21 add cpp versions
|
||||
extensions = []
|
||||
|
||||
# TODO @thomasw21 build cuda kernels only on some conditions
|
||||
if True:
|
||||
extensions += [
|
||||
CUDAExtension(
|
||||
name="transformers.models.bloom.custom_kernels.fused_bloom_attention_cuda",
|
||||
sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||
# TODO: understand what that is, probably defines the target architecture
|
||||
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
|
||||
# Build for A100
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
CUDAExtension(
|
||||
name="transformers.models.bloom.custom_kernels.fused_bloom_gelu_cuda",
|
||||
sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_gelu_cuda.cu"],
|
||||
# TODO: understand what that is, probably defines the target architecture
|
||||
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
|
||||
# Build for A100
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
]
|
||||
return extensions
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
@ -429,5 +456,9 @@ setup(
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
cmdclass={"deps_table_update": DepsTableUpdateCommand},
|
||||
ext_modules=get_extensions(),
|
||||
cmdclass={
|
||||
"deps_table_update": DepsTableUpdateCommand,
|
||||
"build_ext": BuildExtension
|
||||
},
|
||||
)
|
||||
|
@ -481,11 +481,11 @@ class GenerationMixin:
|
||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
||||
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
|
||||
return torch.full(shape, fill_value=-100, dtype=torch.long, device=self.device)
|
||||
|
||||
if bos_token_id is None:
|
||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
|
||||
return torch.full((1, 1), fill_value=bos_token_id, dtype=torch.long, device=self.device)
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
@ -541,7 +541,7 @@ class GenerationMixin:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
if device is None:
|
||||
device = self.device
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
|
||||
return torch.full((batch_size, 1), fill_value=decoder_start_token_id, dtype=torch.long, device=device)
|
||||
|
||||
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
|
||||
decoder_start_token_id = (
|
||||
@ -1691,7 +1691,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
unfinished_sequences = input_ids.new_full((input_ids.shape[0],), fill_value=True, dtype=torch.bool)
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
@ -1752,7 +1752,7 @@ class GenerationMixin:
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
next_tokens.masked_fill_(~unfinished_sequences, pad_token_id)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
@ -1763,10 +1763,10 @@ class GenerationMixin:
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
unfinished_sequences &= next_tokens != eos_token_id
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
if not (torch.any(unfinished_sequences) and not stopping_criteria(input_ids, scores)):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
|
@ -16,6 +16,7 @@
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
|
||||
import torch.distributed
|
||||
from transformers import is_torch_available
|
||||
|
||||
|
||||
@ -130,6 +131,7 @@ class BloomConfig(PretrainedConfig):
|
||||
attention_dropout=0.0,
|
||||
pretraining_tp=1, # TP rank used when training with megatron
|
||||
slow_but_exact=False,
|
||||
tp_parallel=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -149,6 +151,7 @@ class BloomConfig(PretrainedConfig):
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.slow_but_exact = slow_but_exact
|
||||
self.tp_parallel = tp_parallel
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
|
@ -0,0 +1,250 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <optional>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
|
||||
**/
|
||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
|
||||
__global__ void forward_masked_softmax_kernel(
|
||||
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
|
||||
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
|
||||
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
|
||||
const int64_t effective_kv_length,
|
||||
const dim3 blockDim,
|
||||
const int64_t rows_per_block,
|
||||
const int64_t kv_length,
|
||||
const int64_t batch_size
|
||||
) {
|
||||
const auto row_id = threadIdx.x / effective_kv_length;
|
||||
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
|
||||
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
|
||||
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
|
||||
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
|
||||
const auto kv_length_end = kv_length_end_;
|
||||
|
||||
const auto batch_id = blockIdx.x * rows_per_block + row_id;
|
||||
|
||||
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
|
||||
extern __shared__ float temp_storage[];
|
||||
const auto row_id_mem_offset = row_id * 2;
|
||||
if (effective_kv_length_id == 0) {
|
||||
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
|
||||
temp_storage[row_id_mem_offset + 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute mask and max
|
||||
if (batch_id < batch_size) {
|
||||
float thread_max = -std::numeric_limits<float>::infinity();
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
const float candidate = attention_scores[batch_id][kv_length_id];
|
||||
thread_max = (thread_max < candidate) ? candidate : thread_max;
|
||||
}
|
||||
}
|
||||
if (thread_max != -std::numeric_limits<float>::infinity()) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute exp(elt - max) masked
|
||||
float exponential[min_kv_length_shard_size_per_thread];
|
||||
if (batch_id < batch_size) {
|
||||
float thread_add = 0;
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
|
||||
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
|
||||
} else {
|
||||
exponential[kv_length_id - kv_length_start] = 0.;
|
||||
}
|
||||
}
|
||||
if (thread_add > 0) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute softmax
|
||||
if (batch_id < batch_size) {
|
||||
// If sum of all exponential is 0, we set the softmax values to 0
|
||||
if (temp_storage[row_id_mem_offset + 1] == 0.) {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = 0.;
|
||||
}
|
||||
} else {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
|
||||
const at::Tensor fused_qkv,
|
||||
const std::optional<std::vector<at::Tensor>> layer_past,
|
||||
const at::Tensor alibi,
|
||||
const at::Tensor attention_mask,
|
||||
const std::optional<at::Tensor> head_mask,
|
||||
const float beta,
|
||||
const float inv_norm_factor,
|
||||
const int num_heads,
|
||||
const bool use_cache
|
||||
) {
|
||||
const auto batch_size = fused_qkv.size(0);
|
||||
const auto q_length = fused_qkv.size(1);
|
||||
const auto three_times_hidden_size = fused_qkv.size(2);
|
||||
const auto head_dim = three_times_hidden_size / (3 * num_heads);
|
||||
const auto batch_size_times_num_heads = batch_size * num_heads;
|
||||
|
||||
// `split_heads`
|
||||
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
|
||||
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
|
||||
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
|
||||
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
|
||||
if (layer_past) {
|
||||
const auto past_key = (*layer_past).at(0);
|
||||
const auto past_value = (*layer_past).at(1);
|
||||
key_layer = at::cat({past_key, key_layer}, 2);
|
||||
value_layer = at::cat({past_value, value_layer}, 1);
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> present;
|
||||
if (use_cache) {
|
||||
present = {key_layer, value_layer};
|
||||
} else {
|
||||
present = {};
|
||||
}
|
||||
|
||||
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
|
||||
|
||||
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
|
||||
at::Tensor attention_probs;
|
||||
if (true) {
|
||||
const auto kv_length = key_layer.size(2);
|
||||
|
||||
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
|
||||
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
|
||||
// Custom kernel
|
||||
attention_probs = at::empty_like(attention_scores_2d);
|
||||
|
||||
// Check that inputs and contiguous + cuda tensors
|
||||
CHECK_INPUT(attention_scores_2d);
|
||||
CHECK_INPUT(attention_mask_2d);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
/*
|
||||
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
|
||||
* with multiple threads as we need to `sync_threads` to run exponential sum.
|
||||
* We maximise the usage of threads within a single block
|
||||
*/
|
||||
// TODO @thomasw21 figure out everything warp related:
|
||||
// - why do they have to be power of 2
|
||||
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
|
||||
const auto MAX_THREADS_PER_SM = 1024;
|
||||
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
|
||||
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
|
||||
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
|
||||
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
|
||||
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
|
||||
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
|
||||
|
||||
const dim3 gridDim(num_blocks); // Number of blocks that run
|
||||
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
|
||||
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
|
||||
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
|
||||
|
||||
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
|
||||
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
|
||||
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
effective_kv_length,
|
||||
blockDim,
|
||||
rows_per_block,
|
||||
kv_length,
|
||||
batch_size_times_num_heads * q_length
|
||||
);
|
||||
});
|
||||
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
|
||||
} else {
|
||||
// Pytorch C++ API
|
||||
auto input_dtype = attention_scores.scalar_type();
|
||||
if (input_dtype == at::ScalarType::Float) {
|
||||
attention_scores = attention_scores.to(at::ScalarType::Float);
|
||||
};
|
||||
// TODO @thomasw21 Figure out how to get minimum value
|
||||
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
|
||||
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
|
||||
}
|
||||
|
||||
auto context_layer = attention_probs.bmm(value_layer);
|
||||
|
||||
// `_merge_heads`
|
||||
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
|
||||
context_layer = context_layer.permute({0, 2, 1, 3});
|
||||
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
|
||||
|
||||
return std::make_tuple(context_layer, present, attention_probs);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"Bloom attention mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* compute GELU: `x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))`
|
||||
**/
|
||||
template<typename scalar_t, int64_t max_threads_per_sm>
|
||||
__global__ void forward_kernel(
|
||||
const scalar_t* __restrict__ x, // [B, N, D]
|
||||
scalar_t* __restrict__ result, // [B, N, D]
|
||||
int64_t num_params
|
||||
) {
|
||||
const auto id = blockIdx.x * max_threads_per_sm + threadIdx.x;
|
||||
|
||||
if (num_params <= id) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We upcast to float always
|
||||
const float elt = x[id];
|
||||
|
||||
// Compute gelu
|
||||
// TODO @thomasw21: Figure out where to find a tanh implementation that works for all kinds of scalar types. (I could hardcode it)
|
||||
result[id] = static_cast<scalar_t>(elt * 0.5 * (1.0 + std::tanh(0.79788456 * elt * (1 + 0.044715 * elt * elt))));
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
at::Tensor forward(
|
||||
const at::Tensor x
|
||||
) {
|
||||
CHECK_INPUT(x);
|
||||
|
||||
const auto result = at::empty_like(x);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(key_layer.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "gelu", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
const auto MAX_THREADS_PER_SM = 1024; // TODO @thomas21 check why everyone is setting 1024 when officially it's 1024
|
||||
const auto num_params = x.numel(); // TODO @thomasw21 get `x.size()`
|
||||
const auto NUM_BLOCKS = (num_params - 1) / MAX_THREADS_PER_SM + 1;
|
||||
|
||||
dim3 gridDim(NUM_BLOCKS); // Number of blocks that run
|
||||
dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
|
||||
forward_kernel<scalar_t, MAX_THREADS_PER_SM><<<gridDim, blockDim>>>(
|
||||
x.data<scalar_t>(),
|
||||
result.data<scalar_t>(),
|
||||
num_params
|
||||
);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"Bloom GELU mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
@ -19,6 +19,7 @@ import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||
@ -34,10 +35,18 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_bloom import BloomConfig
|
||||
|
||||
from .parallel_layers import TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED=False
|
||||
try:
|
||||
from .custom_kernels import fused_bloom_attention_cuda
|
||||
from .custom_kernels import fused_bloom_gelu_cuda
|
||||
CUSTOM_KERNELS_ENABLED=True
|
||||
except ImportError:
|
||||
logger.warning("We're not using custom kernels.")
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
||||
_CONFIG_FOR_DOC = "BloomConfig"
|
||||
_TOKENIZER_FOR_DOC = "BloomTokenizerFast"
|
||||
@ -68,7 +77,7 @@ def _make_causal_mask(
|
||||
if past_key_values_length > 0:
|
||||
mask[:, :past_key_values_length] = False
|
||||
|
||||
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
||||
expanded_mask = mask[None, :, :].expand(batch_size, target_length, target_length + past_key_values_length)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
@ -79,11 +88,11 @@ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
||||
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
||||
|
||||
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
@ -124,9 +133,9 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
return alibi
|
||||
|
||||
# @torch.jit.script
|
||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
||||
"""
|
||||
Dropout add function
|
||||
@ -145,7 +154,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
|
||||
@ -179,8 +188,13 @@ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
return bloom_gelu_forward(input)
|
||||
if False and CUSTOM_KERNELS_ENABLED:
|
||||
"""My kernel is actually still slow compared to what `jit` provides."""
|
||||
raise ValueError("WTF")
|
||||
fused_bloom_gelu_cuda.foward(input)
|
||||
else:
|
||||
ctx.save_for_backward(input)
|
||||
return bloom_gelu_forward(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
||||
@ -207,9 +221,57 @@ class BloomGelu(nn.Module):
|
||||
else:
|
||||
return bloom_gelu_forward(x)
|
||||
|
||||
# @torch.jit.script # this is shit for unknow reasons.
|
||||
def _split_heads(fused_qkv: torch.Tensor, num_heads: int, head_dim: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
storage as `fused_qkv`
|
||||
|
||||
Args:
|
||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||
|
||||
Returns:
|
||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||
value: [batch_size, seq_length, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
|
||||
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * num_heads, head_dim, seq_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
|
||||
|
||||
return query_layer, key_layer, value_layer
|
||||
|
||||
# @torch.jit.script
|
||||
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge heads together over the last dimenstion
|
||||
|
||||
Args:
|
||||
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
||||
|
||||
Returns:
|
||||
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
||||
"""
|
||||
# What we want to achieve is:
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
batch_size_and_num_heads, seq_length, _ = x.shape
|
||||
batch_size = batch_size_and_num_heads // num_heads
|
||||
|
||||
# First view to decompose the batch size
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
||||
x = x.view(batch_size, num_heads, seq_length, head_dim)
|
||||
|
||||
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
return x.reshape(batch_size, seq_length, num_heads * head_dim)
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
@ -231,72 +293,37 @@ class BloomAttention(nn.Module):
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
if process_group is None:
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
else:
|
||||
assert self.num_heads % process_group.size() == 0
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear(self.hidden_size, 3 * self.hidden_size, process_group=process_group, bias=True)
|
||||
self.dense = TensorParallelRowLinear(self.hidden_size, self.hidden_size, process_group=process_group)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
storage as `fused_qkv`
|
||||
|
||||
Args:
|
||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||
|
||||
Returns:
|
||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||
value: [batch_size, seq_length, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
||||
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
||||
|
||||
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Merge heads together over the last dimenstion
|
||||
|
||||
Args:
|
||||
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
||||
|
||||
Returns:
|
||||
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
||||
"""
|
||||
# What we want to achieve is:
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
batch_size_and_num_heads, seq_length, _ = x.shape
|
||||
batch_size = batch_size_and_num_heads // self.num_heads
|
||||
|
||||
# First view to decompose the batch size
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
||||
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
||||
|
||||
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
@staticmethod
|
||||
def compute_attention(
|
||||
fused_qkv: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
head_mask: Optional[torch.Tensor],
|
||||
beta: float,
|
||||
inv_norm_factor: float,
|
||||
num_heads: int,
|
||||
use_cache: bool
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
|
||||
head_dim = three_times_hidden_size // (3 * num_heads)
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
|
||||
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
(query_layer, key_layer, value_layer) = _split_heads(fused_qkv, num_heads=num_heads,
|
||||
head_dim=head_dim)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
@ -311,41 +338,81 @@ class BloomAttention(nn.Module):
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
###
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
attention_scores = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
beta=beta,
|
||||
alpha=inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
|
||||
attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
# # [batch_size, num_heads, q_length, kv_length]
|
||||
# attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
context_layer = _merge_heads(context_layer, num_heads=num_heads, head_dim=head_dim)
|
||||
|
||||
return context_layer, present, attention_probs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
batch_size, q_length, _ = fused_qkv.shape
|
||||
|
||||
if CUSTOM_KERNELS_ENABLED:
|
||||
assert self.training is False, "Only foward pass was implemented"
|
||||
assert attention_mask.shape[-1] < 4096, "Custom kernel support only up to 4096 tokens"
|
||||
context_layer, present, attention_probs = fused_bloom_attention_cuda.forward(
|
||||
fused_qkv,
|
||||
layer_past,
|
||||
alibi,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
self.beta,
|
||||
self.inv_norm_factor,
|
||||
self.num_heads,
|
||||
use_cache
|
||||
)
|
||||
else:
|
||||
raise ValueError("Block this path while we figure out how to run C++ code")
|
||||
context_layer, present, attention_probs = self.compute_attention(
|
||||
fused_qkv=fused_qkv,
|
||||
layer_past=layer_past,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
beta=self.beta,
|
||||
inv_norm_factor=self.inv_norm_factor,
|
||||
num_heads=self.num_heads,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
@ -359,7 +426,8 @@ class BloomAttention(nn.Module):
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
output_tensor += residual
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
@ -369,15 +437,19 @@ class BloomAttention(nn.Module):
|
||||
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.slow_but_exact = config.slow_but_exact
|
||||
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
||||
if process_group is None:
|
||||
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
||||
else:
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(hidden_size, 4 * hidden_size, process_group=process_group)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(4 * hidden_size, hidden_size, process_group=process_group)
|
||||
self.gelu_impl = BloomGelu()
|
||||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
@ -394,22 +466,23 @@ class BloomMLP(nn.Module):
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
|
||||
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
intermediate_output += residual
|
||||
|
||||
return output
|
||||
return intermediate_output
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.num_heads = config.n_head
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.self_attention = BloomAttention(config, process_group=process_group)
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = BloomMLP(config)
|
||||
self.mlp = BloomMLP(config, process_group=process_group)
|
||||
|
||||
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
@ -581,18 +654,23 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
BLOOM_START_DOCSTRING,
|
||||
)
|
||||
class BloomModel(BloomPreTrainedModel):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]=None):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.n_head
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
if process_group is None:
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
else:
|
||||
self.word_embeddings = TensorParallelEmbedding(config.vocab_size, self.embed_dim, process_group=process_group)
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([BloomBlock(config, process_group=process_group) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -609,7 +687,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
@ -619,7 +697,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
@ -705,7 +783,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
@ -713,6 +791,18 @@ class BloomModel(BloomPreTrainedModel):
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if hasattr(self, "tp_rank"):
|
||||
assert self.num_heads % self.tp_world_size == 0
|
||||
block_size = self.num_heads // self.tp_world_size
|
||||
alibi = alibi[:, self.tp_rank * block_size: (self.tp_rank + 1) * block_size]
|
||||
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
|
||||
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
||||
else:
|
||||
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
|
||||
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
|
||||
|
||||
alibi = alibi.to(hidden_states.dtype)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if output_hidden_states:
|
||||
@ -787,8 +877,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
print(config)
|
||||
### HACK
|
||||
if config.tp_parallel:
|
||||
process_group = torch.distributed.distributed_c10d._get_default_group()
|
||||
else:
|
||||
process_group = None
|
||||
###
|
||||
self.transformer = BloomModel(config, process_group)
|
||||
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
else:
|
||||
# TODO @thomasw21: TensorParallelColumnLinear doesn't inherit nn.Linear anymore as we switch the underlying storage
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size // process_group.size(), bias=False)
|
||||
# self.lm_head = TensorParallelColumnLinear(config.hidden_size, config.vocab_size, process_group=process_group, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
235
src/transformers/models/bloom/parallel_layers.py
Normal file
235
src/transformers/models/bloom/parallel_layers.py
Normal file
@ -0,0 +1,235 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
assert out_features % self.tp_world_size == 0
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features // self.tp_world_size
|
||||
# We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
|
||||
self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""From `torch.nn.Linear`"""
|
||||
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
|
||||
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
|
||||
# https://github.com/pytorch/pytorch/issues/57109
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""From `torch.nn.Linear`"""
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
|
||||
# # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
|
||||
# in_features, out_features = weight.shape
|
||||
# size_out = input.size()[:-1] + (out_features,)
|
||||
# # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
|
||||
# return torch.addmm(bias, input.view(-1, in_features), weight).view(size_out)
|
||||
|
||||
in_features, out_features = weight.shape
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
# TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
|
||||
input = input.view(-1, in_features)
|
||||
# HACK @thomas21: turns out `aten::addmm.out` is not decomposed
|
||||
out = torch.empty((0,), device=input.device, dtype=input.dtype)
|
||||
out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
|
||||
|
||||
return out.view(size_out)
|
||||
|
||||
# return F.linear(input, weight=weight.transpose(1,0), bias=bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = self.linear(input, weight=self.weight, bias=self.bias)
|
||||
|
||||
# ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
|
||||
# out_from_tp_ranks = [torch.empty_like(out) for _ in range(self.process_group.size())]
|
||||
# torch.distributed.all_gather(out_from_tp_ranks, out, group=self.process_group)
|
||||
# sharded_out = torch.cat(out_from_tp_ranks, dim=-1)
|
||||
#
|
||||
# weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
|
||||
# bias_from_tp_ranks = [torch.empty_like(self.bias) for _ in range(self.process_group.size())]
|
||||
# torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
|
||||
# torch.distributed.all_gather(bias_from_tp_ranks, self.bias, group=self.process_group)
|
||||
# weight = torch.cat(weight_from_tp_ranks, dim=0)
|
||||
# bias = torch.cat(bias_from_tp_ranks, dim=0)
|
||||
# baseline_out = F.linear(input, weight, bias)
|
||||
#
|
||||
# torch.testing.assert_close(sharded_out, baseline_out, atol=0.0, rtol=0.0)
|
||||
# ###
|
||||
|
||||
return out
|
||||
|
||||
class TensorParallelRowLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
assert in_features % self.tp_world_size == 0
|
||||
|
||||
self.in_features = in_features // self.tp_world_size
|
||||
self.out_features = out_features
|
||||
# We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
|
||||
self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""From `torch.nn.Linear`"""
|
||||
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
|
||||
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
|
||||
# https://github.com/pytorch/pytorch/issues/57109
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def linear(input: torch.Tensor, weight: torch.Tensor, bias:torch.Tensor):
|
||||
# # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
|
||||
# in_features, out_features = weight.shape
|
||||
# size_out = input.size()[:-1] + (out_features,)
|
||||
# # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
|
||||
# input = input.view(-1, in_features)
|
||||
# # with torch.jit.strict_fusion():
|
||||
# out = torch.addmm(bias, input, weight)
|
||||
# return out.view(size_out)
|
||||
|
||||
in_features, out_features = weight.shape
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
# TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
|
||||
input = input.view(-1, in_features)
|
||||
# HACK @thomas21: turns out `aten::addmm.out` is not decomposed
|
||||
out = torch.empty((0,), device=input.device, dtype=input.dtype)
|
||||
out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
|
||||
|
||||
return out.view(size_out)
|
||||
|
||||
# return F.linear(input, weight=weight.transpose(1,0), bias=bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = self.linear(input, weight=self.weight, bias=self.bias)
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
# ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
|
||||
# sharded_out = out
|
||||
#
|
||||
# input_from_tp_ranks = [torch.empty_like(input) for _ in range(self.process_group.size())]
|
||||
# weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
|
||||
# bias = self.bias.clone()
|
||||
# torch.distributed.all_gather(input_from_tp_ranks, input, group=self.process_group)
|
||||
# torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
|
||||
# torch.distributed.all_reduce(bias, group=self.process_group)
|
||||
# input = torch.cat(input_from_tp_ranks, dim=-1)
|
||||
# weight = torch.cat(weight_from_tp_ranks, dim=1)
|
||||
# baseline_out = F.linear(input, weight, bias)
|
||||
#
|
||||
# if self.process_group.rank() == 0:
|
||||
# torch.testing.assert_close(bias, self.bias, atol=0.0, rtol=0.0)
|
||||
# torch.distributed.barrier(self.process_group)
|
||||
# # torch.testing.assert_close(sharded_out, baseline_out)
|
||||
# ###
|
||||
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""From `torch.nn.Linear`"""
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None
|
||||
)
|
||||
|
||||
class TensorParallelEmbedding(nn.Embedding):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
self.original_num_embeddings = num_embeddings
|
||||
|
||||
# TODO @thomasw21 fix and remove that constraint
|
||||
assert num_embeddings % self.tp_world_size == 0
|
||||
block_size = num_embeddings // self.tp_world_size
|
||||
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
|
||||
self.min_id = self.tp_rank * block_size
|
||||
self.max_id = (self.tp_rank + 1) * block_size
|
||||
|
||||
super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# Sanity check
|
||||
if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)):
|
||||
raise IndexError(f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}")
|
||||
|
||||
# `0` if input is in the correct interval, else `1`
|
||||
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = input - self.min_id
|
||||
# default all out of bounds values to `0`
|
||||
input[input_mask] = 0
|
||||
out = super().forward(input)
|
||||
out[input_mask] = 0.0
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
Reference in New Issue
Block a user