Compare commits

...

64 Commits

Author SHA1 Message Date
46d37bece7 Add safety in case the entire row is 0 2022-09-01 12:09:07 +02:00
bd232fb97d Off by 1 error .... 2022-09-01 12:06:19 +02:00
5924480094 Woops 2022-08-18 19:01:11 +02:00
a009ca5a3b Only jit gelu 2022-08-18 19:00:52 +02:00
4e383e7a33 Add back jit 2022-08-18 18:27:08 +02:00
870e180f08 Woops 2022-08-18 18:14:49 +02:00
b395423c8b Remove all jit 2022-08-18 18:12:52 +02:00
c63e0a7377 Woops 2022-08-18 17:32:32 +02:00
ccf4b2d610 Woops 2022-08-18 17:12:08 +02:00
989fac0556 Woops 2022-08-18 17:08:07 +02:00
3fc5ebb9e3 Woops 2022-08-18 17:06:24 +02:00
c4edcf9c70 Let's try build gelu 2022-08-18 17:01:58 +02:00
94e6226b3a Let's try build gelu 2022-08-18 16:49:20 +02:00
4787e17582 Allow sequences to up to 4k ... 2022-08-18 16:13:07 +02:00
07e9409032 Allow sequences to up to 4k ... 2022-08-18 16:07:47 +02:00
435f8fcbf2 Add view to change 2D attention_probs to 3D 2022-08-18 15:56:20 +02:00
9d4257a74c maybe? 2022-08-18 15:53:02 +02:00
68944f597d Know after? 2022-08-18 15:45:31 +02:00
0c10634a36 Maybe? 2022-08-18 15:34:46 +02:00
bd7101dace Woops 2022-08-18 15:31:09 +02:00
ee4fe12157 Maybe? 2022-08-18 15:27:06 +02:00
44430fbe14 Woops 2022-08-18 15:17:32 +02:00
6c25f536cc Woops 2022-08-18 15:10:10 +02:00
6fd4c59911 Woops 2022-08-18 15:05:33 +02:00
b70a818d50 Let's have a test at our new improved masked softmax kernel 2022-08-18 15:00:36 +02:00
2b266f0a39 GELU 2022-08-17 18:00:35 +02:00
9e80f9b190 WIP 2022-08-17 15:36:45 +02:00
9bdd6c8286 Minor comments 2022-08-17 15:36:45 +02:00
389f196b1c I'm stupid 2022-08-17 15:36:45 +02:00
b0d6f72b55 Need to initialize correctly shared memory 2022-08-17 15:36:45 +02:00
9b448ad29d Woops 2022-08-17 15:36:45 +02:00
218f457388 Woops 2022-08-17 15:36:45 +02:00
616690fa8a Woops 2022-08-17 15:36:45 +02:00
12ac8b7ec3 Woops 2022-08-17 15:36:45 +02:00
379ef0a888 Woops 2022-08-17 15:36:45 +02:00
5bdb735c28 Maybe custom kernels don't support named arguments 2022-08-17 15:36:45 +02:00
8d8f5864b7 Woops 2022-08-17 15:36:45 +02:00
56915604ac Woops 2022-08-17 15:36:45 +02:00
99c3d575aa Woops 2022-08-17 15:36:45 +02:00
961a341e31 Woops 2022-08-17 15:36:45 +02:00
76e18bcfd1 Woops 2022-08-17 15:36:45 +02:00
20d9aa6058 Woops 2022-08-17 15:36:45 +02:00
82dbd6f64f Woops 2022-08-17 15:36:45 +02:00
066f1d8c47 Woops 2022-08-17 15:36:45 +02:00
493dc31d78 Woops 2022-08-17 15:36:44 +02:00
7168222a3a Woops 2022-08-17 15:36:44 +02:00
dfaa2e37f9 Woops 2022-08-17 15:36:44 +02:00
aa47de5a56 Woops 2022-08-17 15:36:44 +02:00
9067981e0c Woops 2022-08-17 15:36:44 +02:00
e4bd0884ce Woops 2022-08-17 15:36:44 +02:00
ee234a808f Woops 2022-08-17 15:36:44 +02:00
62ef07b72e Woops 2022-08-17 15:36:44 +02:00
a7034e3211 Woops 2022-08-17 15:36:44 +02:00
c698398bb3 Woops 2022-08-17 15:36:44 +02:00
0cad39a4e5 Have a go at coding my own cuda kernel 2022-08-17 15:36:44 +02:00
d92aa10587 Add a default to fix a bunch of test 2022-08-17 15:36:21 +02:00
90821b2f13 Try and plug C++ kernel 2022-08-17 15:36:20 +02:00
68f0b97531 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
2c1e4554d0 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
452b9d00f8 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
8941c2df47 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
e8e15437c5 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
4f98d8e136 WIP: explore cuda kernels 2022-08-17 15:35:08 +02:00
f4d0dc3c15 WIP 2022-08-17 15:32:58 +02:00
8 changed files with 832 additions and 108 deletions

5
README_BUILD.md Normal file
View File

@ -0,0 +1,5 @@
## We provide some functions in order to custom build some kernels
```bash
python setup.py build_ext --inplace
```

View File

@ -75,8 +75,10 @@ from pathlib import Path
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
stale_egg_info = Path(__file__).parent / "transformers.egg-info"
if stale_egg_info.exists():
print(
@ -397,6 +399,31 @@ install_requires = [
deps["tokenizers"],
deps["tqdm"], # progress bars in model download and training scripts
]
def get_extensions():
# TODO @thomasw21 add cpp versions
extensions = []
# TODO @thomasw21 build cuda kernels only on some conditions
if True:
extensions += [
CUDAExtension(
name="transformers.models.bloom.custom_kernels.fused_bloom_attention_cuda",
sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_attention_cuda.cu"],
# TODO: understand what that is, probably defines the target architecture
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
# Build for A100
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CUDAExtension(
name="transformers.models.bloom.custom_kernels.fused_bloom_gelu_cuda",
sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_gelu_cuda.cu"],
# TODO: understand what that is, probably defines the target architecture
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
# Build for A100
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
]
return extensions
setup(
name="transformers",
@ -429,5 +456,9 @@ setup(
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
ext_modules=get_extensions(),
cmdclass={
"deps_table_update": DepsTableUpdateCommand,
"build_ext": BuildExtension
},
)

View File

@ -481,11 +481,11 @@ class GenerationMixin:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
return torch.full(shape, fill_value=-100, dtype=torch.long, device=self.device)
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
return torch.full((1, 1), fill_value=bos_token_id, dtype=torch.long, device=self.device)
def _prepare_attention_mask_for_generation(
self,
@ -541,7 +541,7 @@ class GenerationMixin:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
return torch.full((batch_size, 1), fill_value=decoder_start_token_id, dtype=torch.long, device=device)
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
@ -1691,7 +1691,7 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
unfinished_sequences = input_ids.new_full((input_ids.shape[0],), fill_value=True, dtype=torch.bool)
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
@ -1752,7 +1752,7 @@ class GenerationMixin:
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
next_tokens.masked_fill_(~unfinished_sequences, pad_token_id)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -1763,10 +1763,10 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
unfinished_sequences &= next_tokens != eos_token_id
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not (torch.any(unfinished_sequences) and not stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:

View File

@ -16,6 +16,7 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
import torch.distributed
from transformers import is_torch_available
@ -130,6 +131,7 @@ class BloomConfig(PretrainedConfig):
attention_dropout=0.0,
pretraining_tp=1, # TP rank used when training with megatron
slow_but_exact=False,
tp_parallel=False,
**kwargs,
):
self.vocab_size = vocab_size
@ -149,6 +151,7 @@ class BloomConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.slow_but_exact = slow_but_exact
self.tp_parallel = tp_parallel
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@ -0,0 +1,250 @@
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor fused_qkv,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor alibi,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float beta,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
const auto batch_size = fused_qkv.size(0);
const auto q_length = fused_qkv.size(1);
const auto three_times_hidden_size = fused_qkv.size(2);
const auto head_dim = three_times_hidden_size / (3 * num_heads);
const auto batch_size_times_num_heads = batch_size * num_heads;
// `split_heads`
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 1);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
const auto kv_length = key_layer.size(2);
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_layer);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"Bloom attention mechanism forward (CUDA)"
);
}

View File

@ -0,0 +1,97 @@
#include <ATen/Dispatch.h>
#include <ATen/ATen.h>
#include <torch/extension.h>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* compute GELU: `x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))`
**/
template<typename scalar_t, int64_t max_threads_per_sm>
__global__ void forward_kernel(
const scalar_t* __restrict__ x, // [B, N, D]
scalar_t* __restrict__ result, // [B, N, D]
int64_t num_params
) {
const auto id = blockIdx.x * max_threads_per_sm + threadIdx.x;
if (num_params <= id) {
return;
}
// We upcast to float always
const float elt = x[id];
// Compute gelu
// TODO @thomasw21: Figure out where to find a tanh implementation that works for all kinds of scalar types. (I could hardcode it)
result[id] = static_cast<scalar_t>(elt * 0.5 * (1.0 + std::tanh(0.79788456 * elt * (1 + 0.044715 * elt * elt))));
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor forward(
const at::Tensor x
) {
CHECK_INPUT(x);
const auto result = at::empty_like(x);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(key_layer.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "gelu", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
const auto MAX_THREADS_PER_SM = 1024; // TODO @thomas21 check why everyone is setting 1024 when officially it's 1024
const auto num_params = x.numel(); // TODO @thomasw21 get `x.size()`
const auto NUM_BLOCKS = (num_params - 1) / MAX_THREADS_PER_SM + 1;
dim3 gridDim(NUM_BLOCKS); // Number of blocks that run
dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
forward_kernel<scalar_t, MAX_THREADS_PER_SM><<<gridDim, blockDim>>>(
x.data<scalar_t>(),
result.data<scalar_t>(),
num_params
);
});
return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"Bloom GELU mechanism forward (CUDA)"
);
}

View File

@ -19,6 +19,7 @@ import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
@ -34,10 +35,18 @@ from ...modeling_outputs import (
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_bloom import BloomConfig
from .parallel_layers import TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear
logger = logging.get_logger(__name__)
CUSTOM_KERNELS_ENABLED=False
try:
from .custom_kernels import fused_bloom_attention_cuda
from .custom_kernels import fused_bloom_gelu_cuda
CUSTOM_KERNELS_ENABLED=True
except ImportError:
logger.warning("We're not using custom kernels.")
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizerFast"
@ -68,7 +77,7 @@ def _make_causal_mask(
if past_key_values_length > 0:
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
expanded_mask = mask[None, :, :].expand(batch_size, target_length, target_length + past_key_values_length)
return expanded_mask
@ -79,11 +88,11 @@ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
@ -124,9 +133,9 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
return alibi
# @torch.jit.script
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
"""
Dropout add function
@ -145,7 +154,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
out = residual + out
return out
@torch.jit.script
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
"""
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
@ -179,8 +188,13 @@ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
if False and CUSTOM_KERNELS_ENABLED:
"""My kernel is actually still slow compared to what `jit` provides."""
raise ValueError("WTF")
fused_bloom_gelu_cuda.foward(input)
else:
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
@ -207,9 +221,57 @@ class BloomGelu(nn.Module):
else:
return bloom_gelu_forward(x)
# @torch.jit.script # this is shit for unknow reasons.
def _split_heads(fused_qkv: torch.Tensor, num_heads: int, head_dim: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
query_layer = query_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * num_heads, head_dim, seq_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
return query_layer, key_layer, value_layer
# @torch.jit.script
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, num_heads, seq_length, head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, num_heads * head_dim)
class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
super().__init__()
self.pretraining_tp = config.pretraining_tp
@ -231,72 +293,37 @@ class BloomAttention(nn.Module):
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
if process_group is None:
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
else:
assert self.num_heads % process_group.size() == 0
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(self.hidden_size, 3 * self.hidden_size, process_group=process_group, bias=True)
self.dense = TensorParallelRowLinear(self.hidden_size, self.hidden_size, process_group=process_group)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // self.num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
@staticmethod
def compute_attention(
fused_qkv: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
head_mask: Optional[torch.Tensor],
beta: float,
inv_norm_factor: float,
num_heads: int,
use_cache: bool
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
head_dim = three_times_hidden_size // (3 * num_heads)
batch_size_times_num_heads = batch_size * num_heads
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
(query_layer, key_layer, value_layer) = _split_heads(fused_qkv, num_heads=num_heads,
head_dim=head_dim)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
@ -311,41 +338,81 @@ class BloomAttention(nn.Module):
present = (key_layer, value_layer)
else:
present = None
###
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
beta=beta,
alpha=inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
# # [batch_size, num_heads, q_length, kv_length]
# attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer)
context_layer = _merge_heads(context_layer, num_heads=num_heads, head_dim=head_dim)
return context_layer, present, attention_probs
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, _ = fused_qkv.shape
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert attention_mask.shape[-1] < 4096, "Custom kernel support only up to 4096 tokens"
context_layer, present, attention_probs = fused_bloom_attention_cuda.forward(
fused_qkv,
layer_past,
alibi,
attention_mask,
head_mask,
self.beta,
self.inv_norm_factor,
self.num_heads,
use_cache
)
else:
raise ValueError("Block this path while we figure out how to run C++ code")
context_layer, present, attention_probs = self.compute_attention(
fused_qkv=fused_qkv,
layer_past=layer_past,
alibi=alibi,
attention_mask=attention_mask,
head_mask=head_mask,
beta=self.beta,
inv_norm_factor=self.inv_norm_factor,
num_heads=self.num_heads,
use_cache=use_cache
)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
@ -359,7 +426,8 @@ class BloomAttention(nn.Module):
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output_tensor += residual
outputs = (output_tensor, present)
if output_attentions:
@ -369,15 +437,19 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
super().__init__()
hidden_size = config.hidden_size
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
if process_group is None:
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(hidden_size, 4 * hidden_size, process_group=process_group)
self.dense_4h_to_h = TensorParallelRowLinear(4 * hidden_size, hidden_size, process_group=process_group)
self.gelu_impl = BloomGelu()
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
self.hidden_dropout = config.hidden_dropout
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@ -394,22 +466,23 @@ class BloomMLP(nn.Module):
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
intermediate_output += residual
return output
return intermediate_output
class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig):
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.num_heads = config.n_head
self.self_attention = BloomAttention(config)
self.self_attention = BloomAttention(config, process_group=process_group)
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.mlp = BloomMLP(config, process_group=process_group)
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.hidden_dropout = config.hidden_dropout
@ -581,18 +654,23 @@ BLOOM_INPUTS_DOCSTRING = r"""
BLOOM_START_DOCSTRING,
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]=None):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.n_head
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
if process_group is None:
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
else:
self.word_embeddings = TensorParallelEmbedding(config.vocab_size, self.embed_dim, process_group=process_group)
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
self.h = nn.ModuleList([BloomBlock(config, process_group=process_group) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@ -609,7 +687,7 @@ class BloomModel(BloomPreTrainedModel):
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
@ -619,7 +697,7 @@ class BloomModel(BloomPreTrainedModel):
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
@ -705,7 +783,7 @@ class BloomModel(BloomPreTrainedModel):
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
alibi = build_alibi_tensor(attention_mask, self.num_heads)
causal_mask = self._prepare_attn_mask(
attention_mask,
@ -713,6 +791,18 @@ class BloomModel(BloomPreTrainedModel):
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_heads % self.tp_world_size == 0
block_size = self.num_heads // self.tp_world_size
alibi = alibi[:, self.tp_rank * block_size: (self.tp_rank + 1) * block_size]
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
alibi = alibi.to(hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
@ -787,8 +877,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
print(config)
### HACK
if config.tp_parallel:
process_group = torch.distributed.distributed_c10d._get_default_group()
else:
process_group = None
###
self.transformer = BloomModel(config, process_group)
if process_group is None:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
else:
# TODO @thomasw21: TensorParallelColumnLinear doesn't inherit nn.Linear anymore as we switch the underlying storage
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size // process_group.size(), bias=False)
# self.lm_head = TensorParallelColumnLinear(config.hidden_size, config.vocab_size, process_group=process_group, bias=False)
# Initialize weights and apply final processing
self.post_init()

View File

@ -0,0 +1,235 @@
import torch
import torch.distributed
import torch.nn.functional as F
from torch import nn
import math
class TensorParallelColumnLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
self.in_features = in_features
self.out_features = out_features // self.tp_world_size
# We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""From `torch.nn.Linear`"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def extra_repr(self) -> str:
"""From `torch.nn.Linear`"""
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
@staticmethod
@torch.jit.script
def linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
# # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
# in_features, out_features = weight.shape
# size_out = input.size()[:-1] + (out_features,)
# # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
# return torch.addmm(bias, input.view(-1, in_features), weight).view(size_out)
in_features, out_features = weight.shape
size_out = input.size()[:-1] + (out_features,)
# TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
input = input.view(-1, in_features)
# HACK @thomas21: turns out `aten::addmm.out` is not decomposed
out = torch.empty((0,), device=input.device, dtype=input.dtype)
out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
return out.view(size_out)
# return F.linear(input, weight=weight.transpose(1,0), bias=bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.linear(input, weight=self.weight, bias=self.bias)
# ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
# out_from_tp_ranks = [torch.empty_like(out) for _ in range(self.process_group.size())]
# torch.distributed.all_gather(out_from_tp_ranks, out, group=self.process_group)
# sharded_out = torch.cat(out_from_tp_ranks, dim=-1)
#
# weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
# bias_from_tp_ranks = [torch.empty_like(self.bias) for _ in range(self.process_group.size())]
# torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
# torch.distributed.all_gather(bias_from_tp_ranks, self.bias, group=self.process_group)
# weight = torch.cat(weight_from_tp_ranks, dim=0)
# bias = torch.cat(bias_from_tp_ranks, dim=0)
# baseline_out = F.linear(input, weight, bias)
#
# torch.testing.assert_close(sharded_out, baseline_out, atol=0.0, rtol=0.0)
# ###
return out
class TensorParallelRowLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.tp_world_size = process_group.size()
assert in_features % self.tp_world_size == 0
self.in_features = in_features // self.tp_world_size
self.out_features = out_features
# We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""From `torch.nn.Linear`"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
@staticmethod
@torch.jit.script
def linear(input: torch.Tensor, weight: torch.Tensor, bias:torch.Tensor):
# # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
# in_features, out_features = weight.shape
# size_out = input.size()[:-1] + (out_features,)
# # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
# input = input.view(-1, in_features)
# # with torch.jit.strict_fusion():
# out = torch.addmm(bias, input, weight)
# return out.view(size_out)
in_features, out_features = weight.shape
size_out = input.size()[:-1] + (out_features,)
# TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
input = input.view(-1, in_features)
# HACK @thomas21: turns out `aten::addmm.out` is not decomposed
out = torch.empty((0,), device=input.device, dtype=input.dtype)
out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
return out.view(size_out)
# return F.linear(input, weight=weight.transpose(1,0), bias=bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.linear(input, weight=self.weight, bias=self.bias)
torch.distributed.all_reduce(out, group=self.process_group)
# ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
# sharded_out = out
#
# input_from_tp_ranks = [torch.empty_like(input) for _ in range(self.process_group.size())]
# weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
# bias = self.bias.clone()
# torch.distributed.all_gather(input_from_tp_ranks, input, group=self.process_group)
# torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
# torch.distributed.all_reduce(bias, group=self.process_group)
# input = torch.cat(input_from_tp_ranks, dim=-1)
# weight = torch.cat(weight_from_tp_ranks, dim=1)
# baseline_out = F.linear(input, weight, bias)
#
# if self.process_group.rank() == 0:
# torch.testing.assert_close(bias, self.bias, atol=0.0, rtol=0.0)
# torch.distributed.barrier(self.process_group)
# # torch.testing.assert_close(sharded_out, baseline_out)
# ###
return out
def extra_repr(self) -> str:
"""From `torch.nn.Linear`"""
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
# TODO @thomasw21 fix and remove that constraint
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, dtype=dtype)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check
if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)):
raise IndexError(f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}")
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
# translate for [0, self.max_id - self.min_id[
input = input - self.min_id
# default all out of bounds values to `0`
input[input_mask] = 0
out = super().forward(input)
out[input_mask] = 0.0
torch.distributed.all_reduce(out, group=self.process_group)
return out