[NFC] Typo fix in SP layer. (#7152)

Signed-off-by: c8ef <c8ef@outlook.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Connector Switch
2025-03-25 06:30:33 +08:00
committed by GitHub
parent 9ae010e629
commit 2b245a999e

View File

@ -338,11 +338,11 @@ class DistributedAttention(torch.nn.Module):
if sp_stream is not None:
self.overlap_handles = {}
self.sp_overlap_comm = True
self.dafult_stream = get_accelerator().default_stream()
self.default_stream = get_accelerator().default_stream()
def layer_sync(self, layer):
if self.sp_overlap_comm and hasattr(layer, 'done_event'):
self.dafult_stream.wait_event(layer.done_event)
self.default_stream.wait_event(layer.done_event)
def forward(self,
query: Tensor,
@ -374,7 +374,7 @@ class DistributedAttention(torch.nn.Module):
def pre_hook_fun(grad):
type = 'd' + layer_type
self.overlap_handles[type + '_work'].wait()
self.sp_stream.wait_stream(self.dafult_stream)
self.sp_stream.wait_stream(self.default_stream)
all2all_output = self.overlap_handles[type + '_grad']
grad = list(grad)
grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output)
@ -389,7 +389,7 @@ class DistributedAttention(torch.nn.Module):
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'k')
if self.sp_overlap_comm:
self.dafult_stream.wait_stream(self.sp_stream)
self.default_stream.wait_stream(self.sp_stream)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None,
self.overlap_handles, 'v')