Removing unnecessary device=device in modeling_llama.py (#24696)

* Update modeling_llama.py

Removing unnecessary `device=device`

* fix in all occurrences of _make_causal_mask
This commit is contained in:
Liyang90
2023-07-13 02:30:22 -07:00
committed by GitHub
parent 906afa1d5c
commit 1f6f32c243
33 changed files with 33 additions and 33 deletions

View File

@ -1620,7 +1620,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)