Files
DeepSpeed/deepspeed/ops/random_ltd/dropping_utils.py
digger yu c8d3f5eb19 fix typo in comments with deepspeed/ (#3537)
* fix spelling error with deepspeed/runtime/

* fix typo docs/

* fix typo in comments with deepspeed/

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
2023-05-15 19:20:46 +00:00

133 lines
4.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.ops.op_builder import RandomLTDBuilder
"""
Returns:
sampled_indices: [layers, batch_size, reserved_length]
new_mask: [batch_size, 1, reserved_length, reserved_length]
"""
random_ltd_module = None
def gpt_sample_tokens(reserved_length: int,
seq_length: int,
batch_size: int,
layers: int = 1,
device: str = 'cpu',
attn_mask: torch.Tensor = None):
prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
sampled_indices = torch.multinomial(prob_dist, reserved_length)
sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
# Not certain the optimized kernel is actually better here, cause it kind of screws
# with alignment right if the sequence length is not divisible by like 16
# new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length)
if attn_mask is not None:
new_mask = attn_mask[:, :, :reserved_length, :reserved_length]
else:
new_mask = None
return sampled_indices, new_mask
"""
Returns:
sampled_indices: [layers, batch_size, reserved_length]
new_mask: [layers, batch_size, 1, reserved_length, reserved_length]
"""
def bert_sample_tokens(reserved_length: int,
seq_length: int,
batch_size: int,
layers: int = 1,
device: str = 'cpu',
attn_mask: torch.Tensor = None):
assert attn_mask is not None
prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
sampled_indices = torch.multinomial(prob_dist, reserved_length)
sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
dtype = sampled_indices.dtype
sampled_indices = sampled_indices.to(torch.long)
new_mask = []
for l in range(layers):
tmp_mask_list = []
for i in range(batch_size):
mask_tmp = attn_mask[i:i + 1, :, sampled_indices[l][i], :]
tmp_mask_list.append(mask_tmp[:, :, :, sampled_indices[l][i]])
new_mask.append(torch.cat(tmp_mask_list, dim=0))
return sampled_indices.to(dtype), new_mask
class GatherTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, activations: torch.Tensor, sorted_indices: torch.Tensor, batch_first: bool):
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
ctx.save_for_backward(activations, sorted_indices)
ctx.batch_first = batch_first
return activations, random_ltd_module.token_gather(activations, sorted_indices, batch_first)
@staticmethod
def backward(ctx, a_gradients: torch.Tensor, g_gradients: torch.Tensor):
g_gradients = g_gradients.contiguous()
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
activations, sorted_indices = ctx.saved_tensors
batch_first = ctx.batch_first
return random_ltd_module.token_scatter_(a_gradients, g_gradients, sorted_indices, batch_first), None, None
class ScatterTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, all_activations: torch.Tensor, layer_activations: torch.Tensor, sorted_indices: torch.Tensor,
batch_first: bool):
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), layer_activations, sorted_indices,
batch_first)
ctx.save_for_backward(sorted_indices)
ctx.batch_first = batch_first
return scatter_results
@staticmethod
def backward(ctx, out_gradients: torch.Tensor):
out_gradients = out_gradients.contiguous()
global random_ltd_module
if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load()
sorted_indices, = ctx.saved_tensors
batch_first = ctx.batch_first
ret_val = random_ltd_module.token_gather(out_gradients, sorted_indices, batch_first)
return out_gradients, ret_val, None, None