mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
* fix spelling error with deepspeed/runtime/ * fix typo docs/ * fix typo in comments with deepspeed/ --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
|
|
from deepspeed.ops.op_builder import RandomLTDBuilder
|
|
"""
|
|
Returns:
|
|
sampled_indices: [layers, batch_size, reserved_length]
|
|
new_mask: [batch_size, 1, reserved_length, reserved_length]
|
|
"""
|
|
|
|
random_ltd_module = None
|
|
|
|
|
|
def gpt_sample_tokens(reserved_length: int,
|
|
seq_length: int,
|
|
batch_size: int,
|
|
layers: int = 1,
|
|
device: str = 'cpu',
|
|
attn_mask: torch.Tensor = None):
|
|
|
|
prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
|
|
sampled_indices = torch.multinomial(prob_dist, reserved_length)
|
|
|
|
sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
|
|
|
|
# Not certain the optimized kernel is actually better here, cause it kind of screws
|
|
# with alignment right if the sequence length is not divisible by like 16
|
|
# new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length)
|
|
if attn_mask is not None:
|
|
new_mask = attn_mask[:, :, :reserved_length, :reserved_length]
|
|
else:
|
|
new_mask = None
|
|
|
|
return sampled_indices, new_mask
|
|
|
|
|
|
"""
|
|
Returns:
|
|
sampled_indices: [layers, batch_size, reserved_length]
|
|
new_mask: [layers, batch_size, 1, reserved_length, reserved_length]
|
|
"""
|
|
|
|
|
|
def bert_sample_tokens(reserved_length: int,
|
|
seq_length: int,
|
|
batch_size: int,
|
|
layers: int = 1,
|
|
device: str = 'cpu',
|
|
attn_mask: torch.Tensor = None):
|
|
assert attn_mask is not None
|
|
prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
|
|
sampled_indices = torch.multinomial(prob_dist, reserved_length)
|
|
|
|
sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
|
|
sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
|
|
dtype = sampled_indices.dtype
|
|
|
|
sampled_indices = sampled_indices.to(torch.long)
|
|
new_mask = []
|
|
for l in range(layers):
|
|
tmp_mask_list = []
|
|
for i in range(batch_size):
|
|
mask_tmp = attn_mask[i:i + 1, :, sampled_indices[l][i], :]
|
|
tmp_mask_list.append(mask_tmp[:, :, :, sampled_indices[l][i]])
|
|
new_mask.append(torch.cat(tmp_mask_list, dim=0))
|
|
|
|
return sampled_indices.to(dtype), new_mask
|
|
|
|
|
|
class GatherTokens(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, activations: torch.Tensor, sorted_indices: torch.Tensor, batch_first: bool):
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
ctx.save_for_backward(activations, sorted_indices)
|
|
ctx.batch_first = batch_first
|
|
return activations, random_ltd_module.token_gather(activations, sorted_indices, batch_first)
|
|
|
|
@staticmethod
|
|
def backward(ctx, a_gradients: torch.Tensor, g_gradients: torch.Tensor):
|
|
|
|
g_gradients = g_gradients.contiguous()
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
activations, sorted_indices = ctx.saved_tensors
|
|
batch_first = ctx.batch_first
|
|
|
|
return random_ltd_module.token_scatter_(a_gradients, g_gradients, sorted_indices, batch_first), None, None
|
|
|
|
|
|
class ScatterTokens(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, all_activations: torch.Tensor, layer_activations: torch.Tensor, sorted_indices: torch.Tensor,
|
|
batch_first: bool):
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), layer_activations, sorted_indices,
|
|
batch_first)
|
|
|
|
ctx.save_for_backward(sorted_indices)
|
|
ctx.batch_first = batch_first
|
|
return scatter_results
|
|
|
|
@staticmethod
|
|
def backward(ctx, out_gradients: torch.Tensor):
|
|
|
|
out_gradients = out_gradients.contiguous()
|
|
global random_ltd_module
|
|
if random_ltd_module is None:
|
|
random_ltd_module = RandomLTDBuilder().load()
|
|
sorted_indices, = ctx.saved_tensors
|
|
batch_first = ctx.batch_first
|
|
|
|
ret_val = random_ltd_module.token_gather(out_gradients, sorted_indices, batch_first)
|
|
return out_gradients, ret_val, None, None
|