mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
'''
|
|
Copyright 2022 The Microsoft DeepSpeed Team
|
|
'''
|
|
import torch
|
|
from deepspeed.accelerator import get_accelerator
|
|
from ..features.cuda_graph import CUDAGraph
|
|
|
|
|
|
class DSClipEncoder(CUDAGraph, torch.nn.Module):
|
|
|
|
def __init__(self, enc, enable_cuda_graph=False):
|
|
super().__init__(enable_cuda_graph=enable_cuda_graph)
|
|
enc.text_model._build_causal_attention_mask = self._build_causal_attention_mask
|
|
self.enc = enc
|
|
self.device = self.enc.device
|
|
self.dtype = self.enc.dtype
|
|
self.cuda_graph_created = [False, False]
|
|
self.static_inputs = [None, None]
|
|
self.static_kwargs = [None, None]
|
|
self.static_output = [None, None]
|
|
self._cuda_graphs = [None, None]
|
|
self.iter = 0
|
|
self.config = self.enc.config
|
|
|
|
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
|
|
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype, device=get_accelerator().current_device_name())
|
|
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
|
mask.triu_(1)
|
|
mask = mask.unsqueeze(1)
|
|
return mask
|
|
|
|
def _graph_replay(self, *inputs, **kwargs):
|
|
for i in range(len(inputs)):
|
|
if torch.is_tensor(inputs[i]):
|
|
self.static_inputs[self.iter][i].copy_(inputs[i])
|
|
for k in kwargs:
|
|
if torch.is_tensor(kwargs[k]):
|
|
self.static_kwargs[self.iter][k].copy_(kwargs[k])
|
|
self._cuda_graphs[self.iter].replay()
|
|
return self.static_output[self.iter]
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if self.enable_cuda_graph:
|
|
if self.cuda_graph_created[self.iter]:
|
|
outputs = self._graph_replay(*inputs, **kwargs)
|
|
else:
|
|
self._create_cuda_graph(*inputs, **kwargs)
|
|
outputs = self._graph_replay(*inputs, **kwargs)
|
|
self.iter = (self.iter + 1) % 2
|
|
return outputs
|
|
else:
|
|
return self.enc(*inputs, **kwargs)
|
|
|
|
def _create_cuda_graph(self, *inputs, **kwargs):
|
|
# warmup to create the workspace and cublas handle
|
|
cuda_stream = torch.cuda.Stream()
|
|
cuda_stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(cuda_stream):
|
|
for i in range(3):
|
|
ret = self._forward(*inputs, **kwargs)
|
|
torch.cuda.current_stream().wait_stream(cuda_stream)
|
|
|
|
# create cuda_graph and assign static_inputs and static_outputs
|
|
self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph()
|
|
self.static_inputs[self.iter] = inputs
|
|
self.static_kwargs[self.iter] = kwargs
|
|
|
|
with torch.cuda.graph(self._cuda_graphs[self.iter]):
|
|
self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter],
|
|
**self.static_kwargs[self.iter])
|
|
|
|
self.cuda_graph_created[self.iter] = True
|
|
|
|
def _forward(self, *inputs, **kwargs):
|
|
return self.enc(*inputs, **kwargs)
|