DeepSpeed4Science (#4357)

* zero++ tutorial PR (#3783)

* [Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)

* fix conv_flops_compute when padding is a str when stride=1

* fix error

* change type of paddings to tuple

* fix padding calculation

* apply formatting check

---------

Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* fix interpolate flops compute (#3782)

* use `Flops Profiler` to test `model.generate()` (#2515)

* Update profiler.py

* pre-commit run --all-files

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* revert PR #3611 (#3786)

* bump to 0.9.6

* ZeRO++ chinese blog (#3793)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* remove staging trigger (#3792)

* DeepSpeed-Triton for Inference (#3748)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO++ (#3784)

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>

* adding zero++ to navigation panel of deepspeed.ai (#3796)

* Add ZeRO++ Japanese blog (#3797)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* add ZeRO++ Japanese blog

* add links

---------

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>

* Bug Fixes for autotuner and flops profiler (#1880)

* fix autotuner when backward is not called

* fix format

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Missing strided copy for gated MLP (#3788)

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* Requires grad checking. (#3789)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* bump to 0.10.0

* Fix Bug in transform.cu (#3534)

* Bug fix

* Fixed formatting error

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* bug fix: triton importing error (#3799)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* DeepSpeed4Science (#569)

* Integrating evoformer attention

* add cutlass version check

* Updaate error message

* add benchmark

* Update

* Update evoformer_attn.py

* Update run_evoformer_test.py

* Update evoformer_attn.py

* Update run_evoformer_test.py

* support more GPU archs

* add copyright

* add tests

* Fix bugs

* Update benchmark

* update

* Fix nvcc macro

* clean code

* fix formatting

* fix yaml import

* skip unit test when not compatible

* fix yaml requirement

* revert changes

* update tutorial

* update

* fix formatting

* fix format

* skip evoformer attn in pre-compile-ops

* revert changes

* update tutorial

* fix cutlass check

* update tutorial

* refactor tutorial

* revise

* Updated the Megatron-DS section (#565)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* separate evoformer tutorial

* Revised the ds4science landing page (#566)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* Revised the landing page

* Revised the landing page

* Removing unused file

* fix links image position

* modify main page

* fix doc

---------

Co-authored-by: Shiyang Chen <csycfl@gmail.com>
Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>

---------

Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Bill Luo <50068224+zhiruiluo@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Guorun <84232793+CaffreyR@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: stephen youn <13525892+stephen-youn@users.noreply.github.com>
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com>
Co-authored-by: Shiyang Chen <csycfl@gmail.com>
Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>
This commit is contained in:
Conglong Li
2023-09-18 15:16:08 -07:00
committed by GitHub
parent 367d6f9cec
commit f876d81d34
42 changed files with 15421 additions and 7 deletions

View File

@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 pip3 install .
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report

View File

@ -15,11 +15,11 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
---
@ -35,9 +35,9 @@
---
# DeepSpeed's three innovation pillars
# DeepSpeed's four innovation pillars
<img src="docs/assets/images/3pillars.png" width="800px">
<img src="docs/assets/images/DeepSpeed-pillars.png" width="800px">
## DeepSpeed-Training
@ -53,6 +53,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor,
To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression)
## DeepSpeed4Science
In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](https://www.deepspeed.ai/deepspeed4science/)
---
# DeepSpeed Software Suite

View File

@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h>
void attention_impl(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse);
void attention(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse)
{
attention_impl(q, k, v, bias1, bias2, o, lse);
}
void attention_back_impl(torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2);
void attention_bwd(torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2)
{
attention_back_impl(go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("attention", &attention, "");
m.def("attention_bwd", &attention_bwd, "");
}

View File

@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "gemm_kernel_utils.h"
#include "kernel_forward.h"
#include "transform/bias_broadcast.h"
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
float* lse_ptr)
{
EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
}
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_impl_template(
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
float* lse_ptr)
{
// Attention definition goes here, replaced with BroadcastType1 and
// BroadcastType2
using Attention = AttentionKernel<scalar_t, /* scalar_t */
arch, /* ArchTag */
true, /* Memory is aligned */
64,
64,
true,
true, /* Supports bias */
Broadcast1_,
Broadcast2_>;
static_assert(!Attention::kNeedsOutputAccumulatorBuffer,
"This test does not support output accumulator buffer");
int head_size = q.size(-1);
int head_number = q.size(-2);
int seq_length = q.size(-3);
auto q_view = q.view({-1, seq_length, head_number, head_size});
auto k_view = k.view({-1, seq_length, head_number, head_size});
auto v_view = v.view({-1, seq_length, head_number, head_size});
auto o_view = o.view({-1, seq_length, head_number, head_size});
int batch_size = q_view.size(0);
auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
typename Attention::Params p;
{ // set parameters
p.query_ptr = q_ptr;
p.key_ptr = k_ptr;
p.value_ptr = v_ptr;
p.logsumexp_ptr = lse_ptr; // Only needed for bw
p.output_accum_ptr = nullptr;
p.output_ptr = o_ptr;
p.scale = 1.0f / sqrt(float(head_size));
p.bias1_ptr = bias1_ptr;
p.bias2_ptr = bias2_ptr;
p.B = q.size(0);
p.N = q.size(1);
p.num_heads = head_number;
p.num_batches = batch_size;
p.head_dim = head_size;
p.head_dim_value = head_size;
p.num_queries = seq_length;
p.num_keys = seq_length;
// All tensors are in BMHK shapes
p.q_strideH = q_view.stride(-2);
p.k_strideH = k_view.stride(-2);
p.v_strideH = v_view.stride(-2);
p.q_strideM = q_view.stride(-3);
p.k_strideM = k_view.stride(-3);
p.v_strideM = v_view.stride(-3);
p.o_strideM = o_view.stride(-3);
p.q_strideB = q_view.stride(-4);
p.k_strideB = k_view.stride(-4);
p.v_strideB = v_view.stride(-4);
}
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
if (!Attention::check_supported(p)) { throw std::runtime_error("Parameters not supported"); }
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}
#define CODE(scalar_t, torch_scalar_t) \
do { \
if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
} else if (bias1.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastB>(q, k, v, bias1, bias2, o, lse_ptr); \
} else if (bias2.size(0) == 0) { \
attention_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastA, \
BroadcastNoLoad>(q, k, v, bias1, bias2, o, lse_ptr); \
} else { \
attention_impl_template<ArchTag, scalar_t, torch_scalar_t, BroadcastA, BroadcastB>( \
q, k, v, bias1, bias2, o, lse_ptr); \
} \
} while (0)
// Function to select and call the correct template based on biases sizes
void attention_impl(torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& o,
torch::Tensor& lse)
{
auto lse_ptr = lse.size(0) == 0 ? nullptr : reinterpret_cast<float*>(lse.data_ptr<float>());
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); })));
}

View File

@ -0,0 +1,218 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <type_traits>
#include "gemm_kernel_utils.h"
#include "kernel_backward.h"
#include "transform/bias_broadcast.h"
constexpr auto kBlockSizeI = 64;
constexpr auto kBlockSizeJ = 64;
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<!CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2)
{
EVOFORMER_CHECK(false, "Unsupported GPU and data type combination")
}
template <typename arch,
typename scalar_t,
typename torch_scalar_t,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
typename std::enable_if<CheckArch<arch, scalar_t>::value>::type attention_back_impl_template(
torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2)
{
constexpr bool kPreload_ = arch::kMinComputeCapability >= 80;
using Kernel = AttentionBackwardKernel<arch,
scalar_t, // scalar_t
true, // kIsAligned_
false, // kApplyDropout_
kPreload_, // kPreload_
kBlockSizeI, // kBlockSizeI_,
kBlockSizeJ, // kBlockSizeJ_,
64, // kMaxK
Broadcast1_,
Broadcast2_>;
int head_size = q.size(-1);
int head_number = q.size(-2);
int seq_length = q.size(-3);
auto q_view = q.view({-1, seq_length, head_number, head_size});
auto k_view = k.view({-1, seq_length, head_number, head_size});
auto v_view = v.view({-1, seq_length, head_number, head_size});
auto o_view = o.view({-1, seq_length, head_number, head_size});
auto do_view = go.view({-1, seq_length, head_number, head_size});
auto dk_view = gk.view({-1, seq_length, head_number, head_size});
auto dv_view = gv.view({-1, seq_length, head_number, head_size});
auto dq_view = gq.view({-1, seq_length, head_number, head_size});
auto q_ptr = reinterpret_cast<scalar_t*>(q.data_ptr<torch_scalar_t>());
auto k_ptr = reinterpret_cast<scalar_t*>(k.data_ptr<torch_scalar_t>());
auto v_ptr = reinterpret_cast<scalar_t*>(v.data_ptr<torch_scalar_t>());
auto o_ptr = reinterpret_cast<scalar_t*>(o.data_ptr<torch_scalar_t>());
auto do_ptr = reinterpret_cast<scalar_t*>(go.data_ptr<torch_scalar_t>());
auto dk_ptr = reinterpret_cast<scalar_t*>(gk.data_ptr<torch_scalar_t>());
auto dv_ptr = reinterpret_cast<scalar_t*>(gv.data_ptr<torch_scalar_t>());
auto dq_ptr = reinterpret_cast<scalar_t*>(gq.data_ptr<torch_scalar_t>());
auto db1_ptr = gb1.size(0) > 0 ? reinterpret_cast<float*>(gb1.data_ptr<float>()) : nullptr;
auto db2_ptr = gb2.size(0) > 0 ? reinterpret_cast<float*>(gb2.data_ptr<float>()) : nullptr;
auto lse_ptr = reinterpret_cast<float*>(lse.data_ptr<float>());
auto delta_ptr = reinterpret_cast<float*>(delta.data_ptr<float>());
auto bias1_ptr = reinterpret_cast<scalar_t*>(bias1.data_ptr<torch_scalar_t>());
auto bias2_ptr = reinterpret_cast<scalar_t*>(bias2.data_ptr<torch_scalar_t>());
static_assert(Kernel::kKernelComputesDelta, "Kernel must compute delta");
typename Kernel::Params p;
p.query_ptr = q_ptr;
p.key_ptr = k_ptr;
p.value_ptr = v_ptr;
p.logsumexp_ptr = lse_ptr;
p.output_ptr = o_ptr;
p.grad_output_ptr = do_ptr;
p.delta_ptr = delta_ptr;
p.grad_query_ptr = dq_ptr;
p.grad_key_ptr = dk_ptr;
p.grad_value_ptr = dv_ptr;
p.grad_bias1_ptr = db1_ptr;
p.grad_bias2_ptr = db2_ptr;
p.B = q.size(0);
p.N = q.size(1);
p.bias1_ptr = bias1.size(0) ? bias1_ptr : nullptr;
p.bias2_ptr = bias2.size(0) ? bias2_ptr : nullptr;
p.scale = 1.0f / sqrtf(head_size);
p.head_dim = head_size;
p.head_dim_value = head_size;
p.num_queries = seq_length;
p.num_keys = seq_length;
p.num_heads = head_number;
p.q_strideM = q_view.stride(-3);
p.k_strideM = k_view.stride(-3);
p.v_strideM = v_view.stride(-3);
p.gO_strideM = do_view.stride(-3);
p.o_strideH = o_view.stride(-2);
p.q_strideH = q_view.stride(-2);
p.k_strideH = k_view.stride(-2);
p.v_strideH = v_view.stride(-2);
p.o_strideB = o_view.stride(-4);
p.q_strideB = q_view.stride(-4);
p.k_strideB = k_view.stride(-4);
p.v_strideB = v_view.stride(-4);
p.lse_strideB = lse.stride(-3);
p.lse_strideH = lse.stride(-2);
p.delta_strideB = delta.stride(-3);
p.delta_strideH = delta.stride(-2);
p.num_batches = q_view.size(-4);
p.gO_strideB = do_view.stride(-4);
p.gQ_strideB = dq_view.stride(-4);
p.gK_strideB = dk_view.stride(-4);
p.gV_strideB = dv_view.stride(-4);
p.gO_strideH = do_view.stride(-2);
p.gQ_strideH = dq_view.stride(-2);
p.gK_strideH = dk_view.stride(-2);
p.gV_strideH = dv_view.stride(-2);
torch::Tensor workspace = torch::empty(p.workspace_size() / 4, lse.options());
p.workspace = workspace.data_ptr<float>();
auto kernel_fn = attention_kernel_backward_batched_impl<Kernel>;
size_t smem_bytes = sizeof(typename Kernel::SharedStorage);
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes));
if (!Kernel::check_supported(p)) { throw std::runtime_error("Unsupported parameters"); }
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
}
#define CODE(scalar_t, torch_scalar_t) \
do { \
if (bias1.size(0) == 0 && bias2.size(0) == 0) { \
attention_back_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastNoLoad>( \
go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
} else if (bias1.size(0) > 0 && bias2.size(0) > 0) { \
attention_back_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastA, \
BroadcastB>( \
go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
} else if (bias1.size(0) > 0) { \
attention_back_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastA, \
BroadcastNoLoad>( \
go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
} else { \
attention_back_impl_template<ArchTag, \
scalar_t, \
torch_scalar_t, \
BroadcastNoLoad, \
BroadcastB>( \
go, q, k, v, o, lse, delta, bias1, bias2, gq, gk, gv, gb1, gb2); \
} \
} while (0)
void attention_back_impl(torch::Tensor& go,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
torch::Tensor& o,
torch::Tensor& lse,
torch::Tensor& delta,
torch::Tensor& bias1,
torch::Tensor& bias2,
torch::Tensor& gq,
torch::Tensor& gk,
torch::Tensor& gv,
torch::Tensor& gb1,
torch::Tensor& gb2)
{
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
DISPATCH_ARCHTAG(prop->major * 10 + prop->minor,
DISPATCH_TYPES(q, ([&]() { CODE(scalar_t, torch_scalar_t); })));
}

View File

@ -0,0 +1,250 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
#include <cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h>
#include "../iterators/predicated_tile_iterator_atomic.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <int Rank,
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess>
struct EpilogueTensorOpAffineRankN : public DefaultEpilogueTensorOpAffineRankN<Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using Base = DefaultEpilogueTensorOpAffineRankN<Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess>;
using OutputTileIterator =
cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
typename Base::OutputTileThreadMap,
typename Base::ElementOutput,
Rank>;
using Epilogue =
cutlass::epilogue::threadblock::Epilogue<typename Base::Shape,
typename Base::WarpMmaTensorOp,
Base::kPartitionsK,
OutputTileIterator,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
typename Base::OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration>;
};
template <int Rank,
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess>
struct EpilogueVoltaTensorOpAffineRankN
: public DefaultEpilogueVoltaTensorOpAffineRankN<Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using Base = DefaultEpilogueVoltaTensorOpAffineRankN<Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess>;
using OutputTileIterator =
cutlass::epilogue::threadblock::PredicatedTileIteratorAffineRankNAtomic<
typename Base::OutputTileThreadMap,
typename Base::ElementOutput,
Rank>;
using Epilogue =
cutlass::epilogue::threadblock::Epilogue<typename Base::Shape,
typename Base::WarpMmaTensorOp,
Base::kPartitionsK,
OutputTileIterator,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
typename Base::OutputOp,
typename Base::Padding>;
};
template <typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute>
struct EpilogueTensorOp : public DefaultEpilogueTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout> {
using Base = DefaultEpilogueTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout>;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
typename Base::OutputTileThreadMap,
typename Base::ElementOutput,
ScatterD,
PermuteDLayout>;
using Epilogue =
cutlass::epilogue::threadblock::Epilogue<typename Base::Shape,
typename Base::WarpMmaTensorOp,
Base::kPartitionsK,
OutputTileIterator,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
typename Base::OutputOp,
typename Base::Padding,
Base::kFragmentsPerIteration>;
};
template <typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute>
struct EpilogueVoltaTensorOp : public DefaultEpilogueVoltaTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout> {
using Base = DefaultEpilogueVoltaTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout>;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorAtomic<
typename Base::OutputTileThreadMap,
typename Base::ElementOutput,
ScatterD,
PermuteDLayout>;
using Epilogue =
cutlass::epilogue::threadblock::Epilogue<typename Base::Shape,
typename Base::WarpMmaTensorOp,
Base::kPartitionsK,
OutputTileIterator,
typename Base::AccumulatorFragmentIterator,
typename Base::WarpTileIterator,
typename Base::SharedLoadIterator,
typename Base::OutputOp,
typename Base::Padding>;
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
template <typename Arch_,
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD = false,
typename PermuteDLayout = cutlass::layout::NoPermute>
struct BiasGradEpilogue {
using Epilogue =
typename cutlass::epilogue::threadblock::EpilogueTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout>::Epilogue;
};
template <typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess,
bool ScatterD,
typename PermuteDLayout>
struct BiasGradEpilogue<cutlass::arch::Sm70,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout> {
using Epilogue =
typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOp<Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess,
ScatterD,
PermuteDLayout>::Epilogue;
};
template <typename Arch_,
int Rank,
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess>
struct BiasGradEpilogueAffineRankN {
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueTensorOpAffineRankN<
Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess>::Epilogue;
};
template <int Rank,
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess>
struct BiasGradEpilogueAffineRankN<cutlass::arch::Sm70,
Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueVoltaTensorOpAffineRankN<
Rank,
Shape_,
WarpMmaTensorOp_,
PartitionsK,
OutputOp_,
ElementsPerAccess>::Epilogue;
};

View File

@ -0,0 +1,592 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
File copied from "cutlass/epilogue/threadblock/epilogue.h"
then modified to:
(1) load 2 source fragments at the same time (pipelining)
(2) support reading from a different dtype
(3) pass the row id to the OutputOp if it takes it
(see MemoryEfficientAttentionNormalize)
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename Op>
struct ApplyEpilogueOp {
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentOutput const& source)
{
return output_op(accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput
apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
{
return output_op(accum);
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
///< gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting
///< accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
///< accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading
///< from SMEM
typename OutputOp_, ///< Output operator
typename Padding_, ///< Padding added to SMEM allocation to avoid bank
///< conflicts (concept: MatrixShape)
int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is
///< large
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
typename OutputTileSourceIterator_ =
OutputTileIterator_ ///< Tile iterator reading tensors
>
class EpiloguePipelined : public EpilogueBase<Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using OutputTileSourceIterator = OutputTileSourceIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
using ElementSource = typename OutputTileSourceIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType =
Array<typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
using SourceAccessType = Array<typename OutputTileSourceIterator::Element,
OutputTileSourceIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType =
Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
static int constexpr kSmemTiles =
Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK;
static int constexpr kSmemPointerOffset =
Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(OutputTileSourceIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between input tile and output tile iterator (kElements)");
static_assert(OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
"Mismatch between input tile and output tile iterator (kIterations)");
static_assert(SharedLoadIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements %
OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
public:
/// Constructor
CUTLASS_DEVICE
EpiloguePipelined(typename Base::SharedStorage& shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx)
{
}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(OutputOp const& output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator source_iterator)
{ ///< Threadblock tile coordinate in GEMM (in units
///< of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
} else {
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
}
}
CUTLASS_DEVICE
void operator()(OutputOp const& output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const& accumulators)
{ ///< Complete warp-level accumulator tile
compute_source_not_needed_(output_op, destination_iterator, accumulators);
}
private:
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator)
{
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(iterator_begin, warp_tile_iterator),
0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
"One of these must be exactly 1.");
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const& accumulators ///< Complete warp-level accumulator tile
)
{
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
: 1)
for (int iter = 0; iter < OutputTileIterator::kIterations;
iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<cutlass::make_index_sequence<
OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
} else if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] =
add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) *
kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset *
(1 - Base::kFragmentsPerIteration));
}
}
}
template <class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; }
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator)
{
int dummy[] = {(pos == Seq) && (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units
///< of threadblock tiles)
)
{
typename OutputTileSourceIterator::Fragment source_fragment[2];
source_fragment[0].clear();
source_iterator.load(source_fragment[0]);
++source_iterator;
source_fragment[1].clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
if (iter > 0) { __syncthreads(); }
//
// Load the source for next iteration (pipelining)
//
if (iter + 1 < OutputTileIterator::kIterations) {
source_iterator.load(source_fragment[(iter + 1) % 2]);
}
++source_iterator;
acc2smem_source_needed<cutlass::make_index_sequence<OutputTileIterator::kIterations>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the
// k-slices
if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] =
add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0],
source_fragment[iter % 2]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
typename OutputTileSourceIterator::Fragment const& source_fragment)
{
OutputAccessType* output_frag_ptr = reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
SourceAccessType const* source_frag_ptr =
reinterpret_cast<SourceAccessType const*>(&source_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i],
source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment)
{
OutputAccessType* output_frag_ptr = reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i]);
}
}
// This should be constexpr, but it's only supported on c++14
static int CUTLASS_HOST_DEVICE getRowOffset(int i)
{
using ThreadMap = typename OutputTileIterator::ThreadMap;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
int frag_idx = ThreadMap::kElementsPerAccess *
(frag_row_idx * ThreadMap::Iterations::kColumn + column);
if (i < frag_idx + ThreadMap::kElementsPerAccess) { return row_offset; }
}
}
}
}
return -1;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,251 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
handle "row_id" as a first argument, as uses it to get the corresponding
`m_prime` / `s_prime` to rescale the output.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "epilogue_pipelined.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
// output <- alpha * accumulator + beta * source
// with:
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
// beta = alpha / m_prime (renormalize the output when the max changes)
// source is the current output
template <typename ElementOutput_, ///< Data type used to store tensors
typename ElementSource_, //< Data type for source (usually matches
//`ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data
///< to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime)
{
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return !isFirst; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const
{
assert(!isFirst);
// Convert source to internal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) const
{
assert(isFirst);
// Convert source to internal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
} // namespace thread
namespace threadblock {
template <typename EO,
typename ES,
int Count,
typename EA,
typename EC,
bool F,
bool L,
typename FAB,
FloatRoundStyle R>
struct ApplyEpilogueOp<
thread::MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>> {
using Op = thread::MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentSource const& source)
{
return output_op(row_id, accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput
apply(Op const& output_op, int row_id, typename Op::FragmentAccumulator const& accum)
{
return output_op(row_id, accum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,168 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include <cuda_fp16.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const& input) const
{
Array<Element, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) { result[i] = expf(input[i]); }
return result;
}
};
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
CUTLASS_DEVICE
Array<half_t, ElementsPerAccess> operator()(Array<half_t, ElementsPerAccess> const& input) const
{
Array<half_t, ElementsPerAccess> result;
int const kVectorCount = ElementsPerAccess / 2;
__half2 const* input_ptr = reinterpret_cast<__half2 const*>(input.raw_data());
__half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) { res_ptr[i] = h2exp(input_ptr[i]); }
return result;
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies:
/// output <- (input - lse).exp()
template <typename ElementOutput_, // output
typename ElementLSE_, // accumulator from LSE
typename ElementAccumulator_, // accumulator from matmul
typename ElementCompute_, // intermediate compute (and exp calculation)
int ElementsPerAccess>
class ApplyLogSumExp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
public:
//
// Methods
//
CUTLASS_HOST_DEVICE
ApplyLogSumExp() {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const& AB,
FragmentLSE const& scale_unused,
// bias used as LSE
FragmentLSE const& bias) const
{
FragmentCompute frag_AB =
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess>()(AB);
FragmentCompute frag_lse_compute =
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(bias);
FragmentCompute frag_compute;
minus<FragmentCompute> minus_lse;
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
frag_compute = minus_lse(frag_AB, frag_lse_compute);
frag_compute = apply_exp(frag_compute);
return NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess>()(
frag_compute);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,119 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "custom_mma_multistage.h"
#include "custom_mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
template <typename Mma, int kMaxK>
struct MakeCustomMma;
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
cutlass::arch::CacheOperation::Kind CacheOpA,
typename IteratorB,
typename SmemIteratorB,
cutlass::arch::CacheOperation::Kind CacheOpB,
typename ElementC,
typename LayoutC,
typename Policy,
int Stages,
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
int kMaxK>
struct MakeCustomMma<cutlass::gemm::threadblock::MmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
Stages,
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: cutlass::const_min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
kStages,
SharedMemoryClear,
kMaxK>;
};
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
typename IteratorB,
typename SmemIteratorB,
typename ElementC,
typename LayoutC,
typename Policy,
int kMaxK>
struct MakeCustomMma<cutlass::gemm::threadblock::MmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>,
kMaxK> {
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>;
};

View File

@ -0,0 +1,181 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount =
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template <typename Element, typename OperandShape, typename OperandLayout>
struct OperandSharedStorage {
AlignedBuffer<Element, OperandShape::kCount> buffer;
using TensorRef = TensorRef<Element, OperandLayout>;
CUTLASS_DEVICE
static OperandLayout Layout()
{
return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
}
/// Returns a TensorRef to the operand
CUTLASS_HOST_DEVICE
TensorRef ref() { return TensorRef{buffer.data(), Layout()}; }
};
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
using SharedStorageA =
OperandSharedStorage<typename Operator::ElementA, ShapeA, typename Operator::LayoutA>;
using SharedStorageB =
OperandSharedStorage<typename Operator::ElementB, ShapeB, typename Operator::LayoutB>;
using TensorRefA = typename SharedStorageA::TensorRef;
using TensorRefB = typename SharedStorageB::TensorRef;
struct SharedStorage {
/// Buffer for A operand
SharedStorageA operand_A;
/// Buffer for B operand
SharedStorageB operand_B;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorageA& shared_storageA,
SharedStorageB& shared_storageB,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
warp_tile_iterator_B_(shared_storageB.ref(), lane_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,706 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "custom_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Upper boundon the K dimension
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
/// Used for partial specialization
typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
};
static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
static constexpr int kNumStagesConcurrentLoad = kSmemContainsEntireMat ? Stages : Stages - 1;
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
bool prologue_done_;
// Set to `True` to ensure the accumulator will be zero outside the GEMM
// footprint
bool zero_outside_bounds_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx),
prologue_done_(false),
zero_outside_bounds_(false)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaMultistage(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
{
}
CUTLASS_DEVICE
bool set_prologue_done(bool value) { prologue_done_ = value; }
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k)
{
prologue<kLoadA, kLoadB>(shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k)
{
SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
_prologue<kLoadA, kLoadB>(iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA& iterator_A,
IteratorB& iterator_B,
int group_start_A = 0,
int group_start_B = 0)
{
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void _prologue(IteratorA& iterator_A,
IteratorB& iterator_B,
int32_t& gemm_k_iterations,
SmemIteratorA& smem_iterator_A_,
SmemIteratorB& smem_iterator_B_)
{
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad; ++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
if (kLoadA) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
}
++iterator_A;
}
++smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (kLoadB) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
}
++iterator_B;
}
++smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
smem_iterator_A_.add_tile_offset({0, 1});
smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
if (!prologue_done_) {
_prologue<true, true>(
iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
} else if (!kSmemContainsEntireMat) {
_prologue<false, false>(
iterator_A, iterator_B, gemm_k_iterations, smem_iterator_A_, smem_iterator_B_);
} else {
gemm_k_iterations -= kNumStagesConcurrentLoad;
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint
// are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared
/// memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared
/// memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B[2];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(warp_transformed_frag_A[0],
warp_transformed_frag_B[0],
warp_loaded_frag_A[0],
warp_loaded_frag_B[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
// accumulator and this temporary accumulator is added to the final
// accumulator once in every mainloop iteration.
plus<FragmentC> plus_accum;
FragmentC tmp_accum;
if (platform::is_same<typename Operator::MathOperator, arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
tmp_accum.clear();
}
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
// In case of a non-circular buffer ("kSmemContainsEntireMat")
// make sure we don't load out of bounds data.
if (!kSmemContainsEntireMat || gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
warp_mma_k < Base::kWarpGemmIterations - 1) {
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k > 0)
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B[warp_mma_k % 2]);
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
warp_mma(tmp_accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
tmp_accum);
if (warp_mma_k == 0) {
accum = plus_accum(accum, tmp_accum);
tmp_accum.clear();
}
} else {
warp_mma(accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
accum);
}
// Issue global->shared copies for the this stage
if (!kSmemContainsEntireMat && warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(
iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
if (!kSmemContainsEntireMat) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (!kSmemContainsEntireMat && smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
}
if (platform::is_same<typename Operator::MathOperator, arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,388 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "custom_mma_base.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ = NumericArrayConverter<typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to B operand
typename TransformB_ = NumericArrayConverter<typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, 2>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "MmaPipelined requires kStages set to value 2");
static bool const kSmemContainsEntireMat = false;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaPipelined(st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx)
{
}
CUTLASS_DEVICE
bool set_prologue_done(bool value)
{
// NOT IMPLEMENTED FOR PIPELINED
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value)
{
// NOT NEEDED FOR PIPELINED
// shared memory will always be zero-filled
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorage& shared_storage,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k)
{
prologue<kLoadA, kLoadB>(shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(typename Base::SharedStorageA& shared_storageA,
typename Base::SharedStorageB& shared_storageB,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k)
{
// NOT IMPLEMENTED FOR PIPELINED
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const& src_accum, ///< source accumulator tile
TransformA transform_A = TransformA(), ///< transformation applied to A fragment
TransformB transform_B = TransformB())
{ ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,191 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instantiate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
so we just copy-pasted some code from `default_mma.h` and
`default_mma_core.h` files and wrapped this template to allow our usecase.
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Layout type for C and D matrix operand
typename LayoutC,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {
static constexpr bool AccumulatorsInRowMajor = false;
static constexpr SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone;
using DefaultMma = cutlass::gemm::threadblock::DefaultMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
Operator,
AccumulatorsInRowMajor,
SharedMemoryClear>;
};
/// Specialization for sm80 / FastF32 / multistage with kStages=2
template <typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
int kStages,
typename Operator>
struct FindDefaultMma<ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
using LayoutC = layout::RowMajor;
using OperatorClass = arch::OpClassTensorOp;
using ArchTag = arch::Sm80;
using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
3,
Operator>;
struct DefaultMma : DefaultMma_ {
using MmaCore_ = typename DefaultMma_::MmaCore;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore_::Shape,
typename DefaultMma_::IteratorA,
typename MmaCore_::SmemIteratorA,
MmaCore_::kCacheOpA,
typename DefaultMma_::IteratorB,
typename MmaCore_::SmemIteratorB,
MmaCore_::kCacheOpB,
ElementAccumulator,
LayoutC,
typename MmaCore_::MmaPolicy,
kStages>;
};
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,347 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm80 {
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
{
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(
quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn);
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow)
{
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile +
lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col +
lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm70 {
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial =
typename cutlass::platform::conditional<cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
{
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
static_assert(cutlass::platform::is_same<Element, float>::value,
"update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow)
{
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2 + n +
lane_offset.column();
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset,
FA beginRow,
FB op,
FC endRow)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx = n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset)
{
static_assert(cutlass::platform::is_same<typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset =
lane_layout.inverse(lane_id) *
cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaSimtTileIterator<S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>,
accum_t,
kWarpSize> {
using WarpIterator =
typename cutlass::gemm::warp::MmaSimtTileIterator<S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template <typename S1, typename S2, typename S3, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::
MmaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, S3>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::
MmaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, S3>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,242 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <type_traits>
#include "cutlass/arch/mma.h"
template <typename arch, typename scalar_t>
struct CheckArch {
static constexpr bool isPreVolta = arch::kMinComputeCapability < 70;
static constexpr bool isPreAmpere =
arch::kMinComputeCapability < 80 && arch::kMinComputeCapability >= 70;
static constexpr bool isAmpere = arch::kMinComputeCapability >= 80;
#if defined(__CUDA_ARCH__)
static constexpr bool compiler_cc = arch::kMinComputeCapability * 10 <= __CUDA_ARCH__;
#else
static constexpr bool compiler_cc = true;
#endif
static constexpr bool value = (isPreVolta && std::is_same_v<scalar_t, float>) ||
(isPreAmpere && !std::is_same_v<scalar_t, cutlass::bfloat16_t>) ||
isAmpere && compiler_cc;
};
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \
using ArchTag = cutlass::arch::Sm80; \
func; \
} else if (CC >= 75) { \
using ArchTag = cutlass::arch::Sm75; \
func; \
} else if (CC >= 70) { \
using ArchTag = cutlass::arch::Sm70; \
func; \
} else { \
EVOFORMER_CHECK(false, "Only GPUs with Tensor Core are supported for now"); \
} \
}
#define DISPATCH_TYPES(tensor, func) \
{ \
if (tensor.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
using torch_scalar_t = at::Half; \
func(); \
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
using torch_scalar_t = at::BFloat16; \
func(); \
} else { \
EVOFORMER_CHECK(false, "Only fp16 and bf16 supported at the moment"); \
} \
}
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
F(); \
} \
}
#ifdef TORCH_CHECK
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
EVOFORMER_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define EVOFORMER_CHECK TORCH_CHECK
#elif defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { return false; }
#define EVOFORMER_CHECK(COND, ERR) \
if (!(COND)) { return false; }
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define EVOFORMER_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "[Evoformer Attention]" \
<< "'" #COND "' failed: " << ERR << "\n"; \
return false; \
}
#endif
namespace gemm_kernel_utils {
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m)
{
return (n + m - 1) / m;
}
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m)
{
return ((n + m - 1) / m) * m;
}
////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
////////////////////////////////////////////////////////////////////////////////
// Fallback to Simt (FMA on cuda cores) if not in a special case below
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f32
template <typename ArchTag>
struct DefaultGemmType<
ArchTag,
float,
typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Specialization for tensorcores with f16/bf16 - Sm75+
template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<
ArchTag,
scalar_t,
typename cutlass::platform::enable_if<ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f16 - Volta
template <>
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 2;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Enables to do
// `auto x = kCondition ? fa(arg) : fb(arg)`
// when `fa` and `fb` have different types
template <bool kVal, typename TA, typename TB>
struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(ta(arg))
{
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) -> decltype(tb(arg))
{
return tb(arg);
}
};
////////////////////////////////////////////////////////////////////////////////
// Mark a variable as warp-uniform - enables some compiler optimizations
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value)
{
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
}
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr)
{
struct {
union {
T* ptr;
uint32_t asInt[2];
};
} p;
p.ptr = ptr;
p.asInt[0] = warp_uniform(p.asInt[0]);
p.asInt[1] = warp_uniform(p.asInt[1]);
return p.ptr;
}
} // namespace gemm_kernel_utils

View File

@ -0,0 +1,691 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Epilogue iterator that supports prefetching
Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in
/// epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator |
/// ForwardTileIterator
///
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false>
class PredicatedTileIteratorPrefetch {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0");
static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0");
static_assert(ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>())
{
}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {}
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() { enable(); }
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = false; }
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = true; }
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
uint8_t* byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const* indices_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorPrefetch(PredicatedTileIteratorParams const& params,
Element* pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const* indices = nullptr)
: params_(params), indices_(indices)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] =
((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) { mask_.clear(); }
if (ScatterD && !indices) { mask_.clear(); }
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
if (ScatterD) {
byte_pointer_ =
reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
}
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_DEVICE
void prefetch_all()
{
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kIterations; ++iter) {
prefetch();
++(*this);
}
}
CUTLASS_DEVICE
void prefetch()
{
uint8_t* byte_pointer = byte_pointer_;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
AccessType* memory_pointer = reinterpret_cast<AccessType*>(byte_pointer);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
// on windows using unsigned long here gives the error
// error: asm operand type size(4) does not match
// type/size implied by constraint 'l'
uint64_t addr =
(uint64_t)((void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess]);
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) { byte_pointer += params_.increment_row; }
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const
{
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) { byte_pointer += params_.increment_row; }
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) const { load_with_byte_offset(frag, 0); }
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const
{
uint8_t* byte_pointer = byte_pointer_;
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
if (UseCUDAStore) {
if (guard) {
memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) { byte_pointer += params_.increment_row; }
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); }
/// Loads a fragment from memory
CUTLASS_DEVICE
void downsample_load_with_byte_offset(Fragment& frag,
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const
{
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int input_row = output_N * 2 * convolution_P * 2 * convolution_Q +
(2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q +
add_Q;
int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void upsample_load_with_byte_offset(Fragment& frag,
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const
{
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int row_add_P = add_P;
int row_add_Q = add_Q;
if (output_P > convolution_P - 2) row_add_P = 0;
if (output_Q > convolution_Q - 2) row_add_Q = 0;
int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) +
((output_P + row_add_P) / 2) * (convolution_Q / 2) +
(output_Q + row_add_Q) / 2;
int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
CUTLASS_DEVICE
MatrixCoord thread_start() const
{
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const { return thread_start_row_; }
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const { return thread_start_column_; }
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const { return extent_row_; }
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const { return extent_column_; }
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorPrefetch& operator++()
{
++state_[0];
if (!ScatterD) { byte_pointer_ += params_.advance_row; }
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
}
}
}
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() { mask_.clear(); }
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() { mask_.enable(); }
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; }
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; }
};
template <typename IT>
struct MakePrefetchableIterator {
using Iterator = PredicatedTileIteratorPrefetch<typename IT::ThreadMap, typename IT::Element>;
};
///////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,91 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "predicated_tile_access_iterator_residual_last.h"
#include "predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<
PredicatedTileIterator<Shape, Element, Layout, AdvanceRank, ThreadMap, AccessSize, Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass

View File

@ -0,0 +1,886 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/threadblock/predicated_tile_iterator.h>
#include <cutlass/tensor_coord.h>
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <class AccessType, class Enable = void>
struct atomic_store {};
template <class AccessType>
struct atomic_store<AccessType,
typename platform::enable_if<
platform::is_same<typename AccessType::Element, half_t>::value>::type> {
using Element = typename AccessType::Element;
static const int kCount = AccessType::kElements;
CUTLASS_DEVICE
atomic_store(AccessType const& D, void* ptr, bool pred_guard)
{
static_assert(!(kCount % 2), "kCount must be even");
half2* p = reinterpret_cast<half2*>(ptr);
uint const* data = reinterpret_cast<uint const*>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
:
: "r"((int)pred_guard));
for (int i = 0; i < kCount / 2; i++) {
asm volatile(" @p red.relaxed.global.add.noftz.f16x2 [%0], %1;\n"
:
: "l"(p + i), "r"(data[i]));
}
asm volatile("}\n" ::);
}
};
template <class AccessType>
struct atomic_store<AccessType,
typename platform::enable_if<
platform::is_same<typename AccessType::Element, float>::value>::type> {
using Element = typename AccessType::Element;
static const int kCount = AccessType::kElements;
CUTLASS_DEVICE
atomic_store(AccessType const& D, void* ptr, bool pred_guard)
{
Element* p = reinterpret_cast<Element*>(ptr);
uint const* data = reinterpret_cast<uint const*>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
:
: "r"((int)pred_guard));
for (int i = 0; i < kCount; i++) {
asm volatile(" @p red.relaxed.global.add.f32 [%0], %1;\n"
:
: "l"(p + i), "r"(data[i]));
}
asm volatile("}\n" ::);
}
};
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
int Rank>
class PredicatedTileIteratorAffineRankNAtomic {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::AffineRankN<Rank>;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = typename Layout::TensorCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0");
static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0");
static_assert(ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0");
static_assert(!(Layout::kRank % 2),
"Layout rank must be even. This assumes the first half of the "
"modes correspond to the 'row' "
"and the second half of the modes correspond to the 'column'");
static bool const kBigEndian = false;
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Parameters structure
struct Params {
//
// Data members
//
Layout layout;
/// Stride in units of bytes along M modes
Coord<Layout::kRank / 2, typename Layout::LongIndex> stride_m;
/// Stride in units of bytes along N modes
Coord<Layout::kRank / 2, typename Layout::LongIndex> stride_n;
/// Fast divmod objects divided by tensor extents
FastDivmod divmod_m[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)];
/// Fast divmod objects divided by tensor extents
FastDivmod divmod_n[(Layout::kRank == 2) ? 1 : (Layout::kRank / 2 - 1)];
int64_t rank2_inc_col;
int64_t rank2_inc_row;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(TensorCoord const& extent, Layout const& layout_) : layout(layout_)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2; ++i) {
stride_m[i] = OffsetBytes<Element>(layout_.stride()[i]);
stride_n[i] = OffsetBytes<Element>(layout_.stride()[i + Layout::kRank / 2]);
}
if (kBigEndian) {
// "Big Endian" scheme
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
divmod_m[i] = FastDivmod(extent[i + 1]);
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2 + 1]);
}
} else {
// "Little Endian" scheme
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2 - 1; ++i) {
divmod_m[i] = FastDivmod(extent[i]);
divmod_n[i] = FastDivmod(extent[i + Layout::kRank / 2]);
}
}
}
CUTLASS_HOST_DEVICE
Params(Layout const& layout_) : layout(layout_)
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Layout::kRank / 2; ++i) {
stride_m[i] = OffsetBytes<Element>(layout_.stride()[i]);
stride_n[i] = OffsetBytes<Element>(layout_.stride()[i + Layout::kRank / 2]);
}
rank2_inc_col = ThreadMap::Delta::kColumn * stride_n[0];
rank2_inc_row = ThreadMap::Delta::kRow * stride_m[0];
}
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() { enable(); }
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = false; }
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = true; }
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
Params params_;
/// Byte-level pointer
uint8_t* byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in columns
Index extent_col_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column position (assuming steady-state predicates have
/// been computed)
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Offsets in columns, cached for performance
int64_t offset_modes_n_[ThreadMap::Iterations::kColumn];
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorAffineRankNAtomic(
Params const& params,
Element* pointer,
MatrixCoord extent,
int thread_idx,
MatrixCoord threadblock_offset = MatrixCoord(),
int const* indices = nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: params_(params)
{
MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_col_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
if (Layout::kRank > 2) {
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
//
// Compute coordinate and decompose into N modes
//
int coord_n = thread_start_column_ + c * ThreadMap::Delta::kColumn;
mask_.predicates[c] = coord_n < extent.column();
Coord<Layout::kRank / 2, Index> modes_n;
int64_t offset_modes_n = 0;
if (kBigEndian) {
modes_n = CoordinateDecomposition<Layout::kRank / 2>(coord_n, params_.divmod_n);
offset_modes_n = dot(modes_n, params_.stride_n);
} else {
modes_n = CoordinateDecompositionLittleEndian<Layout::kRank / 2>(
coord_n, params_.divmod_n);
offset_modes_n = dot(modes_n, params_.stride_n);
}
offset_modes_n_[c] = offset_modes_n;
}
if (!pointer) { mask_.clear(); }
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer);
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset)
{
uint8_t* byte_pointer = byte_pointer_;
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
int row_begin = thread_start_row_ + group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
int64_t offset_modes_m = row_begin * params_.stride_m[0];
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
//
// Compute coordinate and decompose into M modes
//
int coord_m = row * ThreadMap::Delta::kRow + row_begin;
Coord<Layout::kRank / 2, Index> modes_m;
if (Layout::kRank > 2) {
if (kBigEndian) {
modes_m = CoordinateDecomposition<Layout::kRank / 2>(coord_m,
params_.divmod_m);
} else {
modes_m = CoordinateDecompositionLittleEndian<Layout::kRank / 2>(
coord_m, params_.divmod_m);
}
offset_modes_m = dot(modes_m, params_.stride_m);
}
//
// Compute the offset due to modes M
//
bool row_guard = (coord_m < extent_row_);
int64_t offset_modes_n = thread_start_column_ * params_.stride_n[0];
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
//
// Compute coordinate and decompose into N modes
//
if (Layout::kRank > 2) { offset_modes_n = offset_modes_n_[column]; }
//
// Compute the pointer and access
//
bool guard;
if (Layout::kRank > 2) {
guard = row_guard && mask_.predicates[column];
} else {
guard = (coord_m < extent_row_) &&
((thread_start_column_ + ThreadMap::Delta::kColumn * column) <
extent_col_);
}
atomic_store<AccessType>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)(byte_pointer + offset_modes_m + offset_modes_n + byte_offset),
guard);
if (Layout::kRank == 2) { offset_modes_n += params_.rank2_inc_col; }
}
if (Layout::kRank == 2) { offset_modes_m += params_.rank2_inc_row; }
}
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_byte_offset(frag, 0); }
CUTLASS_DEVICE
void load(Fragment& frag) {}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAffineRankNAtomic& operator++()
{
++state_[0];
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; }
}
}
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() { mask_.clear(); }
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() { mask_.enable(); }
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask& mask) { mask = mask_; }
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; }
};
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not
bool UseCUDAStore = false>
class PredicatedTileIteratorAtomic {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static bool constexpr PermuteD = !layout::is_trivial_permute<PermuteDLayout>;
static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0");
static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0");
static_assert(ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>())
{
}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {}
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() { enable(); }
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = false; }
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable()
{
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) { predicates[i] = true; }
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer. This pointer is usually for both load() and store(),
/// unless PermuteD is performed. When having PermuteD, byte_pointer_ is only
/// for load().
uint8_t* byte_pointer_;
/// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_
/// may be with different address computation compared to byte_pointer_.
uint8_t* store_byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const* indices_;
/// PermuteDLayout
PermuteDLayout permute_layout_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorAtomic(PredicatedTileIteratorParams const& params,
Element* pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const* indices = nullptr)
: params_(params),
indices_(indices),
permute_layout_(PitchLinearCoord(extent.column(), extent.row()),
params_.stride * kElementsPerAccess / sizeof(AccessType))
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] =
((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) { mask_.clear(); }
if (ScatterD && !indices) { mask_.clear(); }
// Initialize byte_pointer_
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
if (ScatterD) {
byte_pointer_ =
reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
}
// store_byte_pointer_ is set to be the same with byte_pointer_ unless
// PermuteD is used.
store_byte_pointer_ = PermuteD ? reinterpret_cast<uint8_t*>(pointer) : byte_pointer_;
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset)
{
store_byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const
{
uint8_t* byte_pointer = store_byte_pointer_;
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
if (PermuteD) {
int col_offset = column * ThreadMap::Delta::kColumn;
int col = col_offset + thread_start_column_;
int row = row_offset + thread_start_row_;
// Locate memory_pointer
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
permute_layout_(PitchLinearCoord(col, row)) * sizeof(AccessType) /
kElementsPerAccess);
}
atomic_store<AccessType>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void*)&memory_pointer[0],
guard);
if (!PermuteD) {
memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD && !PermuteD) { byte_pointer += params_.increment_row; }
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); }
CUTLASS_DEVICE
void load(Fragment& frag) {}
CUTLASS_DEVICE
MatrixCoord thread_start() const
{
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const { return thread_start_row_; }
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const { return thread_start_column_; }
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const { return extent_row_; }
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const { return extent_column_; }
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAtomic& operator++()
{
++state_[0];
if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_row; }
if (!ScatterD) { byte_pointer_ += params_.advance_row; }
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
store_byte_pointer_ += params_.advance_group;
thread_start_row_ +=
(ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
store_byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
store_byte_pointer_ += params_.advance_tile;
thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow *
ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile;
}
}
}
return *this;
}
/// Advances a number of positions to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorAtomic& operator+=(int increment)
{
// Row
state_[0] += increment;
int increment_row = state_[0] / ThreadMap::Count::kRow;
state_[0] = state_[0] % ThreadMap::Count::kRow;
byte_pointer_ += (params_.advance_row * increment);
store_byte_pointer_ += (params_.advance_row * increment);
thread_start_row_ += (ThreadMap::Shape::kRow * increment);
// Group
state_[1] += increment_row;
int increment_group = state_[1] / ThreadMap::Count::kGroup;
state_[1] = state_[1] % ThreadMap::Count::kGroup;
byte_pointer_ += (params_.advance_group * increment_row);
store_byte_pointer_ += (params_.advance_group * increment_row);
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow *
ThreadMap::Count::kRow * increment_row;
// Cluster
state_[2] += increment_group;
int increment_cluster = state_[2] / ThreadMap::Count::kCluster;
state_[2] = state_[2] % ThreadMap::Count::kCluster;
byte_pointer_ += (params_.advance_cluster * increment_group);
store_byte_pointer_ += (params_.advance_cluster * increment_group);
thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup *
ThreadMap::Count::kRow * ThreadMap::Shape::kRow * increment_group;
// Tile
byte_pointer_ += (params_.advance_tile * increment_cluster);
store_byte_pointer_ += (params_.advance_tile * increment_cluster);
thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow *
ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile *
increment_cluster;
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() { mask_.clear(); }
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() { mask_.enable(); }
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; }
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; }
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -0,0 +1,57 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include "warp_iterator_from_smem.h"
template <typename WarpIterator>
struct TransposeWarpIterator {
using Iterator = char;
static bool constexpr kSupportsTranspose = false;
};
template <
/// Operand identity
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element,
bool kTranspose>
struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
using Iterator = cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
static bool constexpr kSupportsTranspose = true;
};

View File

@ -0,0 +1,269 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*! \file
\brief Inspired from
"cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
operands from a RowMajor shared-memory layout into registers to use by A100
TensorCores.
The difference with "mma_tensor_op_tile_access_iterator.h" is that:
(1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
faster) (2) We support to transpose the operand (eg read `A.transpose()` when
the shared memory holds `A`)
This is only implemented for the specific shapes.
*/
#pragma once
#include <cutlass/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
template <
/// Operand identity
Operand Operand_,
/// Data type of A elements
typename Element_,
bool kTranspose = false>
class WarpIteratorFromSmem {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = cutlass::MatrixShape<32, 32>;
/// Operand tag
static Operand const kOperand = Operand_;
/// Basic check
static_assert(
kOperand == Operand::kA || kOperand == Operand::kB,
"WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
/// Element type
using Element = Element_;
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
/// Layout of source tile
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = cutlass::MatrixShape<16, 8>;
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
static int const kOpDelta = 1;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Number of elements accessed per Shared Memory load
static int const kElementsPerAccess =
(sizeof_bits<Element>::value >= 32 ? 1 : 32 / sizeof_bits<Element>::value);
using InstructionCount = MatrixShape<Shape::kRow / InstructionShape::kRow,
Shape::kColumn / InstructionShape::kColumn>;
static int const kIterations = (kOperand == Operand::kA) ? InstructionCount::kColumn
: InstructionCount::kRow;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment =
Array<Element,
(kOperand == Operand::kA) ? (Shape::kRow* InstructionShape::kColumn / kThreads)
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
/// Memory access type
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
using AccessType = Array<unsigned, 4>;
static int constexpr kWarpShapeDivisibleInner =
(kOperand == Operand::kA ? InstructionShape::kColumn : InstructionShape::kRow);
static int constexpr kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
private:
/// Underlying tensor reference
TensorRef ref_;
/// Origin
MatrixCoord origin_;
/// Iterations in a tile
int iterations_;
public:
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id)
{
}
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0)
{
int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0);
static_assert(InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, "");
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
++access_m_idx) {
int access_idx =
access_m_idx +
kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (access_idx == ldsm_vec_num) {
if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); }
origin_ += offset;
}
}
}
}
} else {
origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
MatrixCoord offset(inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
if (access_idx == ldsm_vec_num) {
if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); }
origin_ += offset;
}
}
}
}
ref_.add_coord_offset(origin_);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset)
{
TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
tile_offset.column() * Shape::kColumn);
if (kTranspose) { coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; }
origin_ += coord_offset;
ref_.add_coord_offset(coord_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
void advance()
{
if (kOperand == Operand::kA) {
add_tile_offset({0, 1});
} else {
add_tile_offset({1, 0});
}
iterations_ = 0;
}
/// increase iterations in a tile
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem& operator++()
{
iterations_++;
if (iterations_ >= kIterations) advance();
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_DEVICE
void load(Fragment& frag) const
{
AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
using LoadLayout =
typename platform::conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
}
if (kTranspose) { offset = MatrixCoord(offset.column(), offset.row()); }
cutlass::arch::ldsm<LoadLayout, 4>(access_ptr[0], ref_.data() + ref_.offset(offset));
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,986 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <curand_kernel.h>
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
#include "gemm_kernel_utils.h"
#include "transform/bias_broadcast.h"
#include "transform/tile_smem_loader.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm()
{
return (Arch::kMinComputeCapability >= 80 && !cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value)
{
// source: https://stackoverflow.com/a/51549250
return (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
}
} // namespace
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsBias_ = false,
template <typename, typename, typename> class Broadcast1_ = BroadcastNoLoad,
template <typename, typename, typename> class Broadcast2_ = BroadcastNoLoad>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV =
ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer =
!kKeepOutputInRF && !cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
static constexpr int kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (32 * 32);
static constexpr int kWarpSize = 32;
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm = getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t* output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
// Scale
accum_t scale;
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
// int32_t bias_strideM = 0;
int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
// int32_t bias_strideH = 0;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
// int32_t bias_strideB = 0;
int32_t num_batches;
int32_t num_heads;
// Parameters for biases
scalar_t* bias1_ptr = nullptr;
scalar_t* bias2_ptr = nullptr;
int32_t B = 0;
int32_t N = 0;
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block()
{
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
}
int64_t q_start = 0, k_start = 0;
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr += int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(q_start + query_start) * (head_dim_value * num_heads) +
head_id * head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr += batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
using broadcast_1 = Broadcast1_<typename MM0::BiasLoader::ThreadMap,
typename MM0::BiasLoader::Shape,
scalar_t>;
if (kSupportsBias && broadcast_1::kEnable && bias1_ptr) {
bias1_ptr = broadcast_1::advance(bias1_ptr,
batch_id / N,
batch_id % N,
head_id,
num_queries * N,
num_queries,
0);
}
using broadcast_2 = Broadcast2_<typename MM0::BiasLoader::ThreadMap,
typename MM0::BiasLoader::Shape,
scalar_t>;
if (kSupportsBias && broadcast_2::kEnable && bias2_ptr) {
auto strideB = num_heads * num_queries * num_keys;
auto strideH = num_queries * num_keys;
bias2_ptr = broadcast_2::advance(
bias2_ptr, batch_id / N, batch_id % N, head_id, strideB, 0, strideH);
}
num_queries -= query_start;
num_batches = 0; // no longer used after
// If num_queries == 1, and there is only one key head we're wasting
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
if (head_id % kQueriesPerBlock != 0) return false;
q_strideM = q_strideH;
num_queries = num_heads;
num_heads = 1; // unused but here for intent
o_strideM = head_dim_value;
}
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM);
if (kSupportsBias && broadcast_1::kEnable) { bias1_ptr = warp_uniform(bias1_ptr); }
if (kSupportsBias && broadcast_2::kEnable) { bias2_ptr = warp_uniform(bias2_ptr); }
return true;
}
__host__ dim3 getBlocksGrid() const
{
return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock), num_heads, num_batches);
}
__host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); }
};
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = kIsAligned ? DefaultConfig::kAlignmentA
: GemmType::kMinimumAlignment;
static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB
: GemmType::kMinimumAlignment;
using ThreadblockShape =
cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::ColumnMajor, // LayoutB,
kAlignmentB,
accum_t,
cutlass::layout::RowMajor, // LayoutC,
OpClass,
ArchTag, // ArchTag
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Iterator;
static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// used for efficient load of bias tile Bij from global to shared memory
using BiasLoader =
TileSmemLoader<scalar_t,
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
MmaCore::kThreads,
// input restriction: kv_len has to be a multiple of this value
128 / cutlass::sizeof_bits<scalar_t>::value>;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm =
typename cutlass::gemm::threadblock::B2bGemm<typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/**
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<OpClass,
ArchTag,
scalar_t,
scalar_t,
output_accum_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
static constexpr int kAlignmentB = kIsAligned ? DefaultConfig::kAlignmentB
: GemmType::kMinimumAlignment;
using ThreadblockShape =
cutlass::gemm::GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using LayoutB = cutlass::layout::RowMajor;
using DefaultGemm =
cutlass::gemm::kernel::DefaultGemm<scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
LayoutB, // LayoutB,
kAlignmentB,
output_accum_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem = typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, "");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator = typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
static constexpr int64_t kAlignmentV = 1;
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
union {
// typename MM0::BiasLoader::SmemTile bias;
cutlass::AlignedBuffer<float, MM0::BiasLoader::Shape::kCount> bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage()
{
return epilogue;
}
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
union {
// typename MM0::BiasLoader::SmemTile bias;
cutlass::AlignedBuffer<float, MM0::BiasLoader::Shape::kCount> bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage()
{
return after_mm0.epilogue;
}
};
using SharedStorage =
typename cutlass::platform::conditional<kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
static bool __host__ check_supported(Params const& p)
{
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
EVOFORMER_CHECK(p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned (strideM)");
EVOFORMER_CHECK(p.k_strideM % kAlignmentK == 0, "key is not correctly aligned (strideM)");
EVOFORMER_CHECK(p.v_strideM % kAlignmentV == 0, "value is not correctly aligned (strideM)");
EVOFORMER_CHECK(p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
"query is not correctly aligned (strideH)");
EVOFORMER_CHECK(p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
"key is not correctly aligned (strideH)");
EVOFORMER_CHECK(p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
"value is not correctly aligned (strideH)");
return true;
}
static void CUTLASS_DEVICE attention_kernel(Params& p)
{
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
p.output_ptr,
typename OutputTileIterator::TensorCoord{p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)(p.head_dim_value * p.num_heads)},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n =
cutlass::fast_min(int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm, iterator_V, thread_id(), problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM,
tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(),
tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// multiply by scaling factor
// if (kSupportsBias) {
// accum =
// cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale,
// accum);
// }
if (kSupportsBias) {
cutlass::TensorRef<float, cutlass::layout::RowMajor> bias_tensor_ref(
shared_storage.after_mm0.bias.data(),
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
using Shape =
cutlass::MatrixShape<MM0::ThreadblockShape::kM, MM0::ThreadblockShape::kN>;
AttentionBiasEpilogue<Shape,
scalar_t,
MM0::MmaCore::kThreads,
Broadcast1_,
Broadcast2_>
bias_epilogue;
bias_epilogue(bias_tensor_ref,
p.bias1_ptr + iter_key_start,
p.bias2_ptr + query_start * p.num_keys + iter_key_start,
thread_id(),
{problem_size_0_m, problem_size_0_n},
p.num_keys);
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
accum[idx] =
accum[idx] * p.scale + bias_tensor_ref.at({accum_m, accum_n});
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock, kFullColumns, ([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 =
my_warp_id % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords =
cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN =
kSingleValueIteration
? 1
: ceil_div((int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) { accum_o.clear(); }
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys, kIsLast, ([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::thread::
MemoryEfficientAttentionNormalize<
typename cutlass::platform::
conditional<kIsLast, output_t, output_accum_t>::
type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter =
call_conditional<kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) { __syncthreads(); }
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] =
accum_t(mi[thread_id()]) + cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}
template <typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling)
{
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator =
typename DefaultMmaAccumLambdaIterator<WarpIteratorC, accum_t, kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) { m_prime[thread_id] = mi[thread_id]; }
__syncthreads();
}
auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { max = -cutlass::platform::numeric_limits<accum_t>::infinity(); },
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col) ? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (LambdaIterator::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; }
static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; }
static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; }
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl(typename AK::Params p)
{
if (!p.advance_to_block()) { return; }
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);

View File

@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
// This does nothing.
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastNoLoad {
using Fragment =
cutlass::Array<scalar_t, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
static const bool kEnable = false;
CUTLASS_DEVICE static void load(Fragment& frag,
scalar_t* ptr,
int thread_id,
const cutlass::MatrixCoord& extent,
int stride)
{
}
CUTLASS_DEVICE static scalar_t*
advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
{
return ptr;
}
};
// This is to load the bias matrix from the global memory with on-the-fly
// broadcast. The shape in global memory is [B, N, 1, 1, L]. Each time we load
// the last dimension as a L row vector, and we further broadcast the L vector
// to a tile of size [L, L] by repeating the L vector L times
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastA : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
using Base = BroadcastNoLoad<ThreadMap, Shape, scalar_t>;
static const bool kEnable = true;
using layout = cutlass::layout::AffineRank2RowMajor;
using GmemTileIterator = cutlass::transform::threadblock::
PredicatedTileIterator<Shape, scalar_t, layout, 0, ThreadMap>;
using Fragment = typename GmemTileIterator::Fragment;
CUTLASS_DEVICE static void load(Fragment& frag,
scalar_t* ptr,
int thread_id,
const cutlass::MatrixCoord& extent,
int stride)
{
GmemTileIterator iter({layout(0, 1)}, ptr, extent, thread_id);
iter.load(frag);
}
CUTLASS_DEVICE static scalar_t*
advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
{
return ptr + B_id * strideB + N_id * strideN;
}
};
// This is to load the bias matrix from the global memory with on-the-fly
// broadcast. The shape in global memory is [B, 1, H, L, L]. Each time we load
// a [L, L] matrix. Different N use the same bias matrix when B and H are the
// same.
template <typename ThreadMap, typename Shape, typename scalar_t>
struct BroadcastB : public BroadcastNoLoad<ThreadMap, Shape, scalar_t> {
using Base = BroadcastNoLoad<ThreadMap, Shape, scalar_t>;
static const bool kEnable = true;
using layout = cutlass::layout::RowMajor;
using GmemTileIterator = cutlass::transform::threadblock::
PredicatedTileIterator<Shape, scalar_t, layout, 0, ThreadMap>;
using Fragment = typename GmemTileIterator::Fragment;
CUTLASS_DEVICE static void load(Fragment& frag,
scalar_t* ptr,
int thread_id,
const cutlass::MatrixCoord& extent,
int stride)
{
GmemTileIterator iter({layout(stride)}, ptr, extent, thread_id);
iter.load(frag);
}
CUTLASS_DEVICE static scalar_t*
advance(scalar_t* ptr, int B_id, int N_id, int H_id, int strideB, int strideN, int strideH)
{
return ptr + B_id * strideB + H_id * strideH;
}
};
template <typename Shape,
typename scalar_t,
int kThreads,
template <typename, typename, typename>
class Broadcast1_,
template <typename, typename, typename>
class Broadcast2_>
struct AttentionBiasEpilogue {
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
kThreads,
128 / cutlass::sizeof_bits<scalar_t>::value>;
using Broadcast1 = Broadcast1_<ThreadMap, Shape, scalar_t>;
using Broadcast2 = Broadcast2_<ThreadMap, Shape, scalar_t>;
Broadcast1 broadcast1;
Broadcast2 broadcast2;
using Ref = cutlass::TensorRef<float, cutlass::layout::RowMajor>;
using SmemTileIterator = cutlass::transform::threadblock::
RegularTileIterator<Shape, float, cutlass::layout::RowMajor, 0, ThreadMap>;
CUTLASS_DEVICE void operator()(const Ref& ref,
scalar_t* ptr1,
scalar_t* ptr2,
int thread_id,
const cutlass::MatrixCoord& extent,
int stride)
{
static_assert(Broadcast1::Fragment::kElements == Broadcast2::Fragment::kElements,
"The two broadcast fragments must have the same number of "
"elements");
typename SmemTileIterator::Fragment frag;
frag.clear();
float* frag_ptr = reinterpret_cast<float*>(&frag);
if (Broadcast1::kEnable) {
typename Broadcast1::Fragment frag1;
frag1.clear();
broadcast1.load(frag1, ptr1, thread_id, extent, stride);
scalar_t* frag1_ptr = reinterpret_cast<scalar_t*>(&frag1);
for (int i = 0; i < Broadcast1::Fragment::kElements; ++i) {
frag_ptr[i] += static_cast<float>(frag1_ptr[i]);
}
}
if (Broadcast2::kEnable) {
typename Broadcast2::Fragment frag2;
frag2.clear();
broadcast2.load(frag2, ptr2, thread_id, extent, stride);
scalar_t* frag2_ptr = reinterpret_cast<scalar_t*>(&frag2);
for (int i = 0; i < Broadcast2::Fragment::kElements; ++i) {
frag_ptr[i] += static_cast<float>(frag2_ptr[i]);
}
}
SmemTileIterator iter(ref, thread_id);
iter.store(frag);
__syncthreads();
}
};

View File

@ -0,0 +1,93 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
template <typename scalar_t, // scalar type
typename ThreadblockTileShape, // size of tile to load
int Threads, // number of participating threads
int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
public:
using Shape = ThreadblockTileShape;
using SmemTile = cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<ThreadblockTileShape::kColumn, // contiguous
ThreadblockTileShape::kRow>, // strided
Threads, // Threads
ElementsPerAccess>; // ElementsPerAccess
using GmemTileIterator = cutlass::transform::threadblock::PredicatedTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using SmemTileIterator =
cutlass::transform::threadblock::RegularTileIterator<ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using Fragment = typename GmemTileIterator::Fragment;
/// load a tile from global memory into shared memory
CUTLASS_DEVICE
static void load(GmemTileIterator tile_load_iter, SmemTileIterator tile_store_iter)
{
Fragment tb_frag;
tb_frag.clear();
tile_load_iter.load(tb_frag);
tile_store_iter.store(tb_frag);
__syncthreads();
}
};

View File

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .evoformer_attn import DS4Sci_EvoformerAttention, EvoformerFusedAttention

View File

@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import numpy as np
from deepspeed.ops.op_builder import EvoformerAttnBuilder
from deepspeed.accelerator import get_accelerator
kernel_ = None
def _attention(Q, K, V, bias1, bias2):
assert Q.shape[-3] > 16, "seq_len must be greater than 16"
O = torch.empty_like(Q, dtype=Q.dtype)
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
assert get_accelerator().on_accelerator(K), "K must be on cuda"
assert get_accelerator().on_accelerator(V), "V must be on cuda"
assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda"
assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda"
global kernel_
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
nheads = Q.shape[-2]
nq = (Q.shape[-3] + 31) // 32 * 32
nb = np.prod(Q.shape[:-3])
lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device)
kernel_.attention(Q, K, V, bias1, bias2, O, lse)
return O, lse
def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
dQ = torch.empty_like(Q, dtype=Q.dtype)
dK = torch.empty_like(K, dtype=K.dtype)
dV = torch.empty_like(V, dtype=V.dtype)
assert get_accelerator().on_accelerator(dO), "dO must be on cuda"
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
assert get_accelerator().on_accelerator(K), "K must be on cuda"
assert get_accelerator().on_accelerator(V), "V must be on cuda"
assert get_accelerator().on_accelerator(O), "O must be on cuda"
global kernel_
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
delta = torch.empty_like(lse)
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)
class EvoformerFusedAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias1=None, bias2=None):
"""
q, k, v: are in shape [*, L, H, D]
"""
bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
o, lse = _attention(q, k, v, bias1_, bias2_)
ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_)
return o
@staticmethod
def backward(ctx, grad_output):
q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2)
if bias1.numel() == 0:
dB1 = None
if bias2.numel() == 0:
dB2 = None
return dQ, dK, dV, dB1, dB2
def DS4Sci_EvoformerAttention(Q, K, V, biases):
assert len(biases) <= 2
if (len(biases) == 0):
biases.append(None)
if (len(biases) == 1):
biases.append(None)
bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2])
bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])
if biases[0] is not None:
assert biases[0].shape == bias_1_shape(Q)
else:
biases[0] = Q.new_zeros(bias_1_shape(Q))
if biases[1] is not None:
assert biases[1].shape == bias_2_shape(Q)
else:
biases[1] = Q.new_zeros(bias_2_shape(Q))
return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])

View File

@ -41,6 +41,7 @@ collections:
- cifar-10.md
- curriculum-learning.md
- data-efficiency.md
- ds4sci_evoformerattention.md
- flops-profiler.md
- pytorch-profiler.md
- autotuning.md

View File

@ -17,6 +17,8 @@ lnav:
url: /inference/
- title: 'Compression'
url: /compression/
- title: 'Science'
url: /deepspeed4science/
- title: 'Getting Started'
url: /getting-started/
- title: 'ds_config'
@ -67,6 +69,8 @@ lnav:
url: /tutorials/curriculum-learning/
- title: 'Data Efficiency'
url: /tutorials/data-efficiency/
- title: 'DS4Sci_EvoformerAttention'
url: /tutorials/ds4sci_evoformerattention/
- title: 'Flops Profiler'
url: /tutorials/flops-profiler/
- title: 'PyTorch Profiler'

View File

@ -0,0 +1,39 @@
---
title: "DeepSpeed4Science Overview and Tutorial"
permalink: /deepspeed4science/
toc: true
toc_label: "Contents"
toc_sticky: true
---
In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. This page serves as an overview page for all technologies released (or to be released in the future) as part of the DeepSpeed4Science initiative, making it easier for scientists to shop for techniques they need. Details of the DeepSpeed4Science initiative can be found at [our website](https://deepspeed4science.ai/). For each technique we will introduce what is it for, when to use it, links to how to use it, and existing scientific applications of the techniques (we welcome users to contribute more showcases if you apply our techniques in your scientific research):
* [2023/09] We are releasing two techniques: [DeepSpeed4Science large-scale training framework](#new-megatron-deepspeed-for-large-scale-ai4science-model-training), [DS4Sci_EvoformerAttention](#memory-efficient-evoformerattention-kernels) and their scientific applications in structural biology research.
## New Megatron-DeepSpeed for Large-Scale AI4Science Model Training
We are proud to introduce [new Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed), which is an updated framework for large-scale model training. We rebased and enabled DeepSpeed with the newest Megatron-LM for long sequence support and many other capabilities. With the new Megatron-DeepSpeed, users can now train their large AI4Science models like GenSLMS with much longer sequences via a synergetic combination of ZeRO-style data parallelism, tensor parallelism, sequence parallelism, pipeline parallelism, model state offloading, and several newly added memory optimization techniques such as attention mask offloading and position embedding partitoining.
![new Megatron-DeepSpeed](/assets/images/new-megatron-ds.png){: .align-center}
<p align="center">
<em>The figure depicts system capability in terms of enabling long sequence lengths for training a 33B parameter GPT-like model using our new Megatron-DeepSpeed framework. The results show that the new Megatron-DeepSpeed enables 9x onger sequence lengths than NVIDIA's Megatron-LM without triggering out-of-memory error. </em>
</p>
To see how the new Megatron-DeepSpeed helps enabling new system capabilities, such as training models with massive sequences length, please read our [tutorial](https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples_deepspeed/deepspeed4science/megatron_long_seq_support).
Meanwhile, our new Megatron-DeepSpeed has been applied to genome-scale foundation model [GenSLMs](https://github.com/ramanathanlab/genslm), which is a 2022 [ACM Gordon Bell award](https://www.acm.org/media-center/2022/november/gordon-bell-special-prize-covid-research-2022) winning genome-scale language model from Argonne National Lab. To achieve their scientific goal, GenSLMs and similar models require very long sequence support for both training and inference that is beyond generic LLM's long-sequence strategies. By leveraging DeepSpeed4Science's new Megatron-DeepSpeed, GenSLMs team is able to train their 25B model with 512K sequence length, much longer than their original 42K sequence length. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/). GenSLMs team also hosts an [example](https://github.com/ramanathanlab/genslm/tree/main/examples/long-sequences) about how to use DeepSpeed4Science in the GenSLMs repo.
## Memory-Efficient EvoformerAttention Kernels
[Evoformer](https://www.nature.com/articles/s41586-021-03819-2) is a key building block for scientific models such as DeepMind's AlphaFold. However, EvoFormer's multiple sequence alignment (MSA) attention frequently runs into memory explosion problems during training/inference, such as in protein structure prediction models. Existing techniques such as FlashAttention cannot effectively support Evoformer because EvoFormerAttention uses row-wise/column-wise/triangle attention, which are different from standard Transformer self-attention and cross-attention that require custom optimizations. To mitigate the memory explosion problem, we introduce `DS4Sci_EvoformerAttention` kernels, a collection of kernels that improve the memory efficiency of variants of EvoFormer. `DS4Sci_EvoformerAttention` is easy-to-use. To see how you can use it, please refer to our [tutorial](/tutorials/ds4sci_evoformerattention/).
`DS4Sci_EvoformerAttention` has already been applied to [OpenFold](https://github.com/aqlaboratory/openfold), which is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. With DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/).
<!-- OpenFold team also hosts an [example](https://github.com/aqlaboratory/openfold/blob/main/tests/test_deepspeed_evo_attention.py) about how to use DS4Sci_EvoformerAttention in the OpenFold repo. -->
![DS4Sci_EvoformerAttention](/assets/images/evoformer.png){: .align-center}
<p align="center">
<em>The figure shows that DeepSpeed's EvoFormerAttention kernels help reduce OpenFolds peak memory requirement for training by 13X. </em>
</p>

View File

@ -0,0 +1,74 @@
---
title: "DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models"
tags: training inference
---
## 1. What is DS4Sci_EvoformerAttention
`DS4Sci_EvoformerAttention` is a collection of kernels built to scale the [Evoformer](https://www.nature.com/articles/s41586-021-03819-2) computation to larger number of sequences and residuals by reducing the memory footprint and increasing the training speed.
## 2. When to use DS4Sci_EvoformerAttention
`DS4Sci_EvoformerAttention` is most beneficial when the number of sequences and residuals is large. The forward kernel is optimized to accelerate computation. It is beneficial to use the forward kernel during inference for various attention mechanisms. The associated backward kernel can be used during training to reduce the memory footprint at the cost of some computation. Therefore, it is beneficial to use `DS4Sci_EvoformerAttention` in training for memory-constrained operations such as MSA row-wise attention and MSA column-wise attention.
## 3. How to use DS4Sci_EvoformerAttention
### 3.1 Installation
`DS4Sci_EvoformerAttention` is released as part of DeepSpeed >= 0.10.3. `DS4Sci_EvoformerAttention` is implemented based on [CUTLASS](https://github.com/NVIDIA/cutlass). You need to clone the CUTLASS repository and specify the path to it in the environment variable `CUTLASS_PATH`.
```shell
git clone https://github.com/NVIDIA/cutlass
export CUTLASS_PATH=/path/to/cutlass
```
The kernels will be compiled when `DS4Sci_EvoformerAttention` is called for the first time.
`DS4Sci_EvoformerAttention` requires GPUs with compute capability 7.0 or higher (NVIDIA V100 or later GPUs) and the minimal CUDA version is 11.3. It is recommended to use CUDA 11.7 or later for better performance. Besides, the performance of backward kernel on V100 kernel is not as good as that on A100 for now.
### 3.2 Unit test and benchmark
The unit test and benchmark are available in the `tests` folder in DeepSpeed repo. You can use the following command to run the unit test and benchmark.
```shell
pytest -s tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py
python tests/benchmarks/DS4Sci_EvoformerAttention_bench.py
```
### 3.3 Applying DS4Sci_EvoformerAttention to your own model
To use `DS4Sci_EvoformerAttention` in user's own models, you need to import `DS4Sci_EvoformerAttention` from `deepspeed.ops.deepspeed4science`.
```python
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
```
`DS4Sci_EvoformerAttention` supports four attention mechanisms in Evoformer (MSA row-wise, MSA column-wise, and 2 kinds of Triangular) by using different inputs as shown in the following examples. In the examples, we denote the number of sequences as `N_seq` and the number of residuals as `N_res`. The dimension of the hidden states `Dim` and head number `Head` are different among different attention. Note that `DS4Sci_EvoformerAttention` requires the input tensors to be in `torch.float16` or `torch.bfloat16` data type.
(a) **MSA row-wise attention** builds attention weights for residue pairs and integrates the information from the pair representation as an additional bias term.
```python
# Q, K, V: [Batch, N_seq, N_res, Head, Dim]
# res_mask: [Batch, N_seq, 1, 1, N_res]
# pair_bias: [Batch, 1, Head, N_res, N_res]
out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, pair_bias])
```
(b) **MSA column-wise attention** lets the elements that belong to the same target residue exchange information.
```python
# Q, K, V: [Batch, N_res, N_seq, Head, Dim]
# res_mask: [Batch, N_seq, 1, 1, N_res]
out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask])
```
(c) **Triangular self-attention** updates the pair representation. There are two kinds of Triangular self-attention: around starting and around ending node. Below is the example of triangular self-attention around starting node. The triangular self-attention around ending node is similar.
```python
# Q, K, V: [Batch, N_res, N_res, Head, Dim]
# res_mask: [Batch, N_res, 1, 1, N_res]
# right_edges: [Batch, 1, Head, N_res, N_res]
out = DS4Sci_EvoformerAttention(Q, K, V, [res_mask, right_edges])
```
## 4. DS4Sci_EvoformerAttention scientific application
### 4.1 DS4Sci_EvoformerAttention eliminates memory explosion problems for scaling Evoformer-centric structural biology models in OpenFold
[OpenFold](https://github.com/aqlaboratory/openfold) is a community reproduction of DeepMind's AlphaFold2 that makes it possible to train or finetune AlphaFold2 on new datasets. Training AlphaFold2 incurs a memory explosion problem because it contains several custom Evoformer attention variants that manifest unusually large activations. By leveraging DeepSpeed4Science's DS4Sci_EvoformerAttention kernels, OpenFold team is able to reduce the peak memory requirement by 13x without accuracy loss. Detailed information about this application can be found at [our website](https://deepspeed4science.ai/).
<!-- OpenFold team also hosts an [example](https://github.com/aqlaboratory/openfold/blob/main/tests/test_deepspeed_evo_attention.py) about how to use DS4Sci_EvoformerAttention in the OpenFold repo. -->

Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
docs/assets/images/evoformer.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -7,11 +7,11 @@ title: "Latest News"
---
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md)
* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses)
* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)]
* [2023/04] 🚀 [DeepSpeed Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/japanese/README.md)]🚀
# Extreme Speed and Scale for DL Training and Inference
@ -24,9 +24,9 @@ title: "Latest News"
* Achieve extreme compression for an unparalleled inference latency and model size reduction with low costs
# DeepSpeed has three innovation pillars:
# DeepSpeed has four innovation pillars:
![Three innovation pillars](/assets/images/3pillars.png){: .align-center}
[![Four innovation pillars](/assets/images/DeepSpeed-pillars.png){: .align-center}](https://deepspeed4science.ai/)
## DeepSpeed-Training
@ -41,6 +41,10 @@ DeepSpeed brings together innovations in parallelism technology such as tensor,
To further increase the inference efficiency, DeepSpeed offers easy-to-use and flexible-to-compose compression techniques for researchers and practitioners to compress their models while delivering faster speed, smaller model size, and significantly reduced compression cost. Moreover, SoTA innovations on compression like ZeroQuant and XTC are included under the DeepSpeed-Compression pillar. Learn more: [DeepSpeed-Compression](https://www.deepspeed.ai/compression)
## DeepSpeed4Science
In line with Microsoft's mission to solve humanity's most pressing challenges, the DeepSpeed team at Microsoft is responding to this opportunity by launching a new initiative called *DeepSpeed4Science*, aiming to build unique capabilities through AI system technology innovations to help domain experts to unlock today's biggest science mysteries. Learn more: [DeepSpeed4Science website](https://deepspeed4science.ai/) and [tutorials](/deepspeed4science/)
# DeepSpeed Software Suite
## DeepSpeed Library

View File

@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import CUDAOpBuilder, installed_cuda_version
import os
class EvoformerAttnBuilder(CUDAOpBuilder):
BUILD_VAR = "DS_BUILD_EVOFORMER_ATTN"
NAME = "evoformer_attn"
def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)
self.cutlass_path = os.environ.get('CUTLASS_PATH')
def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'
def extra_ldflags(self):
if not self.is_rocm_pytorch():
return ['-lcurand']
else:
return []
def sources(self):
src_dir = 'csrc/deepspeed4science/evoformer_attn'
return [f'{src_dir}/attention.cpp', f'{src_dir}/attention_back.cu', f'{src_dir}/attention.cu']
def is_compatible(self, verbose=True):
try:
import torch
except ImportError:
self.warning("Please install torch if trying to pre-compile kernels")
return False
if self.cutlass_path is None:
self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH")
return False
with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
if '3.1.0' not in f.read():
self.warning("Please use CUTLASS version >= 3.1.0")
return False
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability < 7:
self.warning("Please use a GPU with compute capability >= 7.0")
cuda_okay = False
if torch_cuda_major < 11 or sys_cuda_major < 11:
self.warning("Please use CUDA 11+")
cuda_okay = False
return super().is_compatible(verbose) and cuda_okay
def include_paths(self):
includes = [f'{self.cutlass_path}/include', f'{self.cutlass_path}/tools/util/include']
return includes

View File

@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
This script is to test the correctness of the DS4Sci_EvoformerAttention op.
To run the script,
1. Clone the CUTLASS repo. E.g. git clone https://github.com/NVIDIA/cutlass.git
2. Specify the CUTLASS_PATH environment variable. E.g. export CUTLASS_PATH=$(pwd)/cutlass
3. Run the script. E.g. python DS4Sci_EvoformerAttention_bench.py
"""
import contextlib
import torch
from typing import List
from torch.nn import functional as F
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
from deepspeed.accelerator import get_accelerator
def attention_reference(
q_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
k_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
v_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
biases: List[torch.Tensor],
sm_scale: float) -> torch.Tensor:
# Original shape: [*, Dim_Q, H, C_hid] -> Transpose to: [*, H, Dim_Q, C_hid]
q = q_input.transpose(-2, -3)
k = k_input.transpose(-2, -3)
v = v_input.transpose(-2, -3)
# Now, q, k, v are in shape: [*, H, Dim_Q, C_hid]
# Transpose k to shape [*, H, C_hid, Dim_Q]
k_t = k.transpose(-1, -2)
# Now, q and k_t are in shapes: [*, H, Dim_Q, C_hid] and [*, H, C_hid, Dim_Q] respectively
# [*, H, Dim_Q, Dim_Q]
a = torch.matmul(q, k_t) * sm_scale
for b in biases:
a += b
a = F.softmax(a, dim=-1)
# Now, a is in shape [*, H, Dim_Q, Dim_Q], v is in shape [*, H, Dim_Q, C_hid]
# Matmul operation results in [*, H, Dim_Q, C_hid]
a_v = torch.matmul(a, v)
# [*, Dim_Q, H, C_hid]
o = a_v.transpose(-2, -3)
return o
dtype = torch.float16
batch = 1
N = 256
heads = 4
dim = 32
seq_len = 256
@contextlib.contextmanager
def cuda_timer(res_list):
start = get_accelerator().Event(enable_timing=True)
end = get_accelerator().Event(enable_timing=True)
start.record()
yield
end.record()
get_accelerator().synchronize()
res_list.append(start.elapsed_time(end))
def benchmark():
ours_fw = []
ours_bw = []
baseline_fw = []
baseline_bw = []
for batch_size in range(1, 17):
Q = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
K = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
V = torch.randn(batch, N, seq_len, heads, dim, dtype=dtype, device="cuda", requires_grad=True)
bias1 = torch.randn(batch, N, 1, 1, seq_len, dtype=dtype, device="cuda", requires_grad=True)
bias2 = torch.randn(batch, 1, heads, seq_len, seq_len, dtype=dtype, device="cuda", requires_grad=True)
# warm up
DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
with cuda_timer(ours_fw):
out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
d_out = torch.rand_like(out)
with cuda_timer(ours_bw):
out.backward(d_out)
# warm up
attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
with cuda_timer(baseline_fw):
ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
with cuda_timer(baseline_bw):
ref_out.backward(d_out)
print(f"batch size\tours (FW)\tbaseline (FW)\tours (BW)\tbaseline (BW)")
for i in range(len(ours_fw)):
print(f"{i+1}\t{ours_fw[i]}\t{baseline_fw[i]}\t{ours_bw[i]}\t{baseline_bw[i]}")
benchmark()

View File

@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import pytest
import torch
from torch.nn import functional as F
import deepspeed
from deepspeed.ops.op_builder import EvoformerAttnBuilder
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
from deepspeed.accelerator import get_accelerator
from unit.util import skip_on_arch
if not deepspeed.ops.__compatible_ops__[EvoformerAttnBuilder.NAME]:
pytest.skip("DS4Sci_EvoformerAttention ops are not available on this system", allow_module_level=True)
def attention_reference(
q_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
k_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
v_input: torch.Tensor, # [*, Dim_Q, H, C_hid]
biases: List[torch.Tensor],
sm_scale: float) -> torch.Tensor:
q = q_input.transpose(-2, -3)
k = k_input.transpose(-2, -3)
v = v_input.transpose(-2, -3)
k_t = k.transpose(-1, -2)
a = torch.matmul(q, k_t) * sm_scale
for b in biases:
a += b
a = F.softmax(a, dim=-1)
a_v = torch.matmul(a, v)
o = a_v.transpose(-2, -3)
return o
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)])
def test_DS4Sci_EvoformerAttention(dtype, tensor_shape):
skip_on_arch(8 if dtype == torch.bfloat16 else 7)
batch, n, seq_len, heads, dim = tensor_shape
Q = torch.randn(batch,
n,
seq_len,
heads,
dim,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
K = torch.randn(batch,
n,
seq_len,
heads,
dim,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
V = torch.randn(batch,
n,
seq_len,
heads,
dim,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
bias1 = torch.randn(batch,
n,
1,
1,
seq_len,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
bias2 = torch.randn(batch,
1,
heads,
seq_len,
seq_len,
dtype=dtype,
device=get_accelerator().device_name(),
requires_grad=True)
dummy_out = torch.rand_like(Q, dtype=dtype, device=get_accelerator().device_name())
ref_out = attention_reference(Q, K, V, [bias1, bias2], 1 / (dim**0.5))
ref_out.backward(dummy_out)
ref_dv, V.grad = V.grad.clone(), None
ref_dk, K.grad = K.grad.clone(), None
ref_dq, Q.grad = Q.grad.clone(), None
ref_db1, bias1.grad = bias1.grad.clone(), None
ref_db2, bias2.grad = bias2.grad.clone(), None
out = DS4Sci_EvoformerAttention(Q, K, V, [bias1, bias2])
out.backward(dummy_out)
dv, v_grad = V.grad.clone(), None
dk, k_grad = K.grad.clone(), None
dq, q_grad = Q.grad.clone(), None
db1, bias1.grad = bias1.grad.clone(), None
db2, bias2.grad = bias2.grad.clone(), None
assert torch.allclose(ref_out, out, atol=2e-2, rtol=0), f"\n{ref_out} \n {out}"
assert torch.allclose(ref_dv, dv, atol=2e-2, rtol=0), f"\n{ref_dv} \n {dv}"
assert torch.allclose(ref_dk, dk, atol=2e-2, rtol=0), f"\n{ref_dk} \n {dk}"
assert torch.allclose(ref_dq, dq, atol=2e-2, rtol=0), f"\n{ref_dq} \n {dq}"
assert torch.allclose(ref_db1, db1, atol=2e-2, rtol=1e-2), f"{ref_db1} \n {db1}"
assert torch.allclose(ref_db2, db2, atol=2e-2, rtol=1e-2), f"{ref_db2} \n {db2}"