[pipeline] missing import regarding assisted generation (#35752)

missing import
This commit is contained in:
Joao Gante
2025-01-22 10:34:28 +00:00
committed by GitHub
parent 36c9181f5c
commit ec28957f94

View File

@ -65,6 +65,7 @@ if is_torch_available():
import torch
from torch.utils.data import DataLoader, Dataset
from ..modeling_utils import PreTrainedModel
from ..models.auto.modeling_auto import AutoModel
# Re-export for backward compatibility
@ -447,7 +448,7 @@ def load_assistant_model(
if not model.can_generate() or assistant_model is None:
return None, None
if not isinstance(model, PreTrainedModel):
if getattr(model, "framework") != "pt" or not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."