Files
DeepSpeed/deepspeed/ops/deepspeed4science/evoformer_attn.py
Conglong Li f876d81d34 DeepSpeed4Science (#4357)
* zero++ tutorial PR (#3783)

* [Fix] _conv_flops_compute when padding is a str and stride=1 (#3169)

* fix conv_flops_compute when padding is a str when stride=1

* fix error

* change type of paddings to tuple

* fix padding calculation

* apply formatting check

---------

Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* fix interpolate flops compute (#3782)

* use `Flops Profiler` to test `model.generate()` (#2515)

* Update profiler.py

* pre-commit run --all-files

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

---------

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>

* revert PR #3611 (#3786)

* bump to 0.9.6

* ZeRO++ chinese blog (#3793)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* remove staging trigger (#3792)

* DeepSpeed-Triton for Inference (#3748)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* ZeRO++ (#3784)

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>

* adding zero++ to navigation panel of deepspeed.ai (#3796)

* Add ZeRO++ Japanese blog (#3797)

* zeropp chinese blog

* try better quality images

* make title larger

* even larger...

* various fix

* center captions

* more fixes

* fix format

* add ZeRO++ Japanese blog

* add links

---------

Co-authored-by: HeyangQin <heyangqin@microsoft.com>
Co-authored-by: Conglong Li <conglong.li@gmail.com>

* Bug Fixes for autotuner and flops profiler (#1880)

* fix autotuner when backward is not called

* fix format

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>

* Missing strided copy for gated MLP (#3788)

Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* Requires grad checking. (#3789)

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* bump to 0.10.0

* Fix Bug in transform.cu (#3534)

* Bug fix

* Fixed formatting error

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>

* bug fix: triton importing error (#3799)

Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>

* DeepSpeed4Science (#569)

* Integrating evoformer attention

* add cutlass version check

* Updaate error message

* add benchmark

* Update

* Update evoformer_attn.py

* Update run_evoformer_test.py

* Update evoformer_attn.py

* Update run_evoformer_test.py

* support more GPU archs

* add copyright

* add tests

* Fix bugs

* Update benchmark

* update

* Fix nvcc macro

* clean code

* fix formatting

* fix yaml import

* skip unit test when not compatible

* fix yaml requirement

* revert changes

* update tutorial

* update

* fix formatting

* fix format

* skip evoformer attn in pre-compile-ops

* revert changes

* update tutorial

* fix cutlass check

* update tutorial

* refactor tutorial

* revise

* Updated the Megatron-DS section (#565)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* separate evoformer tutorial

* Revised the ds4science landing page (#566)

* Updated the Megatron-DS section

* minor fix

* minor fix

* minor fix

* Revised the landing page

* Revised the landing page

* Removing unused file

* fix links image position

* modify main page

* fix doc

---------

Co-authored-by: Shiyang Chen <csycfl@gmail.com>
Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>

---------

Co-authored-by: Heyang Qin <heyangqin@microsoft.com>
Co-authored-by: Bill Luo <50068224+zhiruiluo@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Guorun <84232793+CaffreyR@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: stephen youn <13525892+stephen-youn@users.noreply.github.com>
Co-authored-by: Stephen Youn <styoun@microsoft.com>
Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: Ethan Doe <yidoe@microsoft.com>
Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: GuanhuaWang <alexwgh333@gmail.com>
Co-authored-by: cmikeh2 <connorholmes@microsoft.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Reza Yazdani <reyazda@microsoft.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com>
Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com>
Co-authored-by: Shiyang Chen <csycfl@gmail.com>
Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>
2023-09-18 22:16:08 +00:00

103 lines
3.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import numpy as np
from deepspeed.ops.op_builder import EvoformerAttnBuilder
from deepspeed.accelerator import get_accelerator
kernel_ = None
def _attention(Q, K, V, bias1, bias2):
assert Q.shape[-3] > 16, "seq_len must be greater than 16"
O = torch.empty_like(Q, dtype=Q.dtype)
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
assert get_accelerator().on_accelerator(K), "K must be on cuda"
assert get_accelerator().on_accelerator(V), "V must be on cuda"
assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda"
assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda"
global kernel_
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
nheads = Q.shape[-2]
nq = (Q.shape[-3] + 31) // 32 * 32
nb = np.prod(Q.shape[:-3])
lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device)
kernel_.attention(Q, K, V, bias1, bias2, O, lse)
return O, lse
def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
dQ = torch.empty_like(Q, dtype=Q.dtype)
dK = torch.empty_like(K, dtype=K.dtype)
dV = torch.empty_like(V, dtype=V.dtype)
assert get_accelerator().on_accelerator(dO), "dO must be on cuda"
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
assert get_accelerator().on_accelerator(K), "K must be on cuda"
assert get_accelerator().on_accelerator(V), "V must be on cuda"
assert get_accelerator().on_accelerator(O), "O must be on cuda"
global kernel_
if kernel_ is None:
kernel_ = EvoformerAttnBuilder().load()
delta = torch.empty_like(lse)
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)
class EvoformerFusedAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias1=None, bias2=None):
"""
q, k, v: are in shape [*, L, H, D]
"""
bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
o, lse = _attention(q, k, v, bias1_, bias2_)
ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_)
return o
@staticmethod
def backward(ctx, grad_output):
q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2)
if bias1.numel() == 0:
dB1 = None
if bias2.numel() == 0:
dB2 = None
return dQ, dK, dV, dB1, dB2
def DS4Sci_EvoformerAttention(Q, K, V, biases):
assert len(biases) <= 2
if (len(biases) == 0):
biases.append(None)
if (len(biases) == 1):
biases.append(None)
bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2])
bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])
if biases[0] is not None:
assert biases[0].shape == bias_1_shape(Q)
else:
biases[0] = Q.new_zeros(bias_1_shape(Q))
if biases[1] is not None:
assert biases[1].shape == bias_2_shape(Q)
else:
biases[1] = Q.new_zeros(bias_2_shape(Q))
return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])