mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
* zero++ tutorial PR (#3783) * [Fix] _conv_flops_compute when padding is a str and stride=1 (#3169) * fix conv_flops_compute when padding is a str when stride=1 * fix error * change type of paddings to tuple * fix padding calculation * apply formatting check --------- Co-authored-by: Cheng Li <pistasable@gmail.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> * fix interpolate flops compute (#3782) * use `Flops Profiler` to test `model.generate()` (#2515) * Update profiler.py * pre-commit run --all-files * Delete .DS_Store * Delete .DS_Store * Delete .DS_Store --------- Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Cheng Li <pistasable@gmail.com> * revert PR #3611 (#3786) * bump to 0.9.6 * ZeRO++ chinese blog (#3793) * zeropp chinese blog * try better quality images * make title larger * even larger... * various fix * center captions * more fixes * fix format * remove staging trigger (#3792) * DeepSpeed-Triton for Inference (#3748) Co-authored-by: Stephen Youn <styoun@microsoft.com> Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org> Co-authored-by: Cheng Li <pistasable@gmail.com> Co-authored-by: Ethan Doe <yidoe@microsoft.com> Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> * ZeRO++ (#3784) Co-authored-by: HeyangQin <heyangqin@microsoft.com> Co-authored-by: GuanhuaWang <alexwgh333@gmail.com> Co-authored-by: cmikeh2 <connorholmes@microsoft.com> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Reza Yazdani <reyazda@microsoft.com> * adding zero++ to navigation panel of deepspeed.ai (#3796) * Add ZeRO++ Japanese blog (#3797) * zeropp chinese blog * try better quality images * make title larger * even larger... * various fix * center captions * more fixes * fix format * add ZeRO++ Japanese blog * add links --------- Co-authored-by: HeyangQin <heyangqin@microsoft.com> Co-authored-by: Conglong Li <conglong.li@gmail.com> * Bug Fixes for autotuner and flops profiler (#1880) * fix autotuner when backward is not called * fix format --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> * Missing strided copy for gated MLP (#3788) Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> * Requires grad checking. (#3789) Co-authored-by: Jeff Rasley <jerasley@microsoft.com> * bump to 0.10.0 * Fix Bug in transform.cu (#3534) * Bug fix * Fixed formatting error --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> * bug fix: triton importing error (#3799) Co-authored-by: Stephen Youn <styoun@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> * DeepSpeed4Science (#569) * Integrating evoformer attention * add cutlass version check * Updaate error message * add benchmark * Update * Update evoformer_attn.py * Update run_evoformer_test.py * Update evoformer_attn.py * Update run_evoformer_test.py * support more GPU archs * add copyright * add tests * Fix bugs * Update benchmark * update * Fix nvcc macro * clean code * fix formatting * fix yaml import * skip unit test when not compatible * fix yaml requirement * revert changes * update tutorial * update * fix formatting * fix format * skip evoformer attn in pre-compile-ops * revert changes * update tutorial * fix cutlass check * update tutorial * refactor tutorial * revise * Updated the Megatron-DS section (#565) * Updated the Megatron-DS section * minor fix * minor fix * minor fix * separate evoformer tutorial * Revised the ds4science landing page (#566) * Updated the Megatron-DS section * minor fix * minor fix * minor fix * Revised the landing page * Revised the landing page * Removing unused file * fix links image position * modify main page * fix doc --------- Co-authored-by: Shiyang Chen <csycfl@gmail.com> Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com> --------- Co-authored-by: Heyang Qin <heyangqin@microsoft.com> Co-authored-by: Bill Luo <50068224+zhiruiluo@users.noreply.github.com> Co-authored-by: Cheng Li <pistasable@gmail.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Guorun <84232793+CaffreyR@users.noreply.github.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: stephen youn <13525892+stephen-youn@users.noreply.github.com> Co-authored-by: Stephen Youn <styoun@microsoft.com> Co-authored-by: Arash Bakhtiari <arash@bakhtiari.org> Co-authored-by: Ethan Doe <yidoe@microsoft.com> Co-authored-by: yidoe <68296935+yidoe@users.noreply.github.com> Co-authored-by: GuanhuaWang <alexwgh333@gmail.com> Co-authored-by: cmikeh2 <connorholmes@microsoft.com> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Reza Yazdani <reyazda@microsoft.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: Ramya Ramineni <62723901+rraminen@users.noreply.github.com> Co-authored-by: Shiyang Chen <csycfl@gmail.com> Co-authored-by: Minjia Zhang <33713995+minjiaz@users.noreply.github.com>
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
import numpy as np
|
|
from deepspeed.ops.op_builder import EvoformerAttnBuilder
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
kernel_ = None
|
|
|
|
|
|
def _attention(Q, K, V, bias1, bias2):
|
|
assert Q.shape[-3] > 16, "seq_len must be greater than 16"
|
|
O = torch.empty_like(Q, dtype=Q.dtype)
|
|
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
|
|
assert get_accelerator().on_accelerator(K), "K must be on cuda"
|
|
assert get_accelerator().on_accelerator(V), "V must be on cuda"
|
|
assert get_accelerator().on_accelerator(bias1), "bias1 must be on cuda"
|
|
assert get_accelerator().on_accelerator(bias2), "bias2 must be on cuda"
|
|
global kernel_
|
|
if kernel_ is None:
|
|
kernel_ = EvoformerAttnBuilder().load()
|
|
nheads = Q.shape[-2]
|
|
nq = (Q.shape[-3] + 31) // 32 * 32
|
|
nb = np.prod(Q.shape[:-3])
|
|
lse = torch.empty((nb, nheads, nq), dtype=torch.float32, device=Q.device)
|
|
kernel_.attention(Q, K, V, bias1, bias2, O, lse)
|
|
return O, lse
|
|
|
|
|
|
def attention_bwd(dO, Q, K, V, O, lse, bias1, bias2):
|
|
assert max(Q.shape[-1], V.shape[-1]) <= 64, "Hidden size is too large. Need to change kMax to a larger value"
|
|
dQ = torch.empty_like(Q, dtype=Q.dtype)
|
|
dK = torch.empty_like(K, dtype=K.dtype)
|
|
dV = torch.empty_like(V, dtype=V.dtype)
|
|
assert get_accelerator().on_accelerator(dO), "dO must be on cuda"
|
|
assert get_accelerator().on_accelerator(Q), "Q must be on cuda"
|
|
assert get_accelerator().on_accelerator(K), "K must be on cuda"
|
|
assert get_accelerator().on_accelerator(V), "V must be on cuda"
|
|
assert get_accelerator().on_accelerator(O), "O must be on cuda"
|
|
global kernel_
|
|
if kernel_ is None:
|
|
kernel_ = EvoformerAttnBuilder().load()
|
|
delta = torch.empty_like(lse)
|
|
dB1 = torch.zeros_like(bias1, dtype=torch.float32)
|
|
dB2 = torch.zeros_like(bias2, dtype=torch.float32)
|
|
kernel_.attention_bwd(dO, Q, K, V, O, lse, delta, bias1, bias2, dQ, dK, dV, dB1, dB2)
|
|
return dQ, dK, dV, dB1.to(dO.dtype), dB2.to(dO.dtype)
|
|
|
|
|
|
class EvoformerFusedAttention(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, q, k, v, bias1=None, bias2=None):
|
|
"""
|
|
q, k, v: are in shape [*, L, H, D]
|
|
"""
|
|
bias1_ = bias1.contiguous() if bias1 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
|
|
bias2_ = bias2.contiguous() if bias2 is not None else torch.tensor([], dtype=q.dtype, device=q.device)
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
o, lse = _attention(q, k, v, bias1_, bias2_)
|
|
ctx.save_for_backward(q, k, v, o, lse, bias1_, bias2_)
|
|
return o
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
q, k, v, o, lse, bias1, bias2 = ctx.saved_tensors
|
|
dQ, dK, dV, dB1, dB2 = attention_bwd(grad_output, q, k, v, o, lse, bias1, bias2)
|
|
if bias1.numel() == 0:
|
|
dB1 = None
|
|
if bias2.numel() == 0:
|
|
dB2 = None
|
|
return dQ, dK, dV, dB1, dB2
|
|
|
|
|
|
def DS4Sci_EvoformerAttention(Q, K, V, biases):
|
|
assert len(biases) <= 2
|
|
|
|
if (len(biases) == 0):
|
|
biases.append(None)
|
|
|
|
if (len(biases) == 1):
|
|
biases.append(None)
|
|
|
|
bias_1_shape = lambda x: (x.shape[0], x.shape[1], 1, 1, x.shape[2])
|
|
bias_2_shape = lambda x: (x.shape[0], 1, x.shape[3], x.shape[2], x.shape[2])
|
|
|
|
if biases[0] is not None:
|
|
assert biases[0].shape == bias_1_shape(Q)
|
|
else:
|
|
biases[0] = Q.new_zeros(bias_1_shape(Q))
|
|
|
|
if biases[1] is not None:
|
|
assert biases[1].shape == bias_2_shape(Q)
|
|
else:
|
|
biases[1] = Q.new_zeros(bias_2_shape(Q))
|
|
|
|
return EvoformerFusedAttention.apply(Q, K, V, biases[0], biases[1])
|