Fix callable annotations (#4216)

This commit is contained in:
Albert Villanova del Moral
2025-10-08 21:21:21 +02:00
committed by GitHub
parent 521db3520a
commit a944890ff1
3 changed files with 7 additions and 5 deletions

View File

@ -16,6 +16,7 @@ import functools
import random
import signal
import warnings
from collections.abc import Callable
import psutil
import pytest
@ -73,7 +74,7 @@ class TrlTestCase:
self.tmp_dir = str(tmp_path)
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
"""
Decorator to ignore warnings with a specific message and/or category.

View File

@ -15,7 +15,7 @@
import contextlib
import functools
import time
from collections.abc import Generator
from collections.abc import Callable, Generator
from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available
@ -68,12 +68,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
def profiling_decorator(func: callable) -> callable:
def profiling_decorator(func: Callable) -> Callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
Args:
func (`callable`):
func (`Callable`):
Function to be profiled.
Example:

View File

@ -14,6 +14,7 @@
import logging
import os
from collections.abc import Callable
from typing import Optional, Union
import pandas as pd
@ -583,7 +584,7 @@ class WeaveCallback(TrainerCallback):
self,
trainer: Trainer,
project_name: Optional[str] = None,
scorers: Optional[dict[str, callable]] = None,
scorers: Optional[dict[str, Callable]] = None,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
dataset_name: str = "eval_dataset",