mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
🧶 feat: Add WeaveCallback for W&B Weave integration (#4089)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4ff8b4e007
commit
08ea00289a
@ -23,3 +23,7 @@
|
||||
## BEMACallback
|
||||
|
||||
[[autodoc]] BEMACallback
|
||||
|
||||
## WeaveCallback
|
||||
|
||||
[[autodoc]] WeaveCallback
|
||||
|
@ -95,7 +95,13 @@ _import_structure = {
|
||||
"XPOConfig",
|
||||
"XPOTrainer",
|
||||
],
|
||||
"trainer.callbacks": ["BEMACallback", "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
|
||||
"trainer.callbacks": [
|
||||
"BEMACallback",
|
||||
"MergeModelCallback",
|
||||
"RichProgressCallback",
|
||||
"SyncRefModelCallback",
|
||||
"WeaveCallback",
|
||||
],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
|
||||
}
|
||||
|
||||
@ -170,7 +176,13 @@ if TYPE_CHECKING:
|
||||
XPOConfig,
|
||||
XPOTrainer,
|
||||
)
|
||||
from .trainer.callbacks import BEMACallback, MergeModelCallback, RichProgressCallback, SyncRefModelCallback
|
||||
from .trainer.callbacks import (
|
||||
BEMACallback,
|
||||
MergeModelCallback,
|
||||
RichProgressCallback,
|
||||
SyncRefModelCallback,
|
||||
WeaveCallback,
|
||||
)
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
else:
|
||||
|
@ -27,7 +27,8 @@ LIGER_KERNEL_MIN_VERSION = "0.5.8"
|
||||
# Use same as transformers.utils.import_utils
|
||||
_deepspeed_available = _is_package_available("deepspeed")
|
||||
_fastapi_available = _is_package_available("fastapi")
|
||||
_is_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||
_llm_blender_available = _is_package_available("llm_blender")
|
||||
_mergekit_available = _is_package_available("mergekit")
|
||||
_pydantic_available = _is_package_available("pydantic")
|
||||
@ -36,7 +37,7 @@ _unsloth_available = _is_package_available("unsloth")
|
||||
_uvicorn_available = _is_package_available("uvicorn")
|
||||
_vllm_available = _is_package_available("vllm")
|
||||
_vllm_ascend_available = _is_package_available("vllm_ascend")
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_weave_available = _is_package_available("weave")
|
||||
|
||||
|
||||
def is_deepspeed_available() -> bool:
|
||||
@ -47,8 +48,12 @@ def is_fastapi_available() -> bool:
|
||||
return _fastapi_available
|
||||
|
||||
|
||||
def is_joblib_available() -> bool:
|
||||
return _joblib_available
|
||||
|
||||
|
||||
def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSION) -> bool:
|
||||
return _is_liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version)
|
||||
return _liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_llm_blender_available() -> bool:
|
||||
@ -83,8 +88,8 @@ def is_vllm_ascend_available() -> bool:
|
||||
return _vllm_ascend_available
|
||||
|
||||
|
||||
def is_joblib_available() -> bool:
|
||||
return _joblib_available
|
||||
def is_weave_available() -> bool:
|
||||
return _weave_available
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
|
@ -26,6 +26,7 @@ _import_structure = {
|
||||
"MergeModelCallback",
|
||||
"RichProgressCallback",
|
||||
"SyncRefModelCallback",
|
||||
"WeaveCallback",
|
||||
"WinRateCallback",
|
||||
],
|
||||
"cpo_config": ["CPOConfig"],
|
||||
@ -85,6 +86,7 @@ if TYPE_CHECKING:
|
||||
MergeModelCallback,
|
||||
RichProgressCallback,
|
||||
SyncRefModelCallback,
|
||||
WeaveCallback,
|
||||
WinRateCallback,
|
||||
)
|
||||
from .cpo_config import CPOConfig
|
||||
|
@ -35,7 +35,7 @@ from transformers.trainer_utils import has_length
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from ..import_utils import is_mergekit_available
|
||||
from ..import_utils import is_mergekit_available, is_weave_available
|
||||
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .judges import BasePairwiseJudge
|
||||
@ -51,6 +51,11 @@ if is_rich_available():
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_weave_available():
|
||||
import weave
|
||||
from weave import EvaluationLogger
|
||||
from weave.trace.context import weave_client_context
|
||||
|
||||
|
||||
# Logger for module-level logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -335,8 +340,6 @@ class WinRateCallback(TrainerCallback):
|
||||
self.trainer.log({"eval_win_rate": win_rate})
|
||||
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
df = _win_rate_completions_df(
|
||||
state=state,
|
||||
@ -398,8 +401,6 @@ class WinRateCallback(TrainerCallback):
|
||||
self.trainer.log({"eval_win_rate": win_rate})
|
||||
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
df = _win_rate_completions_df(
|
||||
state=state,
|
||||
@ -514,6 +515,235 @@ class LogCompletionsCallback(TrainerCallback):
|
||||
self._last_logged_step = state.global_step
|
||||
|
||||
|
||||
class WeaveCallback(TrainerCallback):
|
||||
r"""
|
||||
A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses
|
||||
https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation
|
||||
step.
|
||||
|
||||
Supports two modes based on the `scorers` parameter:
|
||||
- **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis
|
||||
- **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics
|
||||
|
||||
Both modes use Weave's EvaluationLogger for structured, consistent data logging.
|
||||
|
||||
The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more
|
||||
efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and
|
||||
skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
# Tracing mode (just log predictions)
|
||||
trainer = DPOTrainer(...)
|
||||
weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional
|
||||
trainer.add_callback(weave_callback)
|
||||
|
||||
# Or specify a project name
|
||||
weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training")
|
||||
trainer.add_callback(weave_callback)
|
||||
|
||||
|
||||
# Evaluation mode (log predictions + scores + summary)
|
||||
def accuracy_scorer(prompt: str, completion: str) -> float:
|
||||
# Your scoring logic here (metadata available via eval_attributes)
|
||||
return score
|
||||
|
||||
|
||||
weave_callback = WeaveTraceCallback(
|
||||
trainer=trainer,
|
||||
project_name="my-llm-training", # optional and needed only if weave client is not initialized
|
||||
scorers={"accuracy": accuracy_scorer},
|
||||
)
|
||||
trainer.add_callback(weave_callback)
|
||||
```
|
||||
|
||||
Args:
|
||||
trainer (`Trainer`):
|
||||
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
|
||||
column containing the prompts for generating completions.
|
||||
project_name (`str`, *optional*):
|
||||
Name of the Weave project where data will be logged. If not provided, will try to use existing weave client
|
||||
or fall back to the active wandb run's project name. Raises an error if none of these are available.
|
||||
scorers (`dict[str, Callable]`, *optional*):
|
||||
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
|
||||
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
|
||||
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
|
||||
generation_config (`GenerationConfig`, *optional*):
|
||||
Generation config to use for generating completions.
|
||||
num_prompts (`int` or `None`, *optional*):
|
||||
Number of prompts to generate completions for. If not provided, defaults to the number of examples in the
|
||||
evaluation dataset.
|
||||
dataset_name (`str`, *optional*, defaults to `"eval_dataset"`):
|
||||
Name for the dataset metadata in Weave.
|
||||
model_name (`str`, *optional*):
|
||||
Name for the model metadata in Weave. If not provided, attempts to extract from model config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
project_name: Optional[str] = None,
|
||||
scorers: Optional[dict[str, callable]] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
dataset_name: str = "eval_dataset",
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
self.trainer = trainer
|
||||
self.project_name = project_name
|
||||
self.scorers = scorers or {}
|
||||
self.generation_config = generation_config
|
||||
self.dataset_name = dataset_name
|
||||
self.model_name = model_name
|
||||
self._last_logged_step = -1
|
||||
self._weave_initialized = False
|
||||
self._eval_logger = None
|
||||
|
||||
if self.trainer.eval_dataset is None:
|
||||
raise ValueError("Trainer must have an evaluation dataset to use the WeaveCallback.")
|
||||
else:
|
||||
self.eval_dataset = self.trainer.eval_dataset
|
||||
|
||||
if num_prompts is not None:
|
||||
self.eval_dataset = self.eval_dataset.select(range(num_prompts))
|
||||
|
||||
def _initialize_weave(self):
|
||||
"""Initialize Weave and EvaluationLogger if not already initialized."""
|
||||
if not self._weave_initialized:
|
||||
if not is_weave_available():
|
||||
logger.warning("Weave is not available. Please install weave to enable logging: `pip install weave`")
|
||||
return
|
||||
|
||||
if wc := weave_client_context.get_weave_client():
|
||||
self._weave_client = wc
|
||||
else:
|
||||
if self.project_name is None:
|
||||
if is_wandb_available():
|
||||
if wandb.run is not None:
|
||||
self.project_name = wandb.run.entity + "/" + wandb.run.project
|
||||
logger.info(f"Using project name from active wandb run: {self.project_name}")
|
||||
|
||||
if self.project_name is None:
|
||||
raise ValueError(
|
||||
"No existing Weave client found and no project_name provided. "
|
||||
"Please either initialize weave with `weave.init('project-name')`, "
|
||||
"provide a project_name to the `WeaveTraceCallback`, "
|
||||
"or ensure an active wandb run exists."
|
||||
)
|
||||
|
||||
self._weave_client = weave.init(self.project_name)
|
||||
logger.info(f"Initialized Weave with project: {self.project_name}")
|
||||
|
||||
if self.model_name is None:
|
||||
self.model_name = getattr(self.trainer.model_wrapped.config, "_name_or_path", "unknown_model")
|
||||
|
||||
self._EvaluationLogger = EvaluationLogger
|
||||
|
||||
self._weave_initialized = True
|
||||
|
||||
@property
|
||||
def is_evaluation_mode(self) -> bool:
|
||||
"""True if scorers are provided (evaluation mode), False for tracing mode."""
|
||||
return bool(self.scorers)
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
"""Initialize Weave when training begins."""
|
||||
self._initialize_weave()
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
if state.global_step == self._last_logged_step:
|
||||
return
|
||||
|
||||
self._initialize_weave()
|
||||
|
||||
if not self._weave_initialized:
|
||||
logger.debug("Weave not initialized, skipping logging")
|
||||
return
|
||||
|
||||
tokenizer = kwargs["processing_class"]
|
||||
tokenizer.padding_side = "left"
|
||||
accelerator = self.trainer.accelerator
|
||||
model = self.trainer.model_wrapped
|
||||
|
||||
with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
|
||||
prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts]
|
||||
|
||||
completions = _generate_completions(
|
||||
prompts=prompts,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
accelerator=accelerator,
|
||||
generation_config=self.generation_config,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
)
|
||||
|
||||
all_prompts = gather_object(prompts)
|
||||
all_completions = gather_object(completions)
|
||||
|
||||
if self.trainer.accelerator.is_main_process:
|
||||
eval_attributes = {
|
||||
"training_step": state.global_step,
|
||||
"model_name": self.model_name,
|
||||
"generation_config": (self.generation_config.to_dict() if self.generation_config else None),
|
||||
}
|
||||
|
||||
eval_logger = self._EvaluationLogger(
|
||||
model=self.model_name,
|
||||
dataset=self.dataset_name,
|
||||
eval_attributes=eval_attributes,
|
||||
)
|
||||
|
||||
successful_predictions = 0
|
||||
total_score_values = {} # For summary statistics
|
||||
|
||||
for prompt, completion in zip(all_prompts, all_completions):
|
||||
try:
|
||||
pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion)
|
||||
|
||||
if self.is_evaluation_mode:
|
||||
for scorer_name, scorer_func in self.scorers.items():
|
||||
try:
|
||||
score = scorer_func(prompt, completion)
|
||||
pred_logger.log_score(scorer=scorer_name, score=score)
|
||||
|
||||
if scorer_name not in total_score_values:
|
||||
total_score_values[scorer_name] = []
|
||||
total_score_values[scorer_name].append(score)
|
||||
|
||||
except Exception as scorer_e:
|
||||
logger.warning(f"Failed to apply scorer '{scorer_name}': {scorer_e}")
|
||||
|
||||
pred_logger.finish()
|
||||
successful_predictions += 1
|
||||
|
||||
except Exception as pred_e:
|
||||
logger.warning(f"Failed to log prediction for prompt: {pred_e}")
|
||||
# Continue with other predictions even if one fails
|
||||
|
||||
if self.is_evaluation_mode and total_score_values:
|
||||
try:
|
||||
summary_stats = {
|
||||
"total_predictions": len(all_prompts),
|
||||
"successful_predictions": successful_predictions,
|
||||
}
|
||||
|
||||
for scorer_name, scores in total_score_values.items():
|
||||
if scores: # Only if we have valid scores
|
||||
summary_stats[f"avg_{scorer_name}"] = sum(scores) / len(scores)
|
||||
|
||||
eval_logger.log_summary(summary_stats)
|
||||
|
||||
except Exception as summary_e:
|
||||
logger.warning(f"Failed to log summary: {summary_e}")
|
||||
else:
|
||||
try:
|
||||
eval_logger.finish()
|
||||
except Exception as finish_e:
|
||||
logger.warning(f"Failed to finish evaluation logger: {finish_e}")
|
||||
|
||||
self._last_logged_step = state.global_step
|
||||
|
||||
|
||||
class MergeModelCallback(TrainerCallback):
|
||||
r"""
|
||||
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
|
||||
|
Reference in New Issue
Block a user