Files
DeepSpeed/deepspeed/model_implementations/transformers/clip_encoder.py
2023-03-27 07:55:19 -04:00

76 lines
2.9 KiB
Python

'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from deepspeed.accelerator import get_accelerator
from ..features.cuda_graph import CUDAGraph
class DSClipEncoder(CUDAGraph, torch.nn.Module):
def __init__(self, enc, enable_cuda_graph=False):
super().__init__(enable_cuda_graph=enable_cuda_graph)
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
self.enc = enc
self.device = self.enc.device
self.dtype = self.enc.dtype
self.cuda_graph_created = [False, False]
self.static_inputs = [None, None]
self.static_kwargs = [None, None]
self.static_output = [None, None]
self._cuda_graphs = [None, None]
self.iter = 0
self.config = self.enc.config
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1)
mask = mask.unsqueeze(1)
return mask
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
if torch.is_tensor(inputs[i]):
self.static_inputs[self.iter][i].copy_(inputs[i])
for k in kwargs:
if torch.is_tensor(kwargs[k]):
self.static_kwargs[self.iter][k].copy_(kwargs[k])
self._cuda_graphs[self.iter].replay()
return self.static_output[self.iter]
def forward(self, *inputs, **kwargs):
if self.enable_cuda_graph:
if self.cuda_graph_created[self.iter]:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
self.iter = (self.iter + 1) % 2
return outputs
else:
return self.enc(*inputs, **kwargs)
def _create_cuda_graph(self, *inputs, **kwargs):
# warmup to create the workspace and cublas handle
cuda_stream = torch.cuda.Stream()
cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(cuda_stream):
for i in range(3):
ret = self._forward(*inputs, **kwargs)
torch.cuda.current_stream().wait_stream(cuda_stream)
# create cuda_graph and assign static_inputs and static_outputs
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
self.static_inputs[self.iter] = inputs
self.static_kwargs[self.iter] = kwargs
with torch.cuda.graph(self._cuda_graphs[self.iter]):
self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
**self.static_kwargs[self.iter])
self.cuda_graph_created[self.iter] = True
def _forward(self, *inputs, **kwargs):
return self.enc(*inputs, **kwargs)