Compare commits

...

2 Commits

Author SHA1 Message Date
0cf09f81e5 Correct base model detection 2024-03-11 14:34:31 +00:00
f221fb3426 Support PEFT models in pipelines 2024-03-11 14:34:31 +00:00

View File

@ -1038,7 +1038,12 @@ class Pipeline(_ScikitCompat):
else:
supported_models_names.append(model.__name__)
supported_models = supported_models_names
if self.model.__class__.__name__ not in supported_models:
if "Peft" in self.model.__class__.__name__ and hasattr(self.model, "base_model"):
# Peft models wrap a base model class, so let's look at the base class instead in that case
class_name = self.model.base_model.model.__class__.__name__
else:
class_name = self.model.__class__.__name__
if class_name not in supported_models:
logger.error(
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are"
f" {supported_models}."