mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Use | for Optional and Union typing (#41646)
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
@ -16,7 +16,6 @@ import sys
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
from time import perf_counter, sleep
|
from time import perf_counter, sleep
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
# Add the parent directory to Python path to import benchmarks_entrypoint
|
# Add the parent directory to Python path to import benchmarks_entrypoint
|
||||||
@ -145,7 +144,7 @@ def run_benchmark(
|
|||||||
q = torch.empty_like(probs_sort).exponential_(1)
|
q = torch.empty_like(probs_sort).exponential_(1)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
def logits_to_probs(logits, temperature: float = 1.0, top_k: int | None = None):
|
||||||
logits = logits / max(temperature, 1e-5)
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
@ -155,7 +154,7 @@ def run_benchmark(
|
|||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
def sample(logits, temperature: float = 1.0, top_k: int | None = None):
|
||||||
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
KERNELIZATION_AVAILABLE = False
|
KERNELIZATION_AVAILABLE = False
|
||||||
@ -27,11 +27,11 @@ class BenchmarkConfig:
|
|||||||
sequence_length: int = 128,
|
sequence_length: int = 128,
|
||||||
num_tokens_to_generate: int = 128,
|
num_tokens_to_generate: int = 128,
|
||||||
attn_implementation: str = "eager",
|
attn_implementation: str = "eager",
|
||||||
sdpa_backend: Optional[str] = None,
|
sdpa_backend: str | None = None,
|
||||||
compile_mode: Optional[str] = None,
|
compile_mode: str | None = None,
|
||||||
compile_options: Optional[dict[str, Any]] = None,
|
compile_options: dict[str, Any] | None = None,
|
||||||
kernelize: bool = False,
|
kernelize: bool = False,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
skip_validity_check: bool = False,
|
skip_validity_check: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Benchmark parameters
|
# Benchmark parameters
|
||||||
@ -128,8 +128,8 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
|
|
||||||
def cross_generate_configs(
|
def cross_generate_configs(
|
||||||
attn_impl_and_sdpa_backend: list[tuple[str, Optional[str]]],
|
attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
|
||||||
compiled_mode: list[Optional[str]],
|
compiled_mode: list[str | None],
|
||||||
kernelized: list[bool],
|
kernelized: list[bool],
|
||||||
warmup_iterations: int = 5,
|
warmup_iterations: int = 5,
|
||||||
measurement_iterations: int = 20,
|
measurement_iterations: int = 20,
|
||||||
|
@ -8,7 +8,7 @@ import time
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
@ -74,7 +74,7 @@ def get_git_revision() -> str:
|
|||||||
return git_hash.readline().strip()
|
return git_hash.readline().strip()
|
||||||
|
|
||||||
|
|
||||||
def get_sdpa_backend(backend_name: Optional[str]) -> Optional[torch.nn.attention.SDPBackend]:
|
def get_sdpa_backend(backend_name: str | None) -> torch.nn.attention.SDPBackend | None:
|
||||||
"""Get the SDPA backend enum from string name."""
|
"""Get the SDPA backend enum from string name."""
|
||||||
if backend_name is None:
|
if backend_name is None:
|
||||||
return None
|
return None
|
||||||
@ -145,7 +145,7 @@ class BenchmarkRunner:
|
|||||||
"""Main benchmark runner that coordinates benchmark execution."""
|
"""Main benchmark runner that coordinates benchmark execution."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: Optional[str] = None
|
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
# Those stay constant for the whole run
|
# Those stay constant for the whole run
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
@ -156,7 +156,7 @@ class BenchmarkRunner:
|
|||||||
# Attributes that are reset for each model
|
# Attributes that are reset for each model
|
||||||
self._setup_for = ""
|
self._setup_for = ""
|
||||||
# Attributes that are reset for each run
|
# Attributes that are reset for each run
|
||||||
self.model: Optional[GenerationMixin] = None
|
self.model: GenerationMixin | None = None
|
||||||
|
|
||||||
def cleanup(self) -> None:
|
def cleanup(self) -> None:
|
||||||
del self.model
|
del self.model
|
||||||
@ -251,8 +251,8 @@ class BenchmarkRunner:
|
|||||||
def time_generate(
|
def time_generate(
|
||||||
self,
|
self,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
gpu_monitor: Optional[GPUMonitor] = None,
|
gpu_monitor: GPUMonitor | None = None,
|
||||||
) -> tuple[float, list[float], str, Optional[GPURawMetrics]]:
|
) -> tuple[float, list[float], str, GPURawMetrics | None]:
|
||||||
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
||||||
# Prepare gpu monitoring if needed
|
# Prepare gpu monitoring if needed
|
||||||
if gpu_monitor is not None:
|
if gpu_monitor is not None:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -90,14 +90,14 @@ class BenchmarkResult:
|
|||||||
e2e_latency: float,
|
e2e_latency: float,
|
||||||
token_generation_times: list[float],
|
token_generation_times: list[float],
|
||||||
decoded_output: str,
|
decoded_output: str,
|
||||||
gpu_metrics: Optional[GPURawMetrics],
|
gpu_metrics: GPURawMetrics | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.e2e_latency.append(e2e_latency)
|
self.e2e_latency.append(e2e_latency)
|
||||||
self.token_generation_times.append(token_generation_times)
|
self.token_generation_times.append(token_generation_times)
|
||||||
self.decoded_outputs.append(decoded_output)
|
self.decoded_outputs.append(decoded_output)
|
||||||
self.gpu_metrics.append(gpu_metrics)
|
self.gpu_metrics.append(gpu_metrics)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Union[None, int, float]]:
|
def to_dict(self) -> dict[str, None | int | float]:
|
||||||
# Save GPU metrics as None if it contains only None values
|
# Save GPU metrics as None if it contains only None values
|
||||||
if all(gm is None for gm in self.gpu_metrics):
|
if all(gm is None for gm in self.gpu_metrics):
|
||||||
gpu_metrics = None
|
gpu_metrics = None
|
||||||
@ -111,7 +111,7 @@ class BenchmarkResult:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Union[None, int, float]]) -> "BenchmarkResult":
|
def from_dict(cls, data: dict[str, None | int | float]) -> "BenchmarkResult":
|
||||||
# Handle GPU metrics, which is saved as None if it contains only None values
|
# Handle GPU metrics, which is saved as None if it contains only None values
|
||||||
if data["gpu_metrics"] is None:
|
if data["gpu_metrics"] is None:
|
||||||
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
|
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
|
||||||
|
@ -7,7 +7,6 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import gpustat
|
import gpustat
|
||||||
import psutil
|
import psutil
|
||||||
@ -42,7 +41,7 @@ class HardwareInfo:
|
|||||||
self.cpu_count = psutil.cpu_count()
|
self.cpu_count = psutil.cpu_count()
|
||||||
self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024))
|
self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024))
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
def to_dict(self) -> dict[str, None | int | float | str]:
|
||||||
return {
|
return {
|
||||||
"gpu_name": self.gpu_name,
|
"gpu_name": self.gpu_name,
|
||||||
"gpu_memory_total_gb": self.gpu_memory_total_gb,
|
"gpu_memory_total_gb": self.gpu_memory_total_gb,
|
||||||
@ -109,7 +108,7 @@ class GPURawMetrics:
|
|||||||
timestamp_0: float # in seconds
|
timestamp_0: float # in seconds
|
||||||
monitoring_status: GPUMonitoringStatus
|
monitoring_status: GPUMonitoringStatus
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
def to_dict(self) -> dict[str, None | int | float | str]:
|
||||||
return {
|
return {
|
||||||
"utilization": self.utilization,
|
"utilization": self.utilization,
|
||||||
"memory_used": self.memory_used,
|
"memory_used": self.memory_used,
|
||||||
@ -123,7 +122,7 @@ class GPURawMetrics:
|
|||||||
class GPUMonitor:
|
class GPUMonitor:
|
||||||
"""Monitor GPU utilization during benchmark execution."""
|
"""Monitor GPU utilization during benchmark execution."""
|
||||||
|
|
||||||
def __init__(self, sample_interval_sec: float = 0.1, logger: Optional[Logger] = None):
|
def __init__(self, sample_interval_sec: float = 0.1, logger: Logger | None = None):
|
||||||
self.sample_interval_sec = sample_interval_sec
|
self.sample_interval_sec = sample_interval_sec
|
||||||
self.logger = logger if logger is not None else logging.getLogger(__name__)
|
self.logger = logger if logger is not None else logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -1254,8 +1253,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
do_eval: bool = True,
|
do_eval: bool = True,
|
||||||
quality_checks: bool = True,
|
quality_checks: bool = True,
|
||||||
fp32: bool = False,
|
fp32: bool = False,
|
||||||
extra_args_str: Optional[str] = None,
|
extra_args_str: str | None = None,
|
||||||
remove_args_str: Optional[str] = None,
|
remove_args_str: str | None = None,
|
||||||
):
|
):
|
||||||
# we are doing quality testing so using a small real model
|
# we are doing quality testing so using a small real model
|
||||||
output_dir = self.run_trainer(
|
output_dir = self.run_trainer(
|
||||||
@ -1287,8 +1286,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
do_eval: bool = True,
|
do_eval: bool = True,
|
||||||
distributed: bool = True,
|
distributed: bool = True,
|
||||||
fp32: bool = False,
|
fp32: bool = False,
|
||||||
extra_args_str: Optional[str] = None,
|
extra_args_str: str | None = None,
|
||||||
remove_args_str: Optional[str] = None,
|
remove_args_str: str | None = None,
|
||||||
):
|
):
|
||||||
max_len = 32
|
max_len = 32
|
||||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||||
|
@ -17,7 +17,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -251,13 +250,13 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
learning_rate: float = 3e-3,
|
learning_rate: float = 3e-3,
|
||||||
optim: str = "adafactor",
|
optim: str = "adafactor",
|
||||||
distributed: bool = False,
|
distributed: bool = False,
|
||||||
extra_args_str: Optional[str] = None,
|
extra_args_str: str | None = None,
|
||||||
eval_steps: int = 0,
|
eval_steps: int = 0,
|
||||||
predict_with_generate: bool = True,
|
predict_with_generate: bool = True,
|
||||||
do_train: bool = True,
|
do_train: bool = True,
|
||||||
do_eval: bool = True,
|
do_eval: bool = True,
|
||||||
do_predict: bool = True,
|
do_predict: bool = True,
|
||||||
n_gpus_to_use: Optional[int] = None,
|
n_gpus_to_use: int | None = None,
|
||||||
):
|
):
|
||||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -43,8 +42,8 @@ class ContinuousBatchingTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
def test_group_layers(
|
def test_group_layers(
|
||||||
self,
|
self,
|
||||||
layer_types_str: Optional[str],
|
layer_types_str: str | None,
|
||||||
sliding_window: Optional[int],
|
sliding_window: int | None,
|
||||||
expected_groups: str,
|
expected_groups: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Take a config and change the layer_types attribute to the mix we want
|
# Take a config and change the layer_types attribute to the mix we want
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -90,7 +89,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
@parameterized.expand([(0,), ([0, 18],)])
|
@parameterized.expand([(0,), ([0, 18],)])
|
||||||
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, list[int]]):
|
def test_new_min_length_dist_processor(self, eos_token_id: int | list[int]):
|
||||||
vocab_size = 20
|
vocab_size = 20
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -908,7 +907,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_left_padding_compatibility(
|
def test_left_padding_compatibility(
|
||||||
self, unpadded_custom_inputs: Optional[dict] = None, padded_custom_inputs: Optional[dict] = None
|
self, unpadded_custom_inputs: dict | None = None, padded_custom_inputs: dict | None = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom
|
Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from transformers.image_utils import load_image
|
from transformers.image_utils import load_image
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
@ -36,14 +35,14 @@ class BridgeTowerImageProcessingTester:
|
|||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
size: Optional[dict[str, int]] = None,
|
size: dict[str, int] | None = None,
|
||||||
size_divisor: int = 32,
|
size_divisor: int = 32,
|
||||||
do_rescale: bool = True,
|
do_rescale: bool = True,
|
||||||
rescale_factor: Union[int, float] = 1 / 255,
|
rescale_factor: int | float = 1 / 255,
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
do_center_crop: bool = True,
|
do_center_crop: bool = True,
|
||||||
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
|
||||||
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
|
||||||
do_pad: bool = True,
|
do_pad: bool = True,
|
||||||
batch_size=7,
|
batch_size=7,
|
||||||
min_resolution=30,
|
min_resolution=30,
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import Gemma3Processor, GemmaTokenizer
|
from transformers import Gemma3Processor, GemmaTokenizer
|
||||||
from transformers.testing_utils import get_tests_dir, require_vision
|
from transformers.testing_utils import get_tests_dir, require_vision
|
||||||
@ -83,7 +82,7 @@ class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
} # fmt: skip
|
} # fmt: skip
|
||||||
|
|
||||||
# Override as Gemma3 needs images to be an explicitly nested batch
|
# Override as Gemma3 needs images to be an explicitly nested batch
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of PIL images for testing"""
|
"""This function prepares a list of PIL images for testing"""
|
||||||
images = super().prepare_image_inputs(batch_size)
|
images = super().prepare_image_inputs(batch_size)
|
||||||
if isinstance(images, (list, tuple)):
|
if isinstance(images, (list, tuple)):
|
||||||
|
@ -19,7 +19,6 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
@ -78,8 +77,8 @@ class Gemma3nAudioFeatureExtractionTester:
|
|||||||
dither: float = 0.0,
|
dither: float = 0.0,
|
||||||
input_scale_factor: float = 1.0,
|
input_scale_factor: float = 1.0,
|
||||||
mel_floor: float = 1e-5,
|
mel_floor: float = 1e-5,
|
||||||
per_bin_mean: Optional[Sequence[float]] = None,
|
per_bin_mean: Sequence[float] | None = None,
|
||||||
per_bin_stddev: Optional[Sequence[float]] = None,
|
per_bin_stddev: Sequence[float] | None = None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -17,7 +17,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -79,7 +78,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
cls.embed_dim = 5
|
cls.embed_dim = 5
|
||||||
cls.seq_length = 5
|
cls.seq_length = 5
|
||||||
|
|
||||||
def prepare_text_inputs(self, batch_size: Optional[int] = None, **kwargs):
|
def prepare_text_inputs(self, batch_size: int | None = None, **kwargs):
|
||||||
labels = ["a cat", "remote control"]
|
labels = ["a cat", "remote control"]
|
||||||
labels_longer = ["a person", "a car", "a dog", "a cat"]
|
labels_longer = ["a person", "a car", "a dog", "a cat"]
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152
|
# taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152
|
||||||
def pad_to_square_original(
|
def pad_to_square_original(
|
||||||
image: Image.Image, background_color: Union[int, tuple[int, int, int]] = 0
|
image: Image.Image, background_color: int | tuple[int, int, int] = 0
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width == height:
|
if width == height:
|
||||||
|
@ -16,7 +16,6 @@ import json
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -60,7 +59,7 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
|
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
|
||||||
|
|
||||||
# Override as Mllama needs images to be an explicitly nested batch
|
# Override as Mllama needs images to be an explicitly nested batch
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of PIL images for testing"""
|
"""This function prepares a list of PIL images for testing"""
|
||||||
images = super().prepare_image_inputs(batch_size)
|
images = super().prepare_image_inputs(batch_size)
|
||||||
if isinstance(images, (list, tuple)):
|
if isinstance(images, (list, tuple)):
|
||||||
|
@ -18,7 +18,6 @@ import itertools
|
|||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -87,17 +86,17 @@ class PatchTSMixerModelTester:
|
|||||||
masked_loss: bool = False,
|
masked_loss: bool = False,
|
||||||
mask_mode: str = "mask_before_encoder",
|
mask_mode: str = "mask_before_encoder",
|
||||||
channel_consistent_masking: bool = True,
|
channel_consistent_masking: bool = True,
|
||||||
scaling: Optional[Union[str, bool]] = "std",
|
scaling: str | bool | None = "std",
|
||||||
# Head related
|
# Head related
|
||||||
head_dropout: float = 0.2,
|
head_dropout: float = 0.2,
|
||||||
# forecast related
|
# forecast related
|
||||||
prediction_length: int = 16,
|
prediction_length: int = 16,
|
||||||
out_channels: Optional[int] = None,
|
out_channels: int | None = None,
|
||||||
# Classification/regression related
|
# Classification/regression related
|
||||||
# num_labels: int = 3,
|
# num_labels: int = 3,
|
||||||
num_targets: int = 3,
|
num_targets: int = 3,
|
||||||
output_range: Optional[list] = None,
|
output_range: list | None = None,
|
||||||
head_aggregation: Optional[str] = None,
|
head_aggregation: str | None = None,
|
||||||
# Trainer related
|
# Trainer related
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -94,14 +93,14 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Override as SmolVLM needs images/video to be an explicitly nested batch
|
# Override as SmolVLM needs images/video to be an explicitly nested batch
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of PIL images for testing"""
|
"""This function prepares a list of PIL images for testing"""
|
||||||
images = super().prepare_image_inputs(batch_size)
|
images = super().prepare_image_inputs(batch_size)
|
||||||
if isinstance(images, (list, tuple)):
|
if isinstance(images, (list, tuple)):
|
||||||
images = [[image] for image in images]
|
images = [[image] for image in images]
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def prepare_video_inputs(self, batch_size: Optional[int] = None):
|
def prepare_video_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of numpy videos."""
|
"""This function prepares a list of numpy videos."""
|
||||||
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -41,16 +40,16 @@ class TvpImageProcessingTester:
|
|||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
size: dict[str, int] = {"longest_edge": 40},
|
size: dict[str, int] = {"longest_edge": 40},
|
||||||
do_center_crop: bool = False,
|
do_center_crop: bool = False,
|
||||||
crop_size: Optional[dict[str, int]] = None,
|
crop_size: dict[str, int] | None = None,
|
||||||
do_rescale: bool = False,
|
do_rescale: bool = False,
|
||||||
rescale_factor: Union[int, float] = 1 / 255,
|
rescale_factor: int | float = 1 / 255,
|
||||||
do_pad: bool = True,
|
do_pad: bool = True,
|
||||||
pad_size: dict[str, int] = {"height": 80, "width": 80},
|
pad_size: dict[str, int] = {"height": 80, "width": 80},
|
||||||
fill: Optional[int] = None,
|
fill: int | None = None,
|
||||||
pad_mode: Optional[PaddingMode] = None,
|
pad_mode: PaddingMode | None = None,
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
|
||||||
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
min_resolution=40,
|
min_resolution=40,
|
||||||
max_resolution=80,
|
max_resolution=80,
|
||||||
|
@ -17,7 +17,6 @@ import inspect
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -79,7 +78,7 @@ class VideoLlama3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of PIL images for testing"""
|
"""This function prepares a list of PIL images for testing"""
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
return prepare_image_inputs()[0]
|
return prepare_image_inputs()[0]
|
||||||
|
@ -20,7 +20,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
@ -70,7 +69,7 @@ class DataTrainingArguments:
|
|||||||
the command line.
|
the command line.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_name: Optional[str] = field(
|
task_name: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
||||||
)
|
)
|
||||||
@ -95,7 +94,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_train_samples: Optional[int] = field(
|
max_train_samples: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -104,7 +103,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_val_samples: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -113,7 +112,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_test_samples: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@ -122,13 +121,13 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
train_file: Optional[str] = field(
|
train_file: str | None = field(
|
||||||
default=None, metadata={"help": "A csv or a json file containing the training data."}
|
default=None, metadata={"help": "A csv or a json file containing the training data."}
|
||||||
)
|
)
|
||||||
validation_file: Optional[str] = field(
|
validation_file: str | None = field(
|
||||||
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
||||||
)
|
)
|
||||||
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
test_file: str | None = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.task_name is not None:
|
if self.task_name is not None:
|
||||||
@ -155,13 +154,13 @@ class ModelArguments:
|
|||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
)
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: str | None = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
tokenizer_name: Optional[str] = field(
|
tokenizer_name: str | None = field(
|
||||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||||
)
|
)
|
||||||
|
@ -19,7 +19,6 @@ import random
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -136,7 +135,7 @@ class ProcessorTesterMixin:
|
|||||||
processor = self.processor_class(**components, **self.prepare_processor_dict())
|
processor = self.processor_class(**components, **self.prepare_processor_dict())
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
def prepare_text_inputs(self, batch_size: Optional[int] = None, modalities: Optional[Union[str, list]] = None):
|
def prepare_text_inputs(self, batch_size: int | None = None, modalities: str | list | None = None):
|
||||||
if isinstance(modalities, str):
|
if isinstance(modalities, str):
|
||||||
modalities = [modalities]
|
modalities = [modalities]
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ class ProcessorTesterMixin:
|
|||||||
] * (batch_size - 2)
|
] * (batch_size - 2)
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of PIL images for testing"""
|
"""This function prepares a list of PIL images for testing"""
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
return prepare_image_inputs()[0]
|
return prepare_image_inputs()[0]
|
||||||
@ -167,7 +166,7 @@ class ProcessorTesterMixin:
|
|||||||
return prepare_image_inputs() * batch_size
|
return prepare_image_inputs() * batch_size
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
def prepare_video_inputs(self, batch_size: Optional[int] = None):
|
def prepare_video_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of numpy videos."""
|
"""This function prepares a list of numpy videos."""
|
||||||
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
||||||
video_input = np.array(video_input)
|
video_input = np.array(video_input)
|
||||||
@ -175,7 +174,7 @@ class ProcessorTesterMixin:
|
|||||||
return video_input
|
return video_input
|
||||||
return [video_input] * batch_size
|
return [video_input] * batch_size
|
||||||
|
|
||||||
def prepare_audio_inputs(self, batch_size: Optional[int] = None):
|
def prepare_audio_inputs(self, batch_size: int | None = None):
|
||||||
"""This function prepares a list of numpy audio."""
|
"""This function prepares a list of numpy audio."""
|
||||||
raw_speech = floats_list((1, 1000))
|
raw_speech = floats_list((1, 1000))
|
||||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||||
|
@ -26,7 +26,7 @@ import unittest
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import takewhile
|
from itertools import takewhile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ def _test_subword_regularization_tokenizer(in_queue, out_queue, timeout):
|
|||||||
|
|
||||||
def check_subword_sampling(
|
def check_subword_sampling(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
text: Optional[str] = None,
|
text: str | None = None,
|
||||||
test_sentencepiece_ignore_case: bool = True,
|
test_sentencepiece_ignore_case: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -313,9 +313,9 @@ class TokenizerTesterMixin:
|
|||||||
self,
|
self,
|
||||||
expected_encoding: dict,
|
expected_encoding: dict,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
revision: Optional[str] = None,
|
revision: str | None = None,
|
||||||
sequences: Optional[list[str]] = None,
|
sequences: list[str] | None = None,
|
||||||
decode_kwargs: Optional[dict[str, Any]] = None,
|
decode_kwargs: dict[str, Any] | None = None,
|
||||||
padding: bool = True,
|
padding: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -16,7 +16,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.testing_utils import require_torch, slow
|
from transformers.testing_utils import require_torch, slow
|
||||||
@ -32,9 +31,9 @@ class TestCodeExamples(unittest.TestCase):
|
|||||||
def analyze_directory(
|
def analyze_directory(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
identifier: Union[str, None] = None,
|
identifier: str | None = None,
|
||||||
ignore_files: Union[list[str], None] = None,
|
ignore_files: list[str] | None = None,
|
||||||
n_identifier: Union[str, list[str], None] = None,
|
n_identifier: str | list[str] | None = None,
|
||||||
only_modules: bool = True,
|
only_modules: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -22,7 +22,7 @@ from argparse import Namespace
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union, get_args, get_origin
|
from typing import Literal, Union, get_args, get_origin
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -59,7 +59,7 @@ class WithDefaultExample:
|
|||||||
class WithDefaultBoolExample:
|
class WithDefaultBoolExample:
|
||||||
foo: bool = False
|
foo: bool = False
|
||||||
baz: bool = True
|
baz: bool = True
|
||||||
opt: Optional[bool] = None
|
opt: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class BasicEnum(Enum):
|
class BasicEnum(Enum):
|
||||||
@ -91,11 +91,11 @@ class MixedTypeEnumExample:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OptionalExample:
|
class OptionalExample:
|
||||||
foo: Optional[int] = None
|
foo: int | None = None
|
||||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
bar: float | None = field(default=None, metadata={"help": "help message"})
|
||||||
baz: Optional[str] = None
|
baz: str | None = None
|
||||||
ces: Optional[list[str]] = list_field(default=[])
|
ces: list[str] | None = list_field(default=[])
|
||||||
des: Optional[list[int]] = list_field(default=[])
|
des: list[int] | None = list_field(default=[])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -120,7 +120,7 @@ class RequiredExample:
|
|||||||
class StringLiteralAnnotationExample:
|
class StringLiteralAnnotationExample:
|
||||||
foo: int
|
foo: int
|
||||||
required_enum: "BasicEnum" = field()
|
required_enum: "BasicEnum" = field()
|
||||||
opt: "Optional[bool]" = None
|
opt: "bool | None" = None
|
||||||
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
||||||
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -46,7 +45,7 @@ if is_vision_available():
|
|||||||
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
|
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
|
||||||
|
|
||||||
|
|
||||||
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: Optional[str] = None) -> "PIL.Image.Image":
|
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: str | None = None) -> "PIL.Image.Image":
|
||||||
url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision)
|
url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision)
|
||||||
return PIL.Image.open(BytesIO(httpx.get(url, follow_redirects=True).content))
|
return PIL.Image.open(BytesIO(httpx.get(url, follow_redirects=True).content))
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -31,8 +30,8 @@ if is_torch_available():
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelOutputTest(ModelOutput):
|
class ModelOutputTest(ModelOutput):
|
||||||
a: float
|
a: float
|
||||||
b: Optional[float] = None
|
b: float | None = None
|
||||||
c: Optional[float] = None
|
c: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputTester(unittest.TestCase):
|
class ModelOutputTester(unittest.TestCase):
|
||||||
@ -182,8 +181,8 @@ class ModelOutputTestNoDataclass(ModelOutput):
|
|||||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||||
|
|
||||||
a: float
|
a: float
|
||||||
b: Optional[float] = None
|
b: float | None = None
|
||||||
c: Optional[float] = None
|
c: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputSubclassTester(unittest.TestCase):
|
class ModelOutputSubclassTester(unittest.TestCase):
|
||||||
|
@ -3,7 +3,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from huggingface_hub import paper_info
|
from huggingface_hub import paper_info
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ def get_modified_cards() -> list[str]:
|
|||||||
return model_names
|
return model_names
|
||||||
|
|
||||||
|
|
||||||
def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
|
def get_paper_link(model_card: str | None, path: str | None) -> str:
|
||||||
"""Get the first paper link from the model card content."""
|
"""Get the first paper link from the model card content."""
|
||||||
|
|
||||||
if model_card is not None and not model_card.endswith(".md"):
|
if model_card is not None and not model_card.endswith(".md"):
|
||||||
@ -91,7 +90,7 @@ def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
|
|||||||
return paper_ids[0]
|
return paper_ids[0]
|
||||||
|
|
||||||
|
|
||||||
def get_first_commit_date(model_name: Optional[str]) -> str:
|
def get_first_commit_date(model_name: str | None) -> str:
|
||||||
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
|
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
|
||||||
|
|
||||||
if model_name.endswith(".md"):
|
if model_name.endswith(".md"):
|
||||||
|
@ -42,7 +42,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from transformers.utils import direct_transformers_import
|
from transformers.utils import direct_transformers_import
|
||||||
|
|
||||||
@ -384,8 +383,8 @@ def split_code_into_blocks(
|
|||||||
|
|
||||||
|
|
||||||
def find_code_in_transformers(
|
def find_code_in_transformers(
|
||||||
object_name: str, base_path: Optional[str] = None, return_indices: bool = False
|
object_name: str, base_path: str | None = None, return_indices: bool = False
|
||||||
) -> Union[str, tuple[list[str], int, int]]:
|
) -> str | tuple[list[str], int, int]:
|
||||||
"""
|
"""
|
||||||
Find and return the source code of an object.
|
Find and return the source code of an object.
|
||||||
|
|
||||||
@ -485,7 +484,7 @@ def replace_code(code: str, replace_pattern: str) -> str:
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
def find_code_and_splits(object_name: str, base_path: str, buffer: Optional[dict] = None):
|
def find_code_and_splits(object_name: str, base_path: str, buffer: dict | None = None):
|
||||||
"""Find the code of an object (specified by `object_name`) and split it into blocks.
|
"""Find the code of an object (specified by `object_name`) and split it into blocks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -581,7 +580,7 @@ def stylify(code: str) -> str:
|
|||||||
return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
|
return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
|
||||||
|
|
||||||
|
|
||||||
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
|
def check_codes_match(observed_code: str, theoretical_code: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
Checks if two version of a code match with the exception of the class/function name.
|
Checks if two version of a code match with the exception of the class/function name.
|
||||||
|
|
||||||
@ -633,8 +632,8 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
|
|||||||
|
|
||||||
|
|
||||||
def is_copy_consistent(
|
def is_copy_consistent(
|
||||||
filename: str, overwrite: bool = False, buffer: Optional[dict] = None
|
filename: str, overwrite: bool = False, buffer: dict | None = None
|
||||||
) -> Optional[list[tuple[str, int]]]:
|
) -> list[tuple[str, int]] | None:
|
||||||
"""
|
"""
|
||||||
Check if the code commented as a copy in a file matches the original.
|
Check if the code commented as a copy in a file matches the original.
|
||||||
|
|
||||||
@ -826,7 +825,7 @@ def is_copy_consistent(
|
|||||||
return diffs
|
return diffs
|
||||||
|
|
||||||
|
|
||||||
def check_copies(overwrite: bool = False, file: Optional[str] = None):
|
def check_copies(overwrite: bool = False, file: str | None = None):
|
||||||
"""
|
"""
|
||||||
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
||||||
READMEs are consistent.
|
READMEs are consistent.
|
||||||
|
@ -43,7 +43,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from check_repo import ignore_undocumented
|
from check_repo import ignore_undocumented
|
||||||
from git import Repo
|
from git import Repo
|
||||||
@ -525,7 +525,7 @@ def stringify_default(default: Any) -> str:
|
|||||||
return f"`{default}`"
|
return f"`{default}`"
|
||||||
|
|
||||||
|
|
||||||
def eval_math_expression(expression: str) -> Optional[Union[float, int]]:
|
def eval_math_expression(expression: str) -> float | int | None:
|
||||||
# Mainly taken from the excellent https://stackoverflow.com/a/9558001
|
# Mainly taken from the excellent https://stackoverflow.com/a/9558001
|
||||||
"""
|
"""
|
||||||
Evaluate (safely) a mathematial expression and returns its value.
|
Evaluate (safely) a mathematial expression and returns its value.
|
||||||
@ -673,7 +673,7 @@ def find_source_file(obj: Any) -> Path:
|
|||||||
return obj_file.with_suffix(".py")
|
return obj_file.with_suffix(".py")
|
||||||
|
|
||||||
|
|
||||||
def match_docstring_with_signature(obj: Any) -> Optional[tuple[str, str]]:
|
def match_docstring_with_signature(obj: Any) -> tuple[str, str] | None:
|
||||||
"""
|
"""
|
||||||
Matches the docstring of an object with its signature.
|
Matches the docstring of an object with its signature.
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ python utils/check_dummies.py --fix_and_overwrite
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||||
@ -73,7 +72,7 @@ def {0}(*args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def find_backend(line: str) -> Optional[str]:
|
def find_backend(line: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Find one (or multiple) backend in a code line of the init.
|
Find one (or multiple) backend in a code line of the init.
|
||||||
|
|
||||||
@ -156,7 +155,7 @@ def create_dummy_object(name: str, backend_name: str) -> str:
|
|||||||
return DUMMY_CLASS.format(name, backend_name)
|
return DUMMY_CLASS.format(name, backend_name)
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_files(backend_specific_objects: Optional[dict[str, list[str]]] = None) -> dict[str, str]:
|
def create_dummy_files(backend_specific_objects: dict[str, list[str]] | None = None) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Create the content of the dummy files.
|
Create the content of the dummy files.
|
||||||
|
|
||||||
|
@ -39,7 +39,6 @@ import collections
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
# Path is set with the intent you should run this script from the root of the repo.
|
# Path is set with the intent you should run this script from the root of the repo.
|
||||||
@ -70,7 +69,7 @@ _re_try = re.compile(r"^\s*try:")
|
|||||||
_re_else = re.compile(r"^\s*else:")
|
_re_else = re.compile(r"^\s*else:")
|
||||||
|
|
||||||
|
|
||||||
def find_backend(line: str) -> Optional[str]:
|
def find_backend(line: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Find one (or multiple) backend in a code line of the init.
|
Find one (or multiple) backend in a code line of the init.
|
||||||
|
|
||||||
@ -89,7 +88,7 @@ def find_backend(line: str) -> Optional[str]:
|
|||||||
return "_and_".join(backends)
|
return "_and_".join(backends)
|
||||||
|
|
||||||
|
|
||||||
def parse_init(init_file) -> Optional[tuple[dict[str, list[str]], dict[str, list[str]]]]:
|
def parse_init(init_file) -> tuple[dict[str, list[str]], dict[str, list[str]]] | None:
|
||||||
"""
|
"""
|
||||||
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
||||||
defined.
|
defined.
|
||||||
|
@ -39,7 +39,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
# Path is defined with the intent you should run this script from the root of the repo.
|
# Path is defined with the intent you should run this script from the root of the repo.
|
||||||
@ -64,7 +64,7 @@ def get_indent(line: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def split_code_in_indented_blocks(
|
def split_code_in_indented_blocks(
|
||||||
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
|
code: str, indent_level: str = "", start_prompt: str | None = None, end_prompt: str | None = None
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Split some code into its indented blocks, starting at a given level.
|
Split some code into its indented blocks, starting at a given level.
|
||||||
@ -141,7 +141,7 @@ def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any]
|
|||||||
return _inner
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
def sort_objects(objects: list[Any], key: Optional[Callable[[Any], str]] = None) -> list[Any]:
|
def sort_objects(objects: list[Any], key: Callable[[Any], str] | None = None) -> list[Any]:
|
||||||
"""
|
"""
|
||||||
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
||||||
last).
|
last).
|
||||||
|
@ -9,7 +9,6 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from custom_init_isort import sort_imports_in_all_inits
|
from custom_init_isort import sort_imports_in_all_inits
|
||||||
@ -77,7 +76,7 @@ def insert_tip_to_model_doc(model_doc_path, tip_message):
|
|||||||
f.write("\n".join(new_model_lines))
|
f.write("\n".join(new_model_lines))
|
||||||
|
|
||||||
|
|
||||||
def get_model_doc_path(model: str) -> tuple[Optional[str], Optional[str]]:
|
def get_model_doc_path(model: str) -> tuple[str | None, str | None]:
|
||||||
# Possible variants of the model name in the model doc path
|
# Possible variants of the model name in the model doc path
|
||||||
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
|
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -85,8 +84,8 @@ def handle_suite(
|
|||||||
machine_type: str,
|
machine_type: str,
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
tmp_cache: str = "",
|
tmp_cache: str = "",
|
||||||
resume_at: Optional[str] = None,
|
resume_at: str | None = None,
|
||||||
only_in: Optional[list[str]] = None,
|
only_in: list[str] | None = None,
|
||||||
cpu_tests: bool = False,
|
cpu_tests: bool = False,
|
||||||
process_id: int = 1,
|
process_id: int = 1,
|
||||||
total_processes: int = 1,
|
total_processes: int = 1,
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ EXCLUDED_EXTERNAL_FILES = {
|
|||||||
def convert_relative_import_to_absolute(
|
def convert_relative_import_to_absolute(
|
||||||
import_node: cst.ImportFrom,
|
import_node: cst.ImportFrom,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
package_name: Optional[str] = "transformers",
|
package_name: str | None = "transformers",
|
||||||
) -> cst.ImportFrom:
|
) -> cst.ImportFrom:
|
||||||
"""
|
"""
|
||||||
Convert a relative libcst.ImportFrom node into an absolute one,
|
Convert a relative libcst.ImportFrom node into an absolute one,
|
||||||
@ -51,7 +50,7 @@ def convert_relative_import_to_absolute(
|
|||||||
base_parts = module_parts[:-rel_level]
|
base_parts = module_parts[:-rel_level]
|
||||||
|
|
||||||
# Flatten the module being imported (if any)
|
# Flatten the module being imported (if any)
|
||||||
def flatten_module(module: Optional[cst.BaseExpression]) -> list[str]:
|
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
|
||||||
if not module:
|
if not module:
|
||||||
return []
|
return []
|
||||||
if isinstance(module, cst.Name):
|
if isinstance(module, cst.Name):
|
||||||
@ -76,7 +75,7 @@ def convert_relative_import_to_absolute(
|
|||||||
full_parts = [file_parts[pkg_index - 1]] + full_parts
|
full_parts = [file_parts[pkg_index - 1]] + full_parts
|
||||||
|
|
||||||
# Build the dotted module path
|
# Build the dotted module path
|
||||||
dotted_module: Optional[cst.BaseExpression] = None
|
dotted_module: cst.BaseExpression | None = None
|
||||||
for part in full_parts:
|
for part in full_parts:
|
||||||
name = cst.Name(part)
|
name = cst.Name(part)
|
||||||
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
|
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
|
||||||
|
@ -22,7 +22,6 @@ import subprocess
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import Counter, defaultdict, deque
|
from collections import Counter, defaultdict, deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
from create_dependency_mapping import find_priority_list
|
from create_dependency_mapping import find_priority_list
|
||||||
@ -177,7 +176,7 @@ DOCSTRING_NODE = m.SimpleStatementLine(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]:
|
def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None:
|
||||||
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
|
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
|
||||||
successive value of an Attribute are not Name nodes, return `None`."""
|
successive value of an Attribute are not Name nodes, return `None`."""
|
||||||
if m.matches(node, m.Name()):
|
if m.matches(node, m.Name()):
|
||||||
@ -378,11 +377,11 @@ class ReplaceSuperCallTransformer(cst.CSTTransformer):
|
|||||||
|
|
||||||
def find_all_dependencies(
|
def find_all_dependencies(
|
||||||
dependency_mapping: dict[str, set],
|
dependency_mapping: dict[str, set],
|
||||||
start_entity: Optional[str] = None,
|
start_entity: str | None = None,
|
||||||
initial_dependencies: Optional[set] = None,
|
initial_dependencies: set | None = None,
|
||||||
initial_checked_dependencies: Optional[set] = None,
|
initial_checked_dependencies: set | None = None,
|
||||||
return_parent: bool = False,
|
return_parent: bool = False,
|
||||||
) -> Union[list, set]:
|
) -> list | set:
|
||||||
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
|
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
|
||||||
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
|
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
|
||||||
|
|
||||||
@ -476,7 +475,7 @@ class ClassDependencyMapper(CSTVisitor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None
|
self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.class_name = class_name
|
self.class_name = class_name
|
||||||
@ -504,7 +503,7 @@ def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> s
|
|||||||
|
|
||||||
|
|
||||||
def augmented_dependencies_for_class_node(
|
def augmented_dependencies_for_class_node(
|
||||||
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None
|
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None
|
||||||
) -> set:
|
) -> set:
|
||||||
"""Create augmented dependencies for a class node based on a `mapper`.
|
"""Create augmented dependencies for a class node based on a `mapper`.
|
||||||
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
|
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
|
||||||
@ -1659,8 +1658,8 @@ def get_class_node_and_dependencies(
|
|||||||
|
|
||||||
def create_modules(
|
def create_modules(
|
||||||
modular_mapper: ModularFileMapper,
|
modular_mapper: ModularFileMapper,
|
||||||
file_path: Optional[str] = None,
|
file_path: str | None = None,
|
||||||
package_name: Optional[str] = "transformers",
|
package_name: str | None = "transformers",
|
||||||
) -> dict[str, cst.Module]:
|
) -> dict[str, cst.Module]:
|
||||||
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary."""
|
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary."""
|
||||||
files = defaultdict(dict)
|
files = defaultdict(dict)
|
||||||
@ -1747,7 +1746,7 @@ def run_ruff(code, check=False):
|
|||||||
return stdout.decode()
|
return stdout.decode()
|
||||||
|
|
||||||
|
|
||||||
def convert_modular_file(modular_file: str, source_library: Optional[str] = "transformers") -> dict[str, str]:
|
def convert_modular_file(modular_file: str, source_library: str | None = "transformers") -> dict[str, str]:
|
||||||
"""Convert a `modular_file` into all the different model-specific files it depicts."""
|
"""Convert a `modular_file` into all the different model-specific files it depicts."""
|
||||||
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||||
output = {}
|
output = {}
|
||||||
@ -1818,7 +1817,7 @@ def count_loc(file_path: str) -> int:
|
|||||||
return len([line for line in comment_less_code.split("\n") if line.strip()])
|
return len([line for line in comment_less_code.split("\n") if line.strip()])
|
||||||
|
|
||||||
|
|
||||||
def run_converter(modular_file: str, source_library: Optional[str] = "transformers"):
|
def run_converter(modular_file: str, source_library: str | None = "transformers"):
|
||||||
"""Convert a modular file, and save resulting files."""
|
"""Convert a modular file, and save resulting files."""
|
||||||
print(f"Converting {modular_file} to a single model single file format")
|
print(f"Converting {modular_file} to a single model single file format")
|
||||||
converted_files = convert_modular_file(modular_file, source_library=source_library)
|
converted_files = convert_modular_file(modular_file, source_library=source_library)
|
||||||
|
@ -21,7 +21,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from compare_test_runs import compare_job_sets
|
from compare_test_runs import compare_job_sets
|
||||||
@ -120,7 +120,7 @@ def handle_stacktraces(test_results):
|
|||||||
return stacktraces
|
return stacktraces
|
||||||
|
|
||||||
|
|
||||||
def dicts_to_sum(objects: Union[dict[str, dict], list[dict]]):
|
def dicts_to_sum(objects: dict[str, dict] | list[dict]):
|
||||||
if isinstance(objects, dict):
|
if isinstance(objects, dict):
|
||||||
lists = objects.values()
|
lists = objects.values()
|
||||||
else:
|
else:
|
||||||
@ -139,7 +139,7 @@ class Message:
|
|||||||
ci_title: str,
|
ci_title: str,
|
||||||
model_results: dict,
|
model_results: dict,
|
||||||
additional_results: dict,
|
additional_results: dict,
|
||||||
selected_warnings: Optional[list] = None,
|
selected_warnings: list | None = None,
|
||||||
prev_ci_artifacts=None,
|
prev_ci_artifacts=None,
|
||||||
other_ci_artifacts=None,
|
other_ci_artifacts=None,
|
||||||
):
|
):
|
||||||
@ -941,7 +941,7 @@ class Message:
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def retrieve_artifact(artifact_path: str, gpu: Optional[str]):
|
def retrieve_artifact(artifact_path: str, gpu: str | None):
|
||||||
if gpu not in [None, "single", "multi"]:
|
if gpu not in [None, "single", "multi"]:
|
||||||
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
|
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
|
||||||
|
|
||||||
@ -970,7 +970,7 @@ def retrieve_available_artifacts():
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def add_path(self, path: str, gpu: Optional[str] = None):
|
def add_path(self, path: str, gpu: str | None = None):
|
||||||
self.paths.append({"name": self.name, "path": path, "gpu": gpu})
|
self.paths.append({"name": self.name, "path": path, "gpu": gpu})
|
||||||
|
|
||||||
_available_artifacts: dict[str, Artifact] = {}
|
_available_artifacts: dict[str, Artifact] = {}
|
||||||
|
@ -33,7 +33,6 @@ python utils/sort_auto_mappings.py --check_only
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
# Path are set with the intent you should run this script from the root of the repo.
|
# Path are set with the intent you should run this script from the root of the repo.
|
||||||
@ -47,7 +46,7 @@ _re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict
|
|||||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||||
|
|
||||||
|
|
||||||
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:
|
def sort_auto_mapping(fname: str, overwrite: bool = False) -> bool | None:
|
||||||
"""
|
"""
|
||||||
Sort all auto mappings in a file.
|
Sort all auto mappings in a file.
|
||||||
|
|
||||||
|
@ -57,7 +57,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
@ -555,7 +554,7 @@ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*
|
|||||||
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
||||||
|
|
||||||
|
|
||||||
def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
|
def extract_imports(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get the imports a given module makes.
|
Get the imports a given module makes.
|
||||||
|
|
||||||
@ -637,7 +636,7 @@ def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = N
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_module_dependencies(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
|
def get_module_dependencies(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
|
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
|
||||||
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
|
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
|
||||||
@ -734,7 +733,7 @@ def create_reverse_dependency_tree() -> list[tuple[str, str]]:
|
|||||||
return list(set(edges))
|
return list(set(edges))
|
||||||
|
|
||||||
|
|
||||||
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[Union[str, list[str]]]:
|
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[str | list[str]]:
|
||||||
"""
|
"""
|
||||||
Returns the tree starting at a given module following all edges.
|
Returns the tree starting at a given module following all edges.
|
||||||
|
|
||||||
@ -883,7 +882,7 @@ def create_reverse_dependency_map() -> dict[str, list[str]]:
|
|||||||
|
|
||||||
|
|
||||||
def create_module_to_test_map(
|
def create_module_to_test_map(
|
||||||
reverse_map: Optional[dict[str, list[str]]] = None, filter_models: bool = False
|
reverse_map: dict[str, list[str]] | None = None, filter_models: bool = False
|
||||||
) -> dict[str, list[str]]:
|
) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
||||||
|
Reference in New Issue
Block a user