Fix callable annotations (#4216)

This commit is contained in:
Albert Villanova del Moral
2025-10-08 21:21:21 +02:00
committed by GitHub
parent 521db3520a
commit a944890ff1
3 changed files with 7 additions and 5 deletions

View File

@ -16,6 +16,7 @@ import functools
import random import random
import signal import signal
import warnings import warnings
from collections.abc import Callable
import psutil import psutil
import pytest import pytest
@ -73,7 +74,7 @@ class TrlTestCase:
self.tmp_dir = str(tmp_path) self.tmp_dir = str(tmp_path)
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable: def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
""" """
Decorator to ignore warnings with a specific message and/or category. Decorator to ignore warnings with a specific message and/or category.

View File

@ -15,7 +15,7 @@
import contextlib import contextlib
import functools import functools
import time import time
from collections.abc import Generator from collections.abc import Callable, Generator
from transformers import Trainer from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available from transformers.integrations import is_mlflow_available, is_wandb_available
@ -68,12 +68,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
def profiling_decorator(func: callable) -> callable: def profiling_decorator(func: Callable) -> Callable:
""" """
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
Args: Args:
func (`callable`): func (`Callable`):
Function to be profiled. Function to be profiled.
Example: Example:

View File

@ -14,6 +14,7 @@
import logging import logging
import os import os
from collections.abc import Callable
from typing import Optional, Union from typing import Optional, Union
import pandas as pd import pandas as pd
@ -583,7 +584,7 @@ class WeaveCallback(TrainerCallback):
self, self,
trainer: Trainer, trainer: Trainer,
project_name: Optional[str] = None, project_name: Optional[str] = None,
scorers: Optional[dict[str, callable]] = None, scorers: Optional[dict[str, Callable]] = None,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None, num_prompts: Optional[int] = None,
dataset_name: str = "eval_dataset", dataset_name: str = "eval_dataset",