mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
Fix callable annotations (#4216)
This commit is contained in:
committed by
GitHub
parent
521db3520a
commit
a944890ff1
@ -16,6 +16,7 @@ import functools
|
||||
import random
|
||||
import signal
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
@ -73,7 +74,7 @@ class TrlTestCase:
|
||||
self.tmp_dir = str(tmp_path)
|
||||
|
||||
|
||||
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
|
||||
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
|
||||
"""
|
||||
Decorator to ignore warnings with a specific message and/or category.
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_mlflow_available, is_wandb_available
|
||||
@ -68,12 +68,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
|
||||
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
|
||||
|
||||
|
||||
def profiling_decorator(func: callable) -> callable:
|
||||
def profiling_decorator(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
|
||||
|
||||
Args:
|
||||
func (`callable`):
|
||||
func (`Callable`):
|
||||
Function to be profiled.
|
||||
|
||||
Example:
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
@ -583,7 +584,7 @@ class WeaveCallback(TrainerCallback):
|
||||
self,
|
||||
trainer: Trainer,
|
||||
project_name: Optional[str] = None,
|
||||
scorers: Optional[dict[str, callable]] = None,
|
||||
scorers: Optional[dict[str, Callable]] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
dataset_name: str = "eval_dataset",
|
||||
|
Reference in New Issue
Block a user