mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[pipeline] missing import regarding assisted generation (#35752)
missing import
This commit is contained in:
@ -65,6 +65,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..models.auto.modeling_auto import AutoModel
|
||||
|
||||
# Re-export for backward compatibility
|
||||
@ -447,7 +448,7 @@ def load_assistant_model(
|
||||
if not model.can_generate() or assistant_model is None:
|
||||
return None, None
|
||||
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
if getattr(model, "framework") != "pt" or not isinstance(model, PreTrainedModel):
|
||||
raise ValueError(
|
||||
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
|
||||
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
|
||||
|
Reference in New Issue
Block a user