mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
[NFC] Typo fix in SP layer. (#7152)
Signed-off-by: c8ef <c8ef@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
@ -338,11 +338,11 @@ class DistributedAttention(torch.nn.Module):
|
||||
if sp_stream is not None:
|
||||
self.overlap_handles = {}
|
||||
self.sp_overlap_comm = True
|
||||
self.dafult_stream = get_accelerator().default_stream()
|
||||
self.default_stream = get_accelerator().default_stream()
|
||||
|
||||
def layer_sync(self, layer):
|
||||
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
|
||||
self.dafult_stream.wait_event(layer.done_event)
|
||||
self.default_stream.wait_event(layer.done_event)
|
||||
|
||||
def forward(self,
|
||||
query: Tensor,
|
||||
@ -374,7 +374,7 @@ class DistributedAttention(torch.nn.Module):
|
||||
def pre_hook_fun(grad):
|
||||
type = 'd' + layer_type
|
||||
self.overlap_handles[type + '_work'].wait()
|
||||
self.sp_stream.wait_stream(self.dafult_stream)
|
||||
self.sp_stream.wait_stream(self.default_stream)
|
||||
all2all_output = self.overlap_handles[type + '_grad']
|
||||
grad = list(grad)
|
||||
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
|
||||
@ -389,7 +389,7 @@ class DistributedAttention(torch.nn.Module):
|
||||
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
|
||||
self.overlap_handles, 'k')
|
||||
if self.sp_overlap_comm:
|
||||
self.dafult_stream.wait_stream(self.sp_stream)
|
||||
self.default_stream.wait_stream(self.sp_stream)
|
||||
|
||||
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
|
||||
self.overlap_handles, 'v')
|
||||
|
Reference in New Issue
Block a user