mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 12:04:37 +08:00 
			
		
		
		
	Compare commits
	
		
			64 Commits
		
	
	
		
			v4.35.0
			...
			thomas/add
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 46d37bece7 | |||
| bd232fb97d | |||
| 5924480094 | |||
| a009ca5a3b | |||
| 4e383e7a33 | |||
| 870e180f08 | |||
| b395423c8b | |||
| c63e0a7377 | |||
| ccf4b2d610 | |||
| 989fac0556 | |||
| 3fc5ebb9e3 | |||
| c4edcf9c70 | |||
| 94e6226b3a | |||
| 4787e17582 | |||
| 07e9409032 | |||
| 435f8fcbf2 | |||
| 9d4257a74c | |||
| 68944f597d | |||
| 0c10634a36 | |||
| bd7101dace | |||
| ee4fe12157 | |||
| 44430fbe14 | |||
| 6c25f536cc | |||
| 6fd4c59911 | |||
| b70a818d50 | |||
| 2b266f0a39 | |||
| 9e80f9b190 | |||
| 9bdd6c8286 | |||
| 389f196b1c | |||
| b0d6f72b55 | |||
| 9b448ad29d | |||
| 218f457388 | |||
| 616690fa8a | |||
| 12ac8b7ec3 | |||
| 379ef0a888 | |||
| 5bdb735c28 | |||
| 8d8f5864b7 | |||
| 56915604ac | |||
| 99c3d575aa | |||
| 961a341e31 | |||
| 76e18bcfd1 | |||
| 20d9aa6058 | |||
| 82dbd6f64f | |||
| 066f1d8c47 | |||
| 493dc31d78 | |||
| 7168222a3a | |||
| dfaa2e37f9 | |||
| aa47de5a56 | |||
| 9067981e0c | |||
| e4bd0884ce | |||
| ee234a808f | |||
| 62ef07b72e | |||
| a7034e3211 | |||
| c698398bb3 | |||
| 0cad39a4e5 | |||
| d92aa10587 | |||
| 90821b2f13 | |||
| 68f0b97531 | |||
| 2c1e4554d0 | |||
| 452b9d00f8 | |||
| 8941c2df47 | |||
| e8e15437c5 | |||
| 4f98d8e136 | |||
| f4d0dc3c15 | 
							
								
								
									
										5
									
								
								README_BUILD.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								README_BUILD.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
## We provide some functions in order to custom build some kernels
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python setup.py build_ext --inplace
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										33
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								setup.py
									
									
									
									
									
								
							@ -75,8 +75,10 @@ from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from setuptools import find_packages, setup
 | 
			
		||||
 | 
			
		||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
 | 
			
		||||
 | 
			
		||||
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
 | 
			
		||||
 | 
			
		||||
stale_egg_info = Path(__file__).parent / "transformers.egg-info"
 | 
			
		||||
if stale_egg_info.exists():
 | 
			
		||||
    print(
 | 
			
		||||
@ -397,6 +399,31 @@ install_requires = [
 | 
			
		||||
    deps["tokenizers"],
 | 
			
		||||
    deps["tqdm"],  # progress bars in model download and training scripts
 | 
			
		||||
]
 | 
			
		||||
def get_extensions():
 | 
			
		||||
    # TODO @thomasw21 add cpp versions
 | 
			
		||||
    extensions = []
 | 
			
		||||
 | 
			
		||||
    # TODO @thomasw21 build cuda kernels only on some conditions
 | 
			
		||||
    if True:
 | 
			
		||||
        extensions += [
 | 
			
		||||
            CUDAExtension(
 | 
			
		||||
                name="transformers.models.bloom.custom_kernels.fused_bloom_attention_cuda",
 | 
			
		||||
                sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_attention_cuda.cu"],
 | 
			
		||||
                # TODO: understand what that is, probably defines the target architecture
 | 
			
		||||
                #  https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
 | 
			
		||||
                # Build for A100
 | 
			
		||||
                extra_compile_args=["-arch=compute_80", "-std=c++17"],
 | 
			
		||||
            ),
 | 
			
		||||
            CUDAExtension(
 | 
			
		||||
                name="transformers.models.bloom.custom_kernels.fused_bloom_gelu_cuda",
 | 
			
		||||
                sources=["src/transformers/models/bloom/custom_kernels/fused_bloom_gelu_cuda.cu"],
 | 
			
		||||
                # TODO: understand what that is, probably defines the target architecture
 | 
			
		||||
                #  https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-steering-gpu-code-generation-gpu-architecture
 | 
			
		||||
                # Build for A100
 | 
			
		||||
                extra_compile_args=["-arch=compute_80", "-std=c++17"],
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
    return extensions
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name="transformers",
 | 
			
		||||
@ -429,5 +456,9 @@ setup(
 | 
			
		||||
        "Programming Language :: Python :: 3.9",
 | 
			
		||||
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
			
		||||
    ],
 | 
			
		||||
    cmdclass={"deps_table_update": DepsTableUpdateCommand},
 | 
			
		||||
    ext_modules=get_extensions(),
 | 
			
		||||
    cmdclass={
 | 
			
		||||
        "deps_table_update": DepsTableUpdateCommand,
 | 
			
		||||
        "build_ext": BuildExtension
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -481,11 +481,11 @@ class GenerationMixin:
 | 
			
		||||
        if self.config.is_encoder_decoder and encoder_outputs is not None:
 | 
			
		||||
            # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
 | 
			
		||||
            shape = encoder_outputs.last_hidden_state.size()[:-1]
 | 
			
		||||
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100
 | 
			
		||||
            return torch.full(shape, fill_value=-100, dtype=torch.long, device=self.device)
 | 
			
		||||
 | 
			
		||||
        if bos_token_id is None:
 | 
			
		||||
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
 | 
			
		||||
        return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
 | 
			
		||||
        return torch.full((1, 1), fill_value=bos_token_id, dtype=torch.long, device=self.device)
 | 
			
		||||
 | 
			
		||||
    def _prepare_attention_mask_for_generation(
 | 
			
		||||
        self,
 | 
			
		||||
@ -541,7 +541,7 @@ class GenerationMixin:
 | 
			
		||||
            decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
 | 
			
		||||
            if device is None:
 | 
			
		||||
                device = self.device
 | 
			
		||||
            return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
 | 
			
		||||
            return torch.full((batch_size, 1), fill_value=decoder_start_token_id, dtype=torch.long, device=device)
 | 
			
		||||
 | 
			
		||||
    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
 | 
			
		||||
        decoder_start_token_id = (
 | 
			
		||||
@ -1691,7 +1691,7 @@ class GenerationMixin:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # keep track of which sequences are already finished
 | 
			
		||||
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
 | 
			
		||||
        unfinished_sequences = input_ids.new_full((input_ids.shape[0],), fill_value=True, dtype=torch.bool)
 | 
			
		||||
        cur_len = input_ids.shape[-1]
 | 
			
		||||
 | 
			
		||||
        this_peer_finished = False  # used by synced_gpus only
 | 
			
		||||
@ -1752,7 +1752,7 @@ class GenerationMixin:
 | 
			
		||||
            if eos_token_id is not None:
 | 
			
		||||
                if pad_token_id is None:
 | 
			
		||||
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
 | 
			
		||||
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
			
		||||
                next_tokens.masked_fill_(~unfinished_sequences, pad_token_id)
 | 
			
		||||
 | 
			
		||||
            # update generated ids, model inputs, and length for next step
 | 
			
		||||
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 | 
			
		||||
@ -1763,10 +1763,10 @@ class GenerationMixin:
 | 
			
		||||
 | 
			
		||||
            # if eos_token was found in one sentence, set sentence to finished
 | 
			
		||||
            if eos_token_id is not None:
 | 
			
		||||
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
 | 
			
		||||
                unfinished_sequences &= next_tokens != eos_token_id
 | 
			
		||||
 | 
			
		||||
            # stop when each sentence is finished, or if we exceed the maximum length
 | 
			
		||||
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
 | 
			
		||||
            if not (torch.any(unfinished_sequences) and not stopping_criteria(input_ids, scores)):
 | 
			
		||||
                if not synced_gpus:
 | 
			
		||||
                    break
 | 
			
		||||
                else:
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
 | 
			
		||||
 | 
			
		||||
import torch.distributed
 | 
			
		||||
from transformers import is_torch_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -130,6 +131,7 @@ class BloomConfig(PretrainedConfig):
 | 
			
		||||
        attention_dropout=0.0,
 | 
			
		||||
        pretraining_tp=1,  # TP rank used when training with megatron
 | 
			
		||||
        slow_but_exact=False,
 | 
			
		||||
        tp_parallel=False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
@ -149,6 +151,7 @@ class BloomConfig(PretrainedConfig):
 | 
			
		||||
        self.bos_token_id = bos_token_id
 | 
			
		||||
        self.eos_token_id = eos_token_id
 | 
			
		||||
        self.slow_but_exact = slow_but_exact
 | 
			
		||||
        self.tp_parallel = tp_parallel
 | 
			
		||||
 | 
			
		||||
        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,250 @@
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <THC/THCAtomics.cuh>
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <torch/torch.h>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
 | 
			
		||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
 | 
			
		||||
**/
 | 
			
		||||
 | 
			
		||||
// Available in pytorch main
 | 
			
		||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
* Forward passes
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
 | 
			
		||||
**/
 | 
			
		||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
 | 
			
		||||
__global__ void forward_masked_softmax_kernel(
 | 
			
		||||
    const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
 | 
			
		||||
    const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
 | 
			
		||||
    torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
 | 
			
		||||
    const int64_t effective_kv_length,
 | 
			
		||||
    const dim3 blockDim,
 | 
			
		||||
    const int64_t rows_per_block,
 | 
			
		||||
    const int64_t kv_length,
 | 
			
		||||
    const int64_t batch_size
 | 
			
		||||
) {
 | 
			
		||||
    const auto row_id = threadIdx.x / effective_kv_length;
 | 
			
		||||
    const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
 | 
			
		||||
    const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
 | 
			
		||||
    auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
 | 
			
		||||
    kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
 | 
			
		||||
    const auto kv_length_end = kv_length_end_;
 | 
			
		||||
 | 
			
		||||
    const auto batch_id = blockIdx.x * rows_per_block + row_id;
 | 
			
		||||
 | 
			
		||||
    // We need 2 float storage for each row, one for max computation, the other for normalizing exponential
 | 
			
		||||
    extern __shared__ float temp_storage[];
 | 
			
		||||
    const auto row_id_mem_offset = row_id * 2;
 | 
			
		||||
    if (effective_kv_length_id == 0) {
 | 
			
		||||
        temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
 | 
			
		||||
        temp_storage[row_id_mem_offset + 1] = 0;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Compute mask and max
 | 
			
		||||
    if (batch_id < batch_size) {
 | 
			
		||||
        float thread_max = -std::numeric_limits<float>::infinity();
 | 
			
		||||
        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
 | 
			
		||||
            if (mask[batch_id][kv_length_id] == 0) {
 | 
			
		||||
                const float candidate = attention_scores[batch_id][kv_length_id];
 | 
			
		||||
                thread_max = (thread_max < candidate) ? candidate : thread_max;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (thread_max != -std::numeric_limits<float>::infinity()) {
 | 
			
		||||
            // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
 | 
			
		||||
            gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Compute exp(elt - max) masked
 | 
			
		||||
    float exponential[min_kv_length_shard_size_per_thread];
 | 
			
		||||
    if (batch_id < batch_size) {
 | 
			
		||||
        float thread_add = 0;
 | 
			
		||||
        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
 | 
			
		||||
            if (mask[batch_id][kv_length_id] == 0) {
 | 
			
		||||
                exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
 | 
			
		||||
                thread_add = thread_add + exponential[kv_length_id - kv_length_start];
 | 
			
		||||
            } else {
 | 
			
		||||
                exponential[kv_length_id - kv_length_start] = 0.;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        if (thread_add > 0) {
 | 
			
		||||
            // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
 | 
			
		||||
            gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Compute softmax
 | 
			
		||||
    if (batch_id < batch_size) {
 | 
			
		||||
        // If sum of all exponential is 0, we set the softmax values to 0
 | 
			
		||||
        if (temp_storage[row_id_mem_offset + 1] == 0.) {
 | 
			
		||||
            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
 | 
			
		||||
                result[batch_id][kv_length_id] = 0.;
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
 | 
			
		||||
                result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
 | 
			
		||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 | 
			
		||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
 | 
			
		||||
    const at::Tensor fused_qkv,
 | 
			
		||||
    const std::optional<std::vector<at::Tensor>> layer_past,
 | 
			
		||||
    const at::Tensor alibi,
 | 
			
		||||
    const at::Tensor attention_mask,
 | 
			
		||||
    const std::optional<at::Tensor> head_mask,
 | 
			
		||||
    const float beta,
 | 
			
		||||
    const float inv_norm_factor,
 | 
			
		||||
    const int num_heads,
 | 
			
		||||
    const bool use_cache
 | 
			
		||||
) {
 | 
			
		||||
    const auto batch_size = fused_qkv.size(0);
 | 
			
		||||
    const auto q_length = fused_qkv.size(1);
 | 
			
		||||
    const auto three_times_hidden_size = fused_qkv.size(2);
 | 
			
		||||
    const auto head_dim = three_times_hidden_size / (3 * num_heads);
 | 
			
		||||
    const auto batch_size_times_num_heads = batch_size * num_heads;
 | 
			
		||||
 | 
			
		||||
    // `split_heads`
 | 
			
		||||
    const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
 | 
			
		||||
    const auto tensor_list = fused_qkv_view.split(head_dim, -1);
 | 
			
		||||
    const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
 | 
			
		||||
    auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
 | 
			
		||||
    auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
 | 
			
		||||
 | 
			
		||||
    if (layer_past) {
 | 
			
		||||
        const auto past_key = (*layer_past).at(0);
 | 
			
		||||
        const auto past_value = (*layer_past).at(1);
 | 
			
		||||
        key_layer = at::cat({past_key, key_layer}, 2);
 | 
			
		||||
        value_layer = at::cat({past_value, value_layer}, 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::optional<std::vector<at::Tensor>> present;
 | 
			
		||||
    if (use_cache) {
 | 
			
		||||
        present = {key_layer, value_layer};
 | 
			
		||||
    } else {
 | 
			
		||||
        present = {};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
 | 
			
		||||
 | 
			
		||||
    // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
 | 
			
		||||
    at::Tensor attention_probs;
 | 
			
		||||
    if (true) {
 | 
			
		||||
        const auto kv_length = key_layer.size(2);
 | 
			
		||||
 | 
			
		||||
        // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
 | 
			
		||||
        const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
 | 
			
		||||
        const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
 | 
			
		||||
 | 
			
		||||
        // Custom kernel
 | 
			
		||||
        attention_probs = at::empty_like(attention_scores_2d);
 | 
			
		||||
 | 
			
		||||
        // Check that inputs and contiguous + cuda tensors
 | 
			
		||||
        CHECK_INPUT(attention_scores_2d);
 | 
			
		||||
        CHECK_INPUT(attention_mask_2d);
 | 
			
		||||
 | 
			
		||||
        // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
 | 
			
		||||
        // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
 | 
			
		||||
        AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
 | 
			
		||||
            /*
 | 
			
		||||
            * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
 | 
			
		||||
            * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
 | 
			
		||||
            *  - SMs: 108
 | 
			
		||||
            *  - TPCs: 56 (What's that?)
 | 
			
		||||
            *  - Memory size: 40 GB
 | 
			
		||||
            *  - L2 Cache size: 40960 KB (shared across all SMs)
 | 
			
		||||
            *  - L1/Shared memory size: 192 KB (shared across all threads within a SM)
 | 
			
		||||
            *  - Max Threads / SM: 2048
 | 
			
		||||
            *  - Max Thread Blocks / SM: 32
 | 
			
		||||
            */
 | 
			
		||||
 | 
			
		||||
            /*
 | 
			
		||||
            * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
 | 
			
		||||
            * with multiple threads as we need to `sync_threads` to run exponential sum.
 | 
			
		||||
            * We maximise the usage of threads within a single block
 | 
			
		||||
            */
 | 
			
		||||
            // TODO @thomasw21 figure out everything warp related:
 | 
			
		||||
            //  - why do they have to be power of 2
 | 
			
		||||
            // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
 | 
			
		||||
            const auto MAX_THREADS_PER_SM = 1024;
 | 
			
		||||
            // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
 | 
			
		||||
            const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
 | 
			
		||||
            // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
 | 
			
		||||
            const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
 | 
			
		||||
            const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
 | 
			
		||||
            const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
 | 
			
		||||
 | 
			
		||||
            const dim3 gridDim(num_blocks); // Number of blocks that run
 | 
			
		||||
            const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
 | 
			
		||||
            const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
 | 
			
		||||
 | 
			
		||||
            // 192 * 2 ** 10
 | 
			
		||||
            // const auto MAX_L1_MEMORY = 196608;
 | 
			
		||||
            // const auto MAX_SMs = 108;
 | 
			
		||||
            // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
 | 
			
		||||
            // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
 | 
			
		||||
            // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
 | 
			
		||||
 | 
			
		||||
            forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
 | 
			
		||||
                attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
 | 
			
		||||
                attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
 | 
			
		||||
                attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
 | 
			
		||||
                effective_kv_length,
 | 
			
		||||
                blockDim,
 | 
			
		||||
                rows_per_block,
 | 
			
		||||
                kv_length,
 | 
			
		||||
                batch_size_times_num_heads * q_length
 | 
			
		||||
            );
 | 
			
		||||
        });
 | 
			
		||||
        attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
 | 
			
		||||
    } else {
 | 
			
		||||
        // Pytorch C++ API
 | 
			
		||||
        auto input_dtype = attention_scores.scalar_type();
 | 
			
		||||
        if (input_dtype == at::ScalarType::Float) {
 | 
			
		||||
            attention_scores = attention_scores.to(at::ScalarType::Float);
 | 
			
		||||
        };
 | 
			
		||||
        // TODO @thomasw21 Figure out how to get minimum value
 | 
			
		||||
        auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
 | 
			
		||||
        attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto context_layer = attention_probs.bmm(value_layer);
 | 
			
		||||
 | 
			
		||||
    // `_merge_heads`
 | 
			
		||||
    context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
 | 
			
		||||
    context_layer = context_layer.permute({0, 2, 1, 3});
 | 
			
		||||
    context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
 | 
			
		||||
 | 
			
		||||
    return std::make_tuple(context_layer, present, attention_probs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    m.def(
 | 
			
		||||
        "forward",
 | 
			
		||||
        &forward,
 | 
			
		||||
        "Bloom attention mechanism forward (CUDA)"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,97 @@
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
 | 
			
		||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
 | 
			
		||||
**/
 | 
			
		||||
 | 
			
		||||
// Available in pytorch main
 | 
			
		||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
 | 
			
		||||
//  at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
* Forward passes
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
* compute GELU: `x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))`
 | 
			
		||||
**/
 | 
			
		||||
template<typename scalar_t, int64_t max_threads_per_sm>
 | 
			
		||||
__global__ void forward_kernel(
 | 
			
		||||
    const scalar_t* __restrict__ x, // [B, N, D]
 | 
			
		||||
    scalar_t* __restrict__ result, // [B, N, D]
 | 
			
		||||
    int64_t num_params
 | 
			
		||||
) {
 | 
			
		||||
    const auto id = blockIdx.x * max_threads_per_sm + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    if (num_params <= id) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // We upcast to float always
 | 
			
		||||
    const float elt = x[id];
 | 
			
		||||
 | 
			
		||||
    // Compute gelu
 | 
			
		||||
    // TODO @thomasw21: Figure out where to find a tanh implementation that works for all kinds of scalar types. (I could hardcode it)
 | 
			
		||||
    result[id] = static_cast<scalar_t>(elt * 0.5 * (1.0 + std::tanh(0.79788456 * elt * (1 + 0.044715 * elt * elt))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
 | 
			
		||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 | 
			
		||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
 | 
			
		||||
 | 
			
		||||
at::Tensor forward(
 | 
			
		||||
    const at::Tensor x
 | 
			
		||||
) {
 | 
			
		||||
    CHECK_INPUT(x);
 | 
			
		||||
 | 
			
		||||
    const auto result = at::empty_like(x);
 | 
			
		||||
 | 
			
		||||
    // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
 | 
			
		||||
    // DISPATCH_CASE_FLOATING_TYPES(key_layer.scalar_type(), "masked_softmax", [&] {
 | 
			
		||||
    AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "gelu", [&] {
 | 
			
		||||
        /*
 | 
			
		||||
        * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
 | 
			
		||||
        * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
 | 
			
		||||
        *  - SMs: 108
 | 
			
		||||
        *  - TPCs: 56 (What's that?)
 | 
			
		||||
        *  - Memory size: 40 GB
 | 
			
		||||
        *  - L2 Cache size: 40960 KB (shared across all SMs)
 | 
			
		||||
        *  - L1/Shared memory size: 192 KB (shared across all threads within a SM)
 | 
			
		||||
        *  - Max Threads / SM: 2048
 | 
			
		||||
        *  - Max Thread Blocks / SM: 32
 | 
			
		||||
        */
 | 
			
		||||
 | 
			
		||||
        const auto MAX_THREADS_PER_SM = 1024; // TODO @thomas21 check why everyone is setting 1024 when officially it's 1024
 | 
			
		||||
        const auto num_params = x.numel(); // TODO @thomasw21 get `x.size()`
 | 
			
		||||
        const auto NUM_BLOCKS = (num_params - 1) / MAX_THREADS_PER_SM + 1;
 | 
			
		||||
 | 
			
		||||
        dim3 gridDim(NUM_BLOCKS); // Number of blocks that run
 | 
			
		||||
        dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
 | 
			
		||||
 | 
			
		||||
        // 192 * 2 ** 10
 | 
			
		||||
        // const auto MAX_L1_MEMORY = 196608;
 | 
			
		||||
        // const auto MAX_SMs = 108;
 | 
			
		||||
 | 
			
		||||
        forward_kernel<scalar_t, MAX_THREADS_PER_SM><<<gridDim, blockDim>>>(
 | 
			
		||||
            x.data<scalar_t>(),
 | 
			
		||||
            result.data<scalar_t>(),
 | 
			
		||||
            num_params
 | 
			
		||||
        );
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    m.def(
 | 
			
		||||
        "forward",
 | 
			
		||||
        &forward,
 | 
			
		||||
        "Bloom GELU mechanism forward (CUDA)"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
@ -19,6 +19,7 @@ import warnings
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed
 | 
			
		||||
import torch.utils.checkpoint
 | 
			
		||||
from torch import nn
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
 | 
			
		||||
@ -34,10 +35,18 @@ from ...modeling_outputs import (
 | 
			
		||||
from ...modeling_utils import PreTrainedModel
 | 
			
		||||
from ...utils import logging
 | 
			
		||||
from .configuration_bloom import BloomConfig
 | 
			
		||||
 | 
			
		||||
from .parallel_layers import TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
CUSTOM_KERNELS_ENABLED=False
 | 
			
		||||
try:
 | 
			
		||||
    from .custom_kernels import fused_bloom_attention_cuda
 | 
			
		||||
    from .custom_kernels import fused_bloom_gelu_cuda
 | 
			
		||||
    CUSTOM_KERNELS_ENABLED=True
 | 
			
		||||
except ImportError:
 | 
			
		||||
    logger.warning("We're not using custom kernels.")
 | 
			
		||||
 | 
			
		||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
 | 
			
		||||
_CONFIG_FOR_DOC = "BloomConfig"
 | 
			
		||||
_TOKENIZER_FOR_DOC = "BloomTokenizerFast"
 | 
			
		||||
@ -68,7 +77,7 @@ def _make_causal_mask(
 | 
			
		||||
    if past_key_values_length > 0:
 | 
			
		||||
        mask[:, :past_key_values_length] = False
 | 
			
		||||
 | 
			
		||||
    expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
 | 
			
		||||
    expanded_mask = mask[None, :, :].expand(batch_size, target_length, target_length + past_key_values_length)
 | 
			
		||||
    return expanded_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,11 +88,11 @@ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
 | 
			
		||||
    batch_size, src_length = mask.shape
 | 
			
		||||
    tgt_length = tgt_length if tgt_length is not None else src_length
 | 
			
		||||
 | 
			
		||||
    expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
 | 
			
		||||
    return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
 | 
			
		||||
    expanded_mask = ~(mask[:, None, :].to(torch.bool))
 | 
			
		||||
    return expanded_mask.expand(batch_size, tgt_length, src_length)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 | 
			
		||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
 | 
			
		||||
    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
 | 
			
		||||
@ -124,9 +133,9 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
 | 
			
		||||
    # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
 | 
			
		||||
    arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
 | 
			
		||||
    alibi = slopes[..., None] * arange_tensor
 | 
			
		||||
    return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
 | 
			
		||||
 | 
			
		||||
    return alibi
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Dropout add function
 | 
			
		||||
@ -145,7 +154,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
 | 
			
		||||
    out = residual + out
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.jit.script
 | 
			
		||||
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
 | 
			
		||||
@ -179,8 +188,13 @@ def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
class GeLUFunction(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        ctx.save_for_backward(input)
 | 
			
		||||
        return bloom_gelu_forward(input)
 | 
			
		||||
        if False and CUSTOM_KERNELS_ENABLED:
 | 
			
		||||
            """My kernel is actually still slow compared to what `jit` provides."""
 | 
			
		||||
            raise ValueError("WTF")
 | 
			
		||||
            fused_bloom_gelu_cuda.foward(input)
 | 
			
		||||
        else:
 | 
			
		||||
            ctx.save_for_backward(input)
 | 
			
		||||
            return bloom_gelu_forward(input)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
@ -207,9 +221,57 @@ class BloomGelu(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            return bloom_gelu_forward(x)
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script # this is shit for unknow reasons.
 | 
			
		||||
def _split_heads(fused_qkv: torch.Tensor, num_heads: int, head_dim: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
 | 
			
		||||
    storage as `fused_qkv`
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
        value: [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
    """
 | 
			
		||||
    batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
 | 
			
		||||
    fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
 | 
			
		||||
    query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
 | 
			
		||||
 | 
			
		||||
    query_layer = query_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
 | 
			
		||||
    key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * num_heads, head_dim, seq_length)
 | 
			
		||||
    value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_heads, seq_length, head_dim)
 | 
			
		||||
 | 
			
		||||
    return query_layer, key_layer, value_layer
 | 
			
		||||
 | 
			
		||||
# @torch.jit.script
 | 
			
		||||
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Merge heads together over the last dimenstion
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 | 
			
		||||
    """
 | 
			
		||||
    # What we want to achieve is:
 | 
			
		||||
    # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
 | 
			
		||||
    batch_size_and_num_heads, seq_length, _ = x.shape
 | 
			
		||||
    batch_size = batch_size_and_num_heads // num_heads
 | 
			
		||||
 | 
			
		||||
    # First view to decompose the batch size
 | 
			
		||||
    # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
 | 
			
		||||
    x = x.view(batch_size, num_heads, seq_length, head_dim)
 | 
			
		||||
 | 
			
		||||
    # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
 | 
			
		||||
    x = x.permute(0, 2, 1, 3)
 | 
			
		||||
 | 
			
		||||
    # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
 | 
			
		||||
    return x.reshape(batch_size, seq_length, num_heads * head_dim)
 | 
			
		||||
 | 
			
		||||
class BloomAttention(nn.Module):
 | 
			
		||||
    def __init__(self, config: BloomConfig):
 | 
			
		||||
    def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.pretraining_tp = config.pretraining_tp
 | 
			
		||||
@ -231,72 +293,37 @@ class BloomAttention(nn.Module):
 | 
			
		||||
        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
 | 
			
		||||
        self.beta = 1.0
 | 
			
		||||
 | 
			
		||||
        self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
 | 
			
		||||
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
 | 
			
		||||
        if process_group is None:
 | 
			
		||||
            self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
 | 
			
		||||
            self.dense = nn.Linear(self.hidden_size, self.hidden_size)
 | 
			
		||||
        else:
 | 
			
		||||
            assert self.num_heads % process_group.size() == 0
 | 
			
		||||
            self.num_heads = self.num_heads // process_group.size()
 | 
			
		||||
            self.query_key_value = TensorParallelColumnLinear(self.hidden_size, 3 * self.hidden_size, process_group=process_group, bias=True)
 | 
			
		||||
            self.dense = TensorParallelRowLinear(self.hidden_size, self.hidden_size, process_group=process_group)
 | 
			
		||||
        self.attention_dropout = nn.Dropout(config.attention_dropout)
 | 
			
		||||
 | 
			
		||||
    def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
 | 
			
		||||
        storage as `fused_qkv`
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
            value: [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
        """
 | 
			
		||||
        batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
 | 
			
		||||
        fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
 | 
			
		||||
        return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
 | 
			
		||||
 | 
			
		||||
    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Merge heads together over the last dimenstion
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 | 
			
		||||
        """
 | 
			
		||||
        # What we want to achieve is:
 | 
			
		||||
        # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
 | 
			
		||||
        batch_size_and_num_heads, seq_length, _ = x.shape
 | 
			
		||||
        batch_size = batch_size_and_num_heads // self.num_heads
 | 
			
		||||
 | 
			
		||||
        # First view to decompose the batch size
 | 
			
		||||
        # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
 | 
			
		||||
        x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
 | 
			
		||||
 | 
			
		||||
        # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
 | 
			
		||||
        x = x.permute(0, 2, 1, 3)
 | 
			
		||||
 | 
			
		||||
        # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
 | 
			
		||||
        return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        residual: torch.Tensor,
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def compute_attention(
 | 
			
		||||
        fused_qkv: torch.Tensor,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        output_attentions: bool = False,
 | 
			
		||||
        head_mask: Optional[torch.Tensor],
 | 
			
		||||
        beta: float,
 | 
			
		||||
        inv_norm_factor: float,
 | 
			
		||||
        num_heads: int,
 | 
			
		||||
        use_cache: bool
 | 
			
		||||
    ):
 | 
			
		||||
        fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]
 | 
			
		||||
        batch_size, q_length, three_times_hidden_size = fused_qkv.shape
 | 
			
		||||
        head_dim = three_times_hidden_size // (3 * num_heads)
 | 
			
		||||
        batch_size_times_num_heads = batch_size * num_heads
 | 
			
		||||
 | 
			
		||||
        ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
 | 
			
		||||
        # 3 x [batch_size, seq_length, num_heads, head_dim]
 | 
			
		||||
        (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
 | 
			
		||||
        (query_layer, key_layer, value_layer) = _split_heads(fused_qkv, num_heads=num_heads,
 | 
			
		||||
                                                             head_dim=head_dim)
 | 
			
		||||
 | 
			
		||||
        batch_size, q_length, _, _ = query_layer.shape
 | 
			
		||||
 | 
			
		||||
        query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
			
		||||
        key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
 | 
			
		||||
        value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            past_key, past_value = layer_past
 | 
			
		||||
            # concatenate along seq_length dimension:
 | 
			
		||||
@ -311,41 +338,81 @@ class BloomAttention(nn.Module):
 | 
			
		||||
            present = (key_layer, value_layer)
 | 
			
		||||
        else:
 | 
			
		||||
            present = None
 | 
			
		||||
        ###
 | 
			
		||||
 | 
			
		||||
        # [batch_size * num_heads, q_length, kv_length]
 | 
			
		||||
        # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
 | 
			
		||||
        matmul_result = alibi.baddbmm(
 | 
			
		||||
        attention_scores = alibi.baddbmm(
 | 
			
		||||
            batch1=query_layer,
 | 
			
		||||
            batch2=key_layer,
 | 
			
		||||
            beta=self.beta,
 | 
			
		||||
            alpha=self.inv_norm_factor,
 | 
			
		||||
            beta=beta,
 | 
			
		||||
            alpha=inv_norm_factor,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # change view to [batch_size, num_heads, q_length, kv_length]
 | 
			
		||||
        attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
 | 
			
		||||
 | 
			
		||||
        # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
 | 
			
		||||
        input_dtype = attention_scores.dtype
 | 
			
		||||
        # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
 | 
			
		||||
        if input_dtype == torch.float16:
 | 
			
		||||
            attention_scores = attention_scores.to(torch.float)
 | 
			
		||||
        attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
 | 
			
		||||
        # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
 | 
			
		||||
        attn_weights = attention_scores.masked_fill_(attention_mask, torch.finfo(attention_scores.dtype).min)
 | 
			
		||||
        attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
 | 
			
		||||
 | 
			
		||||
        # [batch_size, num_heads, q_length, kv_length]
 | 
			
		||||
        attention_probs = self.attention_dropout(attention_probs)
 | 
			
		||||
        # # [batch_size, num_heads, q_length, kv_length]
 | 
			
		||||
        # attention_probs = self.attention_dropout(attention_probs)
 | 
			
		||||
 | 
			
		||||
        if head_mask is not None:
 | 
			
		||||
            attention_probs = attention_probs * head_mask
 | 
			
		||||
 | 
			
		||||
        # change view [batch_size x num_heads, q_length, kv_length]
 | 
			
		||||
        attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
 | 
			
		||||
 | 
			
		||||
        # matmul: [batch_size * num_heads, q_length, head_dim]
 | 
			
		||||
        context_layer = torch.bmm(attention_probs_reshaped, value_layer)
 | 
			
		||||
        context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
 | 
			
		||||
 | 
			
		||||
        # change view [batch_size, num_heads, q_length, head_dim]
 | 
			
		||||
        context_layer = self._merge_heads(context_layer)
 | 
			
		||||
        context_layer = _merge_heads(context_layer, num_heads=num_heads, head_dim=head_dim)
 | 
			
		||||
 | 
			
		||||
        return context_layer, present, attention_probs
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        residual: torch.Tensor,
 | 
			
		||||
        alibi: torch.Tensor,
 | 
			
		||||
        attention_mask: torch.Tensor,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 | 
			
		||||
        head_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: bool = False,
 | 
			
		||||
        output_attentions: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]
 | 
			
		||||
        batch_size, q_length, _ = fused_qkv.shape
 | 
			
		||||
 | 
			
		||||
        if CUSTOM_KERNELS_ENABLED:
 | 
			
		||||
            assert self.training is False, "Only foward pass was implemented"
 | 
			
		||||
            assert attention_mask.shape[-1] < 4096, "Custom kernel support only up to 4096 tokens"
 | 
			
		||||
            context_layer, present, attention_probs = fused_bloom_attention_cuda.forward(
 | 
			
		||||
                fused_qkv,
 | 
			
		||||
                layer_past,
 | 
			
		||||
                alibi,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                head_mask,
 | 
			
		||||
                self.beta,
 | 
			
		||||
                self.inv_norm_factor,
 | 
			
		||||
                self.num_heads,
 | 
			
		||||
                use_cache
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Block this path while we figure out how to run C++ code")
 | 
			
		||||
            context_layer, present, attention_probs = self.compute_attention(
 | 
			
		||||
                fused_qkv=fused_qkv,
 | 
			
		||||
                layer_past=layer_past,
 | 
			
		||||
                alibi=alibi,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                head_mask=head_mask,
 | 
			
		||||
                beta=self.beta,
 | 
			
		||||
                inv_norm_factor=self.inv_norm_factor,
 | 
			
		||||
                num_heads=self.num_heads,
 | 
			
		||||
                use_cache=use_cache
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
 | 
			
		||||
        if self.pretraining_tp > 1 and self.slow_but_exact:
 | 
			
		||||
@ -359,7 +426,8 @@ class BloomAttention(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            output_tensor = self.dense(context_layer)
 | 
			
		||||
 | 
			
		||||
        output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
 | 
			
		||||
        # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
 | 
			
		||||
        output_tensor += residual
 | 
			
		||||
 | 
			
		||||
        outputs = (output_tensor, present)
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
@ -369,15 +437,19 @@ class BloomAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BloomMLP(nn.Module):
 | 
			
		||||
    def __init__(self, config: BloomConfig):
 | 
			
		||||
    def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.pretraining_tp = config.pretraining_tp
 | 
			
		||||
        self.slow_but_exact = config.slow_but_exact
 | 
			
		||||
        self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
 | 
			
		||||
        if process_group is None:
 | 
			
		||||
            self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
 | 
			
		||||
            self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
 | 
			
		||||
        else:
 | 
			
		||||
            self.dense_h_to_4h = TensorParallelColumnLinear(hidden_size, 4 * hidden_size, process_group=process_group)
 | 
			
		||||
            self.dense_4h_to_h = TensorParallelRowLinear(4 * hidden_size, hidden_size, process_group=process_group)
 | 
			
		||||
        self.gelu_impl = BloomGelu()
 | 
			
		||||
        self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
 | 
			
		||||
        self.hidden_dropout = config.hidden_dropout
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
@ -394,22 +466,23 @@ class BloomMLP(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            intermediate_output = self.dense_4h_to_h(hidden_states)
 | 
			
		||||
 | 
			
		||||
        output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
 | 
			
		||||
        # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
 | 
			
		||||
        intermediate_output += residual
 | 
			
		||||
 | 
			
		||||
        return output
 | 
			
		||||
        return intermediate_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BloomBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config: BloomConfig):
 | 
			
		||||
    def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
        self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 | 
			
		||||
        self.num_heads = config.n_head
 | 
			
		||||
        self.self_attention = BloomAttention(config)
 | 
			
		||||
        self.self_attention = BloomAttention(config, process_group=process_group)
 | 
			
		||||
        self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 | 
			
		||||
 | 
			
		||||
        self.mlp = BloomMLP(config)
 | 
			
		||||
        self.mlp = BloomMLP(config, process_group=process_group)
 | 
			
		||||
 | 
			
		||||
        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
 | 
			
		||||
        self.hidden_dropout = config.hidden_dropout
 | 
			
		||||
@ -581,18 +654,23 @@ BLOOM_INPUTS_DOCSTRING = r"""
 | 
			
		||||
    BLOOM_START_DOCSTRING,
 | 
			
		||||
)
 | 
			
		||||
class BloomModel(BloomPreTrainedModel):
 | 
			
		||||
    def __init__(self, config: BloomConfig):
 | 
			
		||||
    def __init__(self, config: BloomConfig, process_group: Optional[torch.distributed.ProcessGroup]=None):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
 | 
			
		||||
        self.embed_dim = config.hidden_size
 | 
			
		||||
        self.num_heads = config.n_head
 | 
			
		||||
 | 
			
		||||
        # Embedding + LN Embedding
 | 
			
		||||
        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
 | 
			
		||||
        if process_group is None:
 | 
			
		||||
            self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
 | 
			
		||||
        else:
 | 
			
		||||
            self.word_embeddings = TensorParallelEmbedding(config.vocab_size, self.embed_dim, process_group=process_group)
 | 
			
		||||
            self.tp_rank = process_group.rank()
 | 
			
		||||
            self.tp_world_size = process_group.size()
 | 
			
		||||
        self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 | 
			
		||||
 | 
			
		||||
        # Transformer blocks
 | 
			
		||||
        self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
 | 
			
		||||
        self.h = nn.ModuleList([BloomBlock(config, process_group=process_group) for _ in range(config.num_hidden_layers)])
 | 
			
		||||
 | 
			
		||||
        # Final Layer Norm
 | 
			
		||||
        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 | 
			
		||||
@ -609,7 +687,7 @@ class BloomModel(BloomPreTrainedModel):
 | 
			
		||||
        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
 | 
			
		||||
    ) -> torch.BoolTensor:
 | 
			
		||||
        # create causal mask
 | 
			
		||||
        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 | 
			
		||||
        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
 | 
			
		||||
        combined_attention_mask = None
 | 
			
		||||
        device = attention_mask.device
 | 
			
		||||
        _, src_length = input_shape
 | 
			
		||||
@ -619,7 +697,7 @@ class BloomModel(BloomPreTrainedModel):
 | 
			
		||||
                input_shape, device=device, past_key_values_length=past_key_values_length
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 | 
			
		||||
        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
 | 
			
		||||
        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 | 
			
		||||
        combined_attention_mask = (
 | 
			
		||||
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
 | 
			
		||||
@ -705,7 +783,7 @@ class BloomModel(BloomPreTrainedModel):
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = attention_mask.to(hidden_states.device)
 | 
			
		||||
 | 
			
		||||
        alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
 | 
			
		||||
        alibi = build_alibi_tensor(attention_mask, self.num_heads)
 | 
			
		||||
 | 
			
		||||
        causal_mask = self._prepare_attn_mask(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
@ -713,6 +791,18 @@ class BloomModel(BloomPreTrainedModel):
 | 
			
		||||
            past_key_values_length=past_key_values_length,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if hasattr(self, "tp_rank"):
 | 
			
		||||
            assert self.num_heads % self.tp_world_size == 0
 | 
			
		||||
            block_size = self.num_heads // self.tp_world_size
 | 
			
		||||
            alibi = alibi[:, self.tp_rank * block_size: (self.tp_rank + 1) * block_size]
 | 
			
		||||
            alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
 | 
			
		||||
            causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
 | 
			
		||||
        else:
 | 
			
		||||
            alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
 | 
			
		||||
            causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
 | 
			
		||||
 | 
			
		||||
        alibi = alibi.to(hidden_states.dtype)
 | 
			
		||||
 | 
			
		||||
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 | 
			
		||||
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
@ -787,8 +877,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: BloomConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.transformer = BloomModel(config)
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
        print(config)
 | 
			
		||||
        ### HACK
 | 
			
		||||
        if config.tp_parallel:
 | 
			
		||||
            process_group = torch.distributed.distributed_c10d._get_default_group()
 | 
			
		||||
        else:
 | 
			
		||||
            process_group = None
 | 
			
		||||
        ###
 | 
			
		||||
        self.transformer = BloomModel(config, process_group)
 | 
			
		||||
 | 
			
		||||
        if process_group is None:
 | 
			
		||||
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO @thomasw21: TensorParallelColumnLinear doesn't inherit nn.Linear anymore as we switch the underlying storage
 | 
			
		||||
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size // process_group.size(), bias=False)
 | 
			
		||||
            # self.lm_head = TensorParallelColumnLinear(config.hidden_size, config.vocab_size, process_group=process_group, bias=False)
 | 
			
		||||
 | 
			
		||||
        # Initialize weights and apply final processing
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										235
									
								
								src/transformers/models/bloom/parallel_layers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								src/transformers/models/bloom/parallel_layers.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,235 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from torch import nn
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorParallelColumnLinear(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        in_features,
 | 
			
		||||
        out_features,
 | 
			
		||||
        process_group: torch.distributed.ProcessGroup,
 | 
			
		||||
        bias=True,
 | 
			
		||||
        device=None,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ):
 | 
			
		||||
        factory_kwargs = {'device': device, 'dtype': dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.process_group = process_group
 | 
			
		||||
        self.tp_world_size = process_group.size()
 | 
			
		||||
 | 
			
		||||
        assert out_features % self.tp_world_size == 0
 | 
			
		||||
 | 
			
		||||
        self.in_features = in_features
 | 
			
		||||
        self.out_features = out_features // self.tp_world_size
 | 
			
		||||
        # We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
 | 
			
		||||
        self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
 | 
			
		||||
        if bias:
 | 
			
		||||
            self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter('bias', None)
 | 
			
		||||
 | 
			
		||||
        self.reset_parameters()
 | 
			
		||||
 | 
			
		||||
    def reset_parameters(self) -> None:
 | 
			
		||||
        """From `torch.nn.Linear`"""
 | 
			
		||||
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
 | 
			
		||||
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
 | 
			
		||||
        # https://github.com/pytorch/pytorch/issues/57109
 | 
			
		||||
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
 | 
			
		||||
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
 | 
			
		||||
            nn.init.uniform_(self.bias, -bound, bound)
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        """From `torch.nn.Linear`"""
 | 
			
		||||
        return 'in_features={}, out_features={}, bias={}'.format(
 | 
			
		||||
            self.in_features, self.out_features, self.bias is not None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    def linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
 | 
			
		||||
        # # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
 | 
			
		||||
        # in_features, out_features = weight.shape
 | 
			
		||||
        # size_out = input.size()[:-1] + (out_features,)
 | 
			
		||||
        # # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
 | 
			
		||||
        # return torch.addmm(bias, input.view(-1, in_features), weight).view(size_out)
 | 
			
		||||
 | 
			
		||||
        in_features, out_features = weight.shape
 | 
			
		||||
        size_out = input.size()[:-1] + (out_features,)
 | 
			
		||||
        # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
 | 
			
		||||
        input = input.view(-1, in_features)
 | 
			
		||||
        # HACK @thomas21: turns out `aten::addmm.out` is not decomposed
 | 
			
		||||
        out = torch.empty((0,), device=input.device, dtype=input.dtype)
 | 
			
		||||
        out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
 | 
			
		||||
 | 
			
		||||
        return out.view(size_out)
 | 
			
		||||
 | 
			
		||||
        # return F.linear(input, weight=weight.transpose(1,0), bias=bias)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        out = self.linear(input, weight=self.weight, bias=self.bias)
 | 
			
		||||
 | 
			
		||||
        # ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
 | 
			
		||||
        # out_from_tp_ranks = [torch.empty_like(out) for _ in range(self.process_group.size())]
 | 
			
		||||
        # torch.distributed.all_gather(out_from_tp_ranks, out, group=self.process_group)
 | 
			
		||||
        # sharded_out = torch.cat(out_from_tp_ranks, dim=-1)
 | 
			
		||||
        #
 | 
			
		||||
        # weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
 | 
			
		||||
        # bias_from_tp_ranks = [torch.empty_like(self.bias) for _ in range(self.process_group.size())]
 | 
			
		||||
        # torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
 | 
			
		||||
        # torch.distributed.all_gather(bias_from_tp_ranks, self.bias, group=self.process_group)
 | 
			
		||||
        # weight = torch.cat(weight_from_tp_ranks, dim=0)
 | 
			
		||||
        # bias = torch.cat(bias_from_tp_ranks, dim=0)
 | 
			
		||||
        # baseline_out = F.linear(input, weight, bias)
 | 
			
		||||
        #
 | 
			
		||||
        # torch.testing.assert_close(sharded_out, baseline_out, atol=0.0, rtol=0.0)
 | 
			
		||||
        # ###
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
class TensorParallelRowLinear(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        in_features,
 | 
			
		||||
        out_features,
 | 
			
		||||
        process_group: torch.distributed.ProcessGroup,
 | 
			
		||||
        bias=True,
 | 
			
		||||
        device=None,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ):
 | 
			
		||||
        factory_kwargs = {'device': device, 'dtype': dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.process_group = process_group
 | 
			
		||||
        self.tp_world_size = process_group.size()
 | 
			
		||||
 | 
			
		||||
        assert in_features % self.tp_world_size == 0
 | 
			
		||||
 | 
			
		||||
        self.in_features = in_features // self.tp_world_size
 | 
			
		||||
        self.out_features = out_features
 | 
			
		||||
        # We change from traditional `nn.Linear` and remove unecessary `torch.Tensor.transpose` operation
 | 
			
		||||
        self.weight = nn.Parameter(torch.empty((self.in_features, self.out_features), **factory_kwargs))
 | 
			
		||||
        if bias:
 | 
			
		||||
            self.bias = nn.Parameter(torch.empty(self.out_features, **factory_kwargs))
 | 
			
		||||
        else:
 | 
			
		||||
            self.register_parameter('bias', None)
 | 
			
		||||
 | 
			
		||||
        self.reset_parameters()
 | 
			
		||||
 | 
			
		||||
    def reset_parameters(self) -> None:
 | 
			
		||||
        """From `torch.nn.Linear`"""
 | 
			
		||||
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
 | 
			
		||||
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
 | 
			
		||||
        # https://github.com/pytorch/pytorch/issues/57109
 | 
			
		||||
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
 | 
			
		||||
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
 | 
			
		||||
            nn.init.uniform_(self.bias, -bound, bound)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.jit.script
 | 
			
		||||
    def linear(input: torch.Tensor, weight: torch.Tensor, bias:torch.Tensor):
 | 
			
		||||
        # # Note the the unsharded equivalent requires us to sum over bias instead of averaging.
 | 
			
		||||
        # in_features, out_features = weight.shape
 | 
			
		||||
        # size_out = input.size()[:-1] + (out_features,)
 | 
			
		||||
        # # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
 | 
			
		||||
        # input = input.view(-1, in_features)
 | 
			
		||||
        # # with torch.jit.strict_fusion():
 | 
			
		||||
        # out =  torch.addmm(bias, input, weight)
 | 
			
		||||
        # return out.view(size_out)
 | 
			
		||||
 | 
			
		||||
        in_features, out_features = weight.shape
 | 
			
		||||
        size_out = input.size()[:-1] + (out_features,)
 | 
			
		||||
        # TODO @thomasw21: when using torch.jit.script, `addmm` is decomposed to `add + mm`
 | 
			
		||||
        input = input.view(-1, in_features)
 | 
			
		||||
        # HACK @thomas21: turns out `aten::addmm.out` is not decomposed
 | 
			
		||||
        out = torch.empty((0,), device=input.device, dtype=input.dtype)
 | 
			
		||||
        out = torch.addmm(bias, input, weight, out=out.view(-1, out_features))
 | 
			
		||||
 | 
			
		||||
        return out.view(size_out)
 | 
			
		||||
 | 
			
		||||
        # return F.linear(input, weight=weight.transpose(1,0), bias=bias)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        out = self.linear(input, weight=self.weight, bias=self.bias)
 | 
			
		||||
        torch.distributed.all_reduce(out, group=self.process_group)
 | 
			
		||||
 | 
			
		||||
        # ### DEBUG @thomasw21:: Check that shard model output the same as the non sharded version
 | 
			
		||||
        # sharded_out = out
 | 
			
		||||
        #
 | 
			
		||||
        # input_from_tp_ranks = [torch.empty_like(input) for _ in range(self.process_group.size())]
 | 
			
		||||
        # weight_from_tp_ranks = [torch.empty_like(self.weight) for _ in range(self.process_group.size())]
 | 
			
		||||
        # bias = self.bias.clone()
 | 
			
		||||
        # torch.distributed.all_gather(input_from_tp_ranks, input, group=self.process_group)
 | 
			
		||||
        # torch.distributed.all_gather(weight_from_tp_ranks, self.weight, group=self.process_group)
 | 
			
		||||
        # torch.distributed.all_reduce(bias, group=self.process_group)
 | 
			
		||||
        # input = torch.cat(input_from_tp_ranks, dim=-1)
 | 
			
		||||
        # weight = torch.cat(weight_from_tp_ranks, dim=1)
 | 
			
		||||
        # baseline_out = F.linear(input, weight, bias)
 | 
			
		||||
        #
 | 
			
		||||
        # if self.process_group.rank() == 0:
 | 
			
		||||
        #     torch.testing.assert_close(bias, self.bias, atol=0.0, rtol=0.0)
 | 
			
		||||
        # torch.distributed.barrier(self.process_group)
 | 
			
		||||
        # # torch.testing.assert_close(sharded_out, baseline_out)
 | 
			
		||||
        # ###
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def extra_repr(self) -> str:
 | 
			
		||||
        """From `torch.nn.Linear`"""
 | 
			
		||||
        return 'in_features={}, out_features={}, bias={}'.format(
 | 
			
		||||
            self.in_features, self.out_features, self.bias is not None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
class TensorParallelEmbedding(nn.Embedding):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        num_embeddings,
 | 
			
		||||
        embedding_dim,
 | 
			
		||||
        process_group: torch.distributed.ProcessGroup,
 | 
			
		||||
        padding_idx=None,
 | 
			
		||||
        max_norm=None,
 | 
			
		||||
        norm_type=2.0,
 | 
			
		||||
        scale_grad_by_freq=False,
 | 
			
		||||
        sparse=False,
 | 
			
		||||
        _weight=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
        dtype=None
 | 
			
		||||
    ):
 | 
			
		||||
        self.process_group = process_group
 | 
			
		||||
        self.tp_rank = process_group.rank()
 | 
			
		||||
        self.tp_world_size = process_group.size()
 | 
			
		||||
 | 
			
		||||
        self.original_num_embeddings = num_embeddings
 | 
			
		||||
 | 
			
		||||
        # TODO @thomasw21 fix and remove that constraint
 | 
			
		||||
        assert num_embeddings % self.tp_world_size == 0
 | 
			
		||||
        block_size = num_embeddings // self.tp_world_size
 | 
			
		||||
        # inputs in `[min_id, max_id[` are handled by `self` to get embeddings
 | 
			
		||||
        self.min_id = self.tp_rank * block_size
 | 
			
		||||
        self.max_id = (self.tp_rank + 1) * block_size
 | 
			
		||||
 | 
			
		||||
        super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # Sanity check
 | 
			
		||||
        if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)):
 | 
			
		||||
            raise IndexError(f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}")
 | 
			
		||||
 | 
			
		||||
        # `0` if input is in the correct interval, else `1`
 | 
			
		||||
        input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
 | 
			
		||||
        # translate for [0, self.max_id - self.min_id[
 | 
			
		||||
        input = input - self.min_id
 | 
			
		||||
        # default all out of bounds values to `0`
 | 
			
		||||
        input[input_mask] = 0
 | 
			
		||||
        out = super().forward(input)
 | 
			
		||||
        out[input_mask] = 0.0
 | 
			
		||||
        torch.distributed.all_reduce(out, group=self.process_group)
 | 
			
		||||
        return out
 | 
			
		||||
		Reference in New Issue
	
	Block a user