mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Fix callable annotations (#4216)
This commit is contained in:
committed by
GitHub
parent
521db3520a
commit
a944890ff1
@ -16,6 +16,7 @@ import functools
|
|||||||
import random
|
import random
|
||||||
import signal
|
import signal
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import pytest
|
import pytest
|
||||||
@ -73,7 +74,7 @@ class TrlTestCase:
|
|||||||
self.tmp_dir = str(tmp_path)
|
self.tmp_dir = str(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
|
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator to ignore warnings with a specific message and/or category.
|
Decorator to ignore warnings with a specific message and/or category.
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.integrations import is_mlflow_available, is_wandb_available
|
from transformers.integrations import is_mlflow_available, is_wandb_available
|
||||||
@ -68,12 +68,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
|
|||||||
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
|
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
|
||||||
|
|
||||||
|
|
||||||
def profiling_decorator(func: callable) -> callable:
|
def profiling_decorator(func: Callable) -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
|
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func (`callable`):
|
func (`Callable`):
|
||||||
Function to be profiled.
|
Function to be profiled.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -583,7 +584,7 @@ class WeaveCallback(TrainerCallback):
|
|||||||
self,
|
self,
|
||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
scorers: Optional[dict[str, callable]] = None,
|
scorers: Optional[dict[str, Callable]] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
num_prompts: Optional[int] = None,
|
num_prompts: Optional[int] = None,
|
||||||
dataset_name: str = "eval_dataset",
|
dataset_name: str = "eval_dataset",
|
||||||
|
Reference in New Issue
Block a user