Files
DeepSpeed/deepspeed/compression/basic_layer.py
2023-03-27 07:55:19 -04:00

838 lines
35 KiB
Python

'''Copyright The Microsoft DeepSpeed Team'''
import torch
import math
from torch import nn
from torch.nn import init
import deepspeed.comm as dist
from .utils import TopKBinarizer, SymQuantizer, AsymQuantizer, TernaryQuantizer, BinaryQuantizer
from deepspeed.utils import logger
g_mpu = None
class QuantAct(nn.Module):
"""
Class to quantize given activations. Note that when using this function, the input acttivation quantization range will be fixed for all
tokens/images for inference. This generally will affect some accuracy but achieve better latency performance.
Parameters:
----------
act_range_momentum : float, default 0.95
Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric'
"""
def __init__(self, act_range_momentum=0.95, quant_mode='symmetric'):
super(QuantAct, self).__init__()
self.act_range_momentum = act_range_momentum
self.quant_mode = quant_mode
if quant_mode == 'symmetric':
self.act_function = SymQuantizer.apply
else:
self.act_function = AsymQuantizer.apply
self.register_buffer('x_min_max', torch.zeros(2))
def forward(self, x, num_bits, *args):
"""
x: the activation that we need to quantize
num_bits: the number of bits we need to quantize the activation to
*args: some extra arguments that are useless but needed for align with the interface of other quantization functions
"""
if self.training:
x_min = x.data.min()
x_max = x.data.max()
# Initialization
if self.x_min_max[0] == self.x_min_max[1]:
self.x_min_max[0] = x_min
self.x_min_max[1] = x_max
# if do not need momentum, please set self.act_range_momentum = 0
self.x_min_max[0] = self.x_min_max[0] * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
self.x_min_max[1] = self.x_min_max[1] * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
x_q = self.act_function(x, num_bits, self.x_min_max[0], self.x_min_max[1])
return x_q
class Embedding_Compress(nn.Embedding):
def __init__(self, *kargs):
super(Embedding_Compress, self).__init__(*kargs)
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.weight_quantization_enabled = False
def extra_repr(self):
return 'num_embeddings={}, embedding_dim={}, weight_quantization={}'.format(
self.num_embeddings, self.embedding_dim, self.weight.target_bits)
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if self.weight.target_bits >= 3:
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
elif self.weight.target_bits == 2:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
self.weight_quantizer = TernaryQuantizer.apply
elif self.weight.target_bits == 1:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
self.weight_quantizer = BinaryQuantizer.apply
# for embedding, we always use token-wise quantization
self.weight_quantize_num_groups = self.weight.size(0)
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
else:
weight = self.weight
out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
return out
class LinearLayer_Compress(nn.Linear):
"""
Linear layer with compression.
"""
def __init__(self, *kargs, bias=True):
super(LinearLayer_Compress, self).__init__(*kargs, bias=bias)
self.sparse_pruning_method = None
self.row_pruning_method = None
self.head_pruning_method = None
self.activation_quantization_method = None
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.weight_quantization_enabled = False
self.sparse_pruning_enabled = False
self.row_pruning_enabled = False
self.head_pruning_enabled = False
self.activation_quantization_enabled = False
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}, sparse pruning={}, row pruning={}, head pruning={}, activation quantization={}, weight_quantization={}'.format(
self.in_features, self.out_features, self.bias is not None, self.sparse_pruning_method is not None, \
self.row_pruning_method is not None, self.head_pruning_method is not None, self.activation_quantization_method is not None, self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.sparse_pruning_ratio = ratio
self.sparse_pruning_method = method
if method == 'l1':
weight_norm = torch.abs(self.weight.data)
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
mask = mask.view(self.weight.size())
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('sparse_pruning_mask', mask)
def enable_row_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.row_pruning_ratio = ratio
self.row_pruning_method = method
if method == 'l1':
# compute the l1 norm of each column
weight_norm = torch.norm(self.weight.data, p=1, dim=1)
mask = TopKBinarizer.apply(weight_norm, self.row_pruning_ratio, False)
mask = mask.view(-1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.row_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1))
self.row_mask_scores.data = self.row_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.row_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('row_pruning_mask', mask)
def enable_head_pruning(self, ratio, method, num_heads):
# Here, we support only topk based pruning
self.num_heads = num_heads
self.head_pruning_ratio = ratio
self.head_pruning_method = method
if method not in ['topk']:
raise NotImplementedError
else:
self.head_pruning_ratio = ratio
self.head_pruning_scores = nn.Parameter(torch.Tensor(1,
self.num_heads)) # we apply the pruning to O matrix
self.head_pruning_scores.data = self.head_pruning_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.head_pruning_scores, a=math.sqrt(5))
def fix_sparse_pruning_helper(self):
mask = self.get_mask(pruning_type='sparse')
self.weight.data = self.weight.data * mask
del self.sparse_pruning_mask
if self.sparse_pruning_method == 'topk':
del self.sparse_mask_scores
self.sparse_pruning_method = None
self.sparse_pruning_enabled = False
return None
def fix_row_col_pruning_helper(self, mask=None, dim_reduction=False):
# This function is used for row/col pruning
# particularly, if we have two back-to-back layers, F1 and F2; when
# we remove rows from F1, we also need to remove columns from F2
# However, if we only have one layer, F1, then we only need to mask pruned
# rows as 0 in F1
if mask is None:
mask = self.get_mask(pruning_type='row').bool()
if dim_reduction:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[mask.view(-1), :])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
self.out_features = self.weight.size(0)
else:
self.weight.data = self.weight.data * mask.view(-1, 1)
if self.bias is not None:
self.bias.data = self.bias.data * mask.view(-1)
del self.row_pruning_mask
if self.row_pruning_method == 'topk':
del self.row_mask_scores
self.row_pruning_method = None
else:
# this is generally for column pruning
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1)])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
self.in_features = self.weight.size(1)
mask = None
self.row_pruning_enabled = False
return mask
def fix_head_pruning_helper(self, mask=None, num_heads=None, dim_reduction=False):
# similar as row/col pruning, head pruning also needs to prune QKV which is associated with O matrix
num_heads = num_heads if num_heads else self.num_heads
if mask is None:
if self.head_pruning_method == 'topk':
mask = self.get_mask(pruning_type='head').bool()
if dim_reduction:
shape = self.weight.size(0)
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data.t().reshape(num_heads,
-1)[mask.view(-1), :].reshape(-1,
shape).t())
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
else:
shape = self.weight.size()
self.weight.data = (self.weight.data.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(
shape[1], shape[0]).t()
if self.head_pruning_method == 'topk':
del self.head_pruning_scores
self.head_pruning_method = None
else:
raise NotImplementedError
else:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
shape = self.weight.size(1)
self.weight = nn.Parameter(self.weight.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1, shape))
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data.reshape(num_heads, -1)[mask.view(-1), :].reshape(-1))
self.head_pruning_enabled = False
return mask
def get_mask(self, pruning_type='row'):
if pruning_type == 'sparse':
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
if pruning_type == 'row':
if self.row_pruning_method == 'l1':
return self.row_pruning_mask.to(self.weight.device)
elif self.row_pruning_method == 'topk':
return TopKBinarizer.apply(self.row_mask_scores, self.row_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'head':
if self.head_pruning_method == 'topk':
return TopKBinarizer.apply(self.head_pruning_scores, self.head_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if self.weight.target_bits >= 3:
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
elif self.weight.target_bits == 2:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for ternary weight quantization'
self.weight_quantizer = TernaryQuantizer.apply
elif self.weight.target_bits == 1:
assert quantization_type == 'symmetric', 'Only symmetric quantization is supported for binary weight quantization'
self.weight_quantizer = BinaryQuantizer.apply
self.weight_quantize_num_groups = num_groups
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_activation_quantization(self, bits, quantization_type, range_calibration):
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
self.activation_quantization_bits = bits
self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
if range_calibration == 'static':
self.activation_quantizer = QuantAct(quant_mode=quantization_type)
else:
if quantization_type == 'symmetric':
self.activation_quantizer = SymQuantizer.apply
else:
self.activation_quantizer = AsymQuantizer.apply
def head_pruning_reshape(self, w, mask):
shape = w.shape
return (w.t().reshape(self.num_heads, -1) * mask.view(-1, 1)).reshape(shape[1], shape[0]).t()
def forward(self, input, skip_bias_add=False):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
weight = self.weight
bias = self.bias
if self.sparse_pruning_enabled and self.sparse_pruning_method:
mask = self.get_mask(pruning_type='sparse')
weight = weight * mask.view(self.weight.size())
if self.row_pruning_enabled and self.row_pruning_method:
mask = self.get_mask(pruning_type='row')
weight = weight * mask.view(-1, 1)
if bias is not None:
bias = bias * mask.view(-1)
if self.head_pruning_enabled and self.head_pruning_method:
mask = self.get_mask(pruning_type='head')
weight = self.head_pruning_reshape(weight, mask)
if self.activation_quantization_enabled:
if 'dynamic' in self.activation_quantization_method:
num_groups = input.numel() // input.size(-1)
else:
num_groups = 1
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
if skip_bias_add:
# used for mpu linear layers
output = nn.functional.linear(input, weight, None)
return output, bias
else:
output = nn.functional.linear(input, weight, bias)
return output
class Conv2dLayer_Compress(nn.Conv2d):
"""
Conv2D layer with compression.
"""
def __init__(self, *kargs):
super(Conv2dLayer_Compress, self).__init__(*kargs)
self.sparse_pruning_method = None
self.channel_pruning_method = None
self.activation_quantization_method = None
self.weight.start_bits = None
self.weight.target_bits = None
self.weight.q_period = None
self.weight_quantization_enabled_in_forward = False
self.sparse_pruning_enabled = False
self.channel_pruning_enabled = False
self.activation_quantization_enabled = False
def __repr__(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
output = s.format(**self.__dict__)
return output + ' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'.format(
self.sparse_pruning_method is not None, self.channel_pruning_method is not None,
self.activation_quantization_method is not None, self.weight.target_bits)
def enable_sparse_pruning(self, ratio, method):
self.sparse_pruning_ratio = ratio
self.sparse_pruning_method = method
if method == 'l1':
weight_norm = torch.abs(self.weight.data)
mask = TopKBinarizer.apply(weight_norm, self.sparse_pruning_ratio, False)
mask = mask.view(self.weight.size())
mask = mask.to(self.weight.device)
elif method == 'topk':
self.sparse_mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.sparse_mask_scores.data = self.sparse_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.sparse_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('sparse_pruning_mask', mask)
def enable_channel_pruning(self, ratio, method):
# Here, we support two cases: L1 norm based pruning and topk based pruning
self.channel_pruning_ratio = ratio
self.channel_pruning_method = method
if method == 'l1':
# compute the l1 norm of each conv2d kernel (the last three dimension)
weight_norm = torch.norm(self.weight.data, p=1, dim=[1, 2, 3])
mask = TopKBinarizer.apply(weight_norm, self.channel_pruning_ratio, False)
mask = mask.view(-1, 1, 1, 1)
mask = mask.to(self.weight.device)
elif method == 'topk':
self.channel_mask_scores = nn.Parameter(torch.Tensor(self.weight.size(0), 1, 1, 1))
self.channel_mask_scores.data = self.channel_mask_scores.data.to(self.weight.device)
init.kaiming_uniform_(self.channel_mask_scores, a=math.sqrt(5))
mask = None
else:
raise NotImplementedError
self.register_buffer('channel_pruning_mask', mask)
def fix_sparse_pruning_helper(self):
mask = self.get_mask(pruning_type='sparse')
self.weight.data = self.weight.data * mask
del self.sparse_pruning_mask
if self.sparse_pruning_method == 'topk':
del self.sparse_mask_scores
self.sparse_pruning_method = None
self.sparse_pruning_enabled = False
return None
def fix_channel_pruning_helper(self, mask=None, dim_reduction=False):
if mask is None:
if self.channel_pruning_method in ['l1', 'topk']:
mask = self.get_mask(pruning_type='channel').bool()
if dim_reduction:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[mask.view(-1), ...])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
if self.bias is not None:
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
else:
self.weight.data = self.weight.data * mask.view(-1, 1, 1, 1)
if self.bias is not None:
self.bias.data = self.bias.data * mask.view(-1)
del self.channel_pruning_mask
if self.channel_pruning_method == 'topk':
del self.channel_mask_scores
self.channel_pruning_method = None
else:
raise NotImplementedError
else:
start_bits = self.weight.start_bits
target_bits = self.weight.target_bits
q_period = self.weight.q_period
self.weight = nn.Parameter(self.weight.data[:, mask.view(-1), ...])
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = q_period
mask = None
self.channel_pruning_enabled = False
return mask
def get_mask(self, pruning_type='sparse'):
if pruning_type == 'sparse':
if self.sparse_pruning_method == 'l1':
return self.sparse_pruning_mask.to(self.weight.device)
elif self.sparse_pruning_method == 'topk':
return TopKBinarizer.apply(self.sparse_mask_scores, self.sparse_pruning_ratio, False)
else:
raise NotImplementedError
elif pruning_type == 'channel':
if self.channel_pruning_method == 'l1':
return self.channel_pruning_mask.to(self.weight.device)
elif self.channel_pruning_method == 'topk':
return TopKBinarizer.apply(self.channel_mask_scores, self.channel_pruning_ratio, False)
else:
raise NotImplementedError
else:
raise NotImplementedError
def fix_weight_quantization(self):
self.weight.data = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups).data
self.weight_quantization_enabled_in_forward = False
return None
def enable_weight_quantization(self, start_bits, target_bits, quantization_period,
weight_quantization_enabled_in_forward, quantization_type, num_groups):
self.weight.start_bits = start_bits
self.weight.target_bits = target_bits
self.weight.q_period = quantization_period
self.weight_quantization_enabled_in_forward = weight_quantization_enabled_in_forward
if self.weight_quantization_enabled_in_forward:
assert self.weight.target_bits >= 4, 'Only >=4 bits weight quantization are supported during forward pass for now'
logger.warning(
"************ A lot of MoQ features are not supported in quantize_weight_in_forward mode, please consider to use DS-FP16 optimizer************"
)
if quantization_type == 'symmetric':
self.weight_quantizer = SymQuantizer.apply
else:
self.weight_quantizer = AsymQuantizer.apply
self.weight_quantize_num_groups = num_groups
def enable_activation_quantization(self, bits, quantization_type, range_calibration):
assert bits in [4, 8], 'Only 4/8 bits activation quantization are supported for now'
self.activation_quantization_bits = bits
self.activation_quantization_method = f"{quantization_type}_{range_calibration}"
if range_calibration == 'static':
self.activation_quantizer = QuantAct(quant_mode=quantization_type)
else:
if quantization_type == 'symmetric':
self.activation_quantizer = SymQuantizer.apply
else:
self.activation_quantizer = AsymQuantizer.apply
def forward(self, input):
if self.weight_quantization_enabled_in_forward and self.weight_quantization_enabled:
weight = self.weight_quantizer(self.weight, self.weight.target_bits, None, None,
self.weight_quantize_num_groups)
bias = self.bias
else:
weight = self.weight
bias = self.bias
if self.sparse_pruning_enabled and self.sparse_pruning_method:
mask = self.get_mask(pruning_type='sparse')
weight = weight * mask.view(self.weight.size())
if self.channel_pruning_enabled:
mask = self.get_mask(pruning_type='channel')
weight = weight * mask.view(-1, 1, 1, 1)
if bias is not None:
bias = bias * mask.view(-1)
if self.activation_quantization_enabled:
if 'dynamic' in self.activation_quantization_method:
num_groups = input.numel() // input[0].numel()
else:
num_groups = 1
input = self.activation_quantizer(input, self.activation_quantization_bits, None, None, num_groups)
return nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class BNLayer_Compress(nn.BatchNorm2d):
def fix_channel_pruning_helper(self, mask, dim_reduction=True):
self.weight = nn.Parameter(self.weight.data[mask.view(-1)])
self.bias = nn.Parameter(self.bias.data[mask.view(-1)])
self.running_mean = self.running_mean[mask.view(-1)]
self.running_var = self.running_var[mask.view(-1)]
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
dist.all_reduce(input_, group=group)
return input_
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
assert tensor.size()[last_dim] % num_partitions == 0
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# Split along last dimension.
world_size = dist.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous()
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
group = g_mpu.get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
class ColumnParallelLinear_Compress(LinearLayer_Compress):
def __init__(self, mpu, input_size, output_size, bias=True, gather_output=True, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
# Divide the weight matrix along the last dimension.
world_size = mpu.get_model_parallel_world_size()
assert output_size % world_size == 0
self.output_size_per_partition = output_size // world_size
super(ColumnParallelLinear_Compress, self).__init__(self.input_size, self.output_size_per_partition, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
if self.skip_bias_add:
output_parallel, bias = super().forward(input_parallel, True)
else:
output_parallel = super().forward(input_parallel)
bias = None
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output, bias
class RowParallelLinear_Compress(LinearLayer_Compress):
def __init__(self, mpu, input_size, output_size, bias=True, input_is_parallel=False, skip_bias_add=False):
# Keep input parameters
global g_mpu
g_mpu = mpu
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
# Divide the weight matrix along the last dimension.
world_size = mpu.get_model_parallel_world_size()
assert input_size % world_size == 0
self.input_size_per_partition = input_size // world_size
super(RowParallelLinear_Compress, self).__init__(self.input_size_per_partition, self.output_size, bias=bias)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel, bias = super().forward(input_parallel, True)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if not self.skip_bias_add:
if bias is not None:
output = output_ + bias
else:
output = output_
output_bias = None
else:
output = output_
output_bias = bias
return output, output_bias