Compare commits

...

1 Commits

Author SHA1 Message Date
77de3a5ddc check 2025-06-07 18:54:46 +02:00

View File

@ -579,34 +579,27 @@ class GenerationMixin:
def _prepare_attention_mask_for_generation(
self,
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: Dict[str, Any],
inputs: torch.Tensor,
pad_token_id: Optional[torch.Tensor],
eos_token_id: Optional[torch.Tensor],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
if pad_token_id is None:
return default_attention_mask
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
@ -2053,7 +2046,7 @@ class GenerationMixin:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config, model_kwargs
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs