mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
Use | for Optional and Union typing (#41646)
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
@ -16,7 +16,6 @@ import sys
|
||||
from logging import Logger
|
||||
from threading import Event, Thread
|
||||
from time import perf_counter, sleep
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Add the parent directory to Python path to import benchmarks_entrypoint
|
||||
@ -145,7 +144,7 @@ def run_benchmark(
|
||||
q = torch.empty_like(probs_sort).exponential_(1)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: int | None = None):
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
@ -155,7 +154,7 @@ def run_benchmark(
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
def sample(logits, temperature: float = 1.0, top_k: int | None = None):
|
||||
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
KERNELIZATION_AVAILABLE = False
|
||||
@ -27,11 +27,11 @@ class BenchmarkConfig:
|
||||
sequence_length: int = 128,
|
||||
num_tokens_to_generate: int = 128,
|
||||
attn_implementation: str = "eager",
|
||||
sdpa_backend: Optional[str] = None,
|
||||
compile_mode: Optional[str] = None,
|
||||
compile_options: Optional[dict[str, Any]] = None,
|
||||
sdpa_backend: str | None = None,
|
||||
compile_mode: str | None = None,
|
||||
compile_options: dict[str, Any] | None = None,
|
||||
kernelize: bool = False,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
skip_validity_check: bool = False,
|
||||
) -> None:
|
||||
# Benchmark parameters
|
||||
@ -128,8 +128,8 @@ class BenchmarkConfig:
|
||||
|
||||
|
||||
def cross_generate_configs(
|
||||
attn_impl_and_sdpa_backend: list[tuple[str, Optional[str]]],
|
||||
compiled_mode: list[Optional[str]],
|
||||
attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
|
||||
compiled_mode: list[str | None],
|
||||
kernelized: list[bool],
|
||||
warmup_iterations: int = 5,
|
||||
measurement_iterations: int = 20,
|
||||
|
@ -8,7 +8,7 @@ import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from queue import Queue
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from tqdm import trange
|
||||
@ -74,7 +74,7 @@ def get_git_revision() -> str:
|
||||
return git_hash.readline().strip()
|
||||
|
||||
|
||||
def get_sdpa_backend(backend_name: Optional[str]) -> Optional[torch.nn.attention.SDPBackend]:
|
||||
def get_sdpa_backend(backend_name: str | None) -> torch.nn.attention.SDPBackend | None:
|
||||
"""Get the SDPA backend enum from string name."""
|
||||
if backend_name is None:
|
||||
return None
|
||||
@ -145,7 +145,7 @@ class BenchmarkRunner:
|
||||
"""Main benchmark runner that coordinates benchmark execution."""
|
||||
|
||||
def __init__(
|
||||
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: Optional[str] = None
|
||||
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: str | None = None
|
||||
) -> None:
|
||||
# Those stay constant for the whole run
|
||||
self.logger = logger
|
||||
@ -156,7 +156,7 @@ class BenchmarkRunner:
|
||||
# Attributes that are reset for each model
|
||||
self._setup_for = ""
|
||||
# Attributes that are reset for each run
|
||||
self.model: Optional[GenerationMixin] = None
|
||||
self.model: GenerationMixin | None = None
|
||||
|
||||
def cleanup(self) -> None:
|
||||
del self.model
|
||||
@ -251,8 +251,8 @@ class BenchmarkRunner:
|
||||
def time_generate(
|
||||
self,
|
||||
max_new_tokens: int,
|
||||
gpu_monitor: Optional[GPUMonitor] = None,
|
||||
) -> tuple[float, list[float], str, Optional[GPURawMetrics]]:
|
||||
gpu_monitor: GPUMonitor | None = None,
|
||||
) -> tuple[float, list[float], str, GPURawMetrics | None]:
|
||||
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
|
||||
# Prepare gpu monitoring if needed
|
||||
if gpu_monitor is not None:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -90,14 +90,14 @@ class BenchmarkResult:
|
||||
e2e_latency: float,
|
||||
token_generation_times: list[float],
|
||||
decoded_output: str,
|
||||
gpu_metrics: Optional[GPURawMetrics],
|
||||
gpu_metrics: GPURawMetrics | None,
|
||||
) -> None:
|
||||
self.e2e_latency.append(e2e_latency)
|
||||
self.token_generation_times.append(token_generation_times)
|
||||
self.decoded_outputs.append(decoded_output)
|
||||
self.gpu_metrics.append(gpu_metrics)
|
||||
|
||||
def to_dict(self) -> dict[str, Union[None, int, float]]:
|
||||
def to_dict(self) -> dict[str, None | int | float]:
|
||||
# Save GPU metrics as None if it contains only None values
|
||||
if all(gm is None for gm in self.gpu_metrics):
|
||||
gpu_metrics = None
|
||||
@ -111,7 +111,7 @@ class BenchmarkResult:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Union[None, int, float]]) -> "BenchmarkResult":
|
||||
def from_dict(cls, data: dict[str, None | int | float]) -> "BenchmarkResult":
|
||||
# Handle GPU metrics, which is saved as None if it contains only None values
|
||||
if data["gpu_metrics"] is None:
|
||||
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
|
||||
import gpustat
|
||||
import psutil
|
||||
@ -42,7 +41,7 @@ class HardwareInfo:
|
||||
self.cpu_count = psutil.cpu_count()
|
||||
self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024))
|
||||
|
||||
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
||||
def to_dict(self) -> dict[str, None | int | float | str]:
|
||||
return {
|
||||
"gpu_name": self.gpu_name,
|
||||
"gpu_memory_total_gb": self.gpu_memory_total_gb,
|
||||
@ -109,7 +108,7 @@ class GPURawMetrics:
|
||||
timestamp_0: float # in seconds
|
||||
monitoring_status: GPUMonitoringStatus
|
||||
|
||||
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
|
||||
def to_dict(self) -> dict[str, None | int | float | str]:
|
||||
return {
|
||||
"utilization": self.utilization,
|
||||
"memory_used": self.memory_used,
|
||||
@ -123,7 +122,7 @@ class GPURawMetrics:
|
||||
class GPUMonitor:
|
||||
"""Monitor GPU utilization during benchmark execution."""
|
||||
|
||||
def __init__(self, sample_interval_sec: float = 0.1, logger: Optional[Logger] = None):
|
||||
def __init__(self, sample_interval_sec: float = 0.1, logger: Logger | None = None):
|
||||
self.sample_interval_sec = sample_interval_sec
|
||||
self.logger = logger if logger is not None else logging.getLogger(__name__)
|
||||
|
||||
|
@ -19,7 +19,6 @@ import os
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from parameterized import parameterized
|
||||
@ -1254,8 +1253,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
do_eval: bool = True,
|
||||
quality_checks: bool = True,
|
||||
fp32: bool = False,
|
||||
extra_args_str: Optional[str] = None,
|
||||
remove_args_str: Optional[str] = None,
|
||||
extra_args_str: str | None = None,
|
||||
remove_args_str: str | None = None,
|
||||
):
|
||||
# we are doing quality testing so using a small real model
|
||||
output_dir = self.run_trainer(
|
||||
@ -1287,8 +1286,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
do_eval: bool = True,
|
||||
distributed: bool = True,
|
||||
fp32: bool = False,
|
||||
extra_args_str: Optional[str] = None,
|
||||
remove_args_str: Optional[str] = None,
|
||||
extra_args_str: str | None = None,
|
||||
remove_args_str: str | None = None,
|
||||
):
|
||||
max_len = 32
|
||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||
|
@ -17,7 +17,6 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
@ -251,13 +250,13 @@ class TestTrainerExt(TestCasePlus):
|
||||
learning_rate: float = 3e-3,
|
||||
optim: str = "adafactor",
|
||||
distributed: bool = False,
|
||||
extra_args_str: Optional[str] = None,
|
||||
extra_args_str: str | None = None,
|
||||
eval_steps: int = 0,
|
||||
predict_with_generate: bool = True,
|
||||
do_train: bool = True,
|
||||
do_eval: bool = True,
|
||||
do_predict: bool = True,
|
||||
n_gpus_to_use: Optional[int] = None,
|
||||
n_gpus_to_use: int | None = None,
|
||||
):
|
||||
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
@ -43,8 +42,8 @@ class ContinuousBatchingTest(unittest.TestCase):
|
||||
)
|
||||
def test_group_layers(
|
||||
self,
|
||||
layer_types_str: Optional[str],
|
||||
sliding_window: Optional[int],
|
||||
layer_types_str: str | None,
|
||||
sliding_window: int | None,
|
||||
expected_groups: str,
|
||||
) -> None:
|
||||
# Take a config and change the layer_types attribute to the mix we want
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
@ -90,7 +89,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||
|
||||
@parameterized.expand([(0,), ([0, 18],)])
|
||||
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, list[int]]):
|
||||
def test_new_min_length_dist_processor(self, eos_token_id: int | list[int]):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
|
@ -23,7 +23,6 @@ import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -908,7 +907,7 @@ class GenerationTesterMixin:
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(
|
||||
self, unpadded_custom_inputs: Optional[dict] = None, padded_custom_inputs: Optional[dict] = None
|
||||
self, unpadded_custom_inputs: dict | None = None, padded_custom_inputs: dict | None = None
|
||||
):
|
||||
"""
|
||||
Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
@ -36,14 +35,14 @@ class BridgeTowerImageProcessingTester:
|
||||
self,
|
||||
parent,
|
||||
do_resize: bool = True,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
size: dict[str, int] | None = None,
|
||||
size_divisor: int = 32,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
rescale_factor: int | float = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
do_center_crop: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
||||
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
||||
image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
|
||||
image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
|
||||
do_pad: bool = True,
|
||||
batch_size=7,
|
||||
min_resolution=30,
|
||||
|
@ -15,7 +15,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Gemma3Processor, GemmaTokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_vision
|
||||
@ -83,7 +82,7 @@ class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
} # fmt: skip
|
||||
|
||||
# Override as Gemma3 needs images to be an explicitly nested batch
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
images = super().prepare_image_inputs(batch_size)
|
||||
if isinstance(images, (list, tuple)):
|
||||
|
@ -19,7 +19,6 @@ import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
@ -78,8 +77,8 @@ class Gemma3nAudioFeatureExtractionTester:
|
||||
dither: float = 0.0,
|
||||
input_scale_factor: float = 1.0,
|
||||
mel_floor: float = 1e-5,
|
||||
per_bin_mean: Optional[Sequence[float]] = None,
|
||||
per_bin_stddev: Optional[Sequence[float]] = None,
|
||||
per_bin_mean: Sequence[float] | None = None,
|
||||
per_bin_stddev: Sequence[float] | None = None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
|
@ -17,7 +17,6 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -79,7 +78,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
cls.embed_dim = 5
|
||||
cls.seq_length = 5
|
||||
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None, **kwargs):
|
||||
def prepare_text_inputs(self, batch_size: int | None = None, **kwargs):
|
||||
labels = ["a cat", "remote control"]
|
||||
labels_longer = ["a person", "a car", "a dog", "a cat"]
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -151,7 +150,7 @@ class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
# taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152
|
||||
def pad_to_square_original(
|
||||
image: Image.Image, background_color: Union[int, tuple[int, int, int]] = 0
|
||||
image: Image.Image, background_color: int | tuple[int, int, int] = 0
|
||||
) -> Image.Image:
|
||||
width, height = image.size
|
||||
if width == height:
|
||||
|
@ -16,7 +16,6 @@ import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -60,7 +59,7 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
|
||||
|
||||
# Override as Mllama needs images to be an explicitly nested batch
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
images = super().prepare_image_inputs(batch_size)
|
||||
if isinstance(images, (list, tuple)):
|
||||
|
@ -18,7 +18,6 @@ import itertools
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
@ -87,17 +86,17 @@ class PatchTSMixerModelTester:
|
||||
masked_loss: bool = False,
|
||||
mask_mode: str = "mask_before_encoder",
|
||||
channel_consistent_masking: bool = True,
|
||||
scaling: Optional[Union[str, bool]] = "std",
|
||||
scaling: str | bool | None = "std",
|
||||
# Head related
|
||||
head_dropout: float = 0.2,
|
||||
# forecast related
|
||||
prediction_length: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
out_channels: int | None = None,
|
||||
# Classification/regression related
|
||||
# num_labels: int = 3,
|
||||
num_targets: int = 3,
|
||||
output_range: Optional[list] = None,
|
||||
head_aggregation: Optional[str] = None,
|
||||
output_range: list | None = None,
|
||||
head_aggregation: str | None = None,
|
||||
# Trainer related
|
||||
batch_size=13,
|
||||
is_training=True,
|
||||
|
@ -15,7 +15,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -94,14 +93,14 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
}
|
||||
|
||||
# Override as SmolVLM needs images/video to be an explicitly nested batch
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
images = super().prepare_image_inputs(batch_size)
|
||||
if isinstance(images, (list, tuple)):
|
||||
images = [[image] for image in images]
|
||||
return images
|
||||
|
||||
def prepare_video_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_video_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of numpy videos."""
|
||||
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
||||
if batch_size is None:
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -41,16 +40,16 @@ class TvpImageProcessingTester:
|
||||
do_resize: bool = True,
|
||||
size: dict[str, int] = {"longest_edge": 40},
|
||||
do_center_crop: bool = False,
|
||||
crop_size: Optional[dict[str, int]] = None,
|
||||
crop_size: dict[str, int] | None = None,
|
||||
do_rescale: bool = False,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
rescale_factor: int | float = 1 / 255,
|
||||
do_pad: bool = True,
|
||||
pad_size: dict[str, int] = {"height": 80, "width": 80},
|
||||
fill: Optional[int] = None,
|
||||
pad_mode: Optional[PaddingMode] = None,
|
||||
fill: int | None = None,
|
||||
pad_mode: PaddingMode | None = None,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
||||
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
||||
image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
|
||||
image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
|
||||
batch_size=2,
|
||||
min_resolution=40,
|
||||
max_resolution=80,
|
||||
|
@ -17,7 +17,6 @@ import inspect
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -79,7 +78,7 @@ class VideoLlama3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
@require_vision
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
if batch_size is None:
|
||||
return prepare_image_inputs()[0]
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_metric
|
||||
@ -70,7 +69,7 @@ class DataTrainingArguments:
|
||||
the command line.
|
||||
"""
|
||||
|
||||
task_name: Optional[str] = field(
|
||||
task_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
|
||||
)
|
||||
@ -95,7 +94,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
max_train_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@ -104,7 +103,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_val_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@ -113,7 +112,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_test_samples: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
@ -122,13 +121,13 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
train_file: str | None = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the training data."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
validation_file: str | None = field(
|
||||
default=None, metadata={"help": "A csv or a json file containing the validation data."}
|
||||
)
|
||||
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||
test_file: str | None = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.task_name is not None:
|
||||
@ -155,13 +154,13 @@ class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
config_name: str | None = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
tokenizer_name: str | None = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
cache_dir: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
|
@ -19,7 +19,6 @@ import random
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
@ -136,7 +135,7 @@ class ProcessorTesterMixin:
|
||||
processor = self.processor_class(**components, **self.prepare_processor_dict())
|
||||
return processor
|
||||
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None, modalities: Optional[Union[str, list]] = None):
|
||||
def prepare_text_inputs(self, batch_size: int | None = None, modalities: str | list | None = None):
|
||||
if isinstance(modalities, str):
|
||||
modalities = [modalities]
|
||||
|
||||
@ -158,7 +157,7 @@ class ProcessorTesterMixin:
|
||||
] * (batch_size - 2)
|
||||
|
||||
@require_vision
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_image_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
if batch_size is None:
|
||||
return prepare_image_inputs()[0]
|
||||
@ -167,7 +166,7 @@ class ProcessorTesterMixin:
|
||||
return prepare_image_inputs() * batch_size
|
||||
|
||||
@require_vision
|
||||
def prepare_video_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_video_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of numpy videos."""
|
||||
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
||||
video_input = np.array(video_input)
|
||||
@ -175,7 +174,7 @@ class ProcessorTesterMixin:
|
||||
return video_input
|
||||
return [video_input] * batch_size
|
||||
|
||||
def prepare_audio_inputs(self, batch_size: Optional[int] = None):
|
||||
def prepare_audio_inputs(self, batch_size: int | None = None):
|
||||
"""This function prepares a list of numpy audio."""
|
||||
raw_speech = floats_list((1, 1000))
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
|
@ -26,7 +26,7 @@ import unittest
|
||||
from collections import OrderedDict
|
||||
from itertools import takewhile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
@ -169,7 +169,7 @@ def _test_subword_regularization_tokenizer(in_queue, out_queue, timeout):
|
||||
|
||||
def check_subword_sampling(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
text: Optional[str] = None,
|
||||
text: str | None = None,
|
||||
test_sentencepiece_ignore_case: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
@ -313,9 +313,9 @@ class TokenizerTesterMixin:
|
||||
self,
|
||||
expected_encoding: dict,
|
||||
model_name: str,
|
||||
revision: Optional[str] = None,
|
||||
sequences: Optional[list[str]] = None,
|
||||
decode_kwargs: Optional[dict[str, Any]] = None,
|
||||
revision: str | None = None,
|
||||
sequences: list[str] | None = None,
|
||||
decode_kwargs: dict[str, Any] | None = None,
|
||||
padding: bool = True,
|
||||
):
|
||||
"""
|
||||
|
@ -16,7 +16,6 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import transformers
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
@ -32,9 +31,9 @@ class TestCodeExamples(unittest.TestCase):
|
||||
def analyze_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
identifier: Union[str, None] = None,
|
||||
ignore_files: Union[list[str], None] = None,
|
||||
n_identifier: Union[str, list[str], None] = None,
|
||||
identifier: str | None = None,
|
||||
ignore_files: list[str] | None = None,
|
||||
n_identifier: str | list[str] | None = None,
|
||||
only_modules: bool = True,
|
||||
):
|
||||
"""
|
||||
|
@ -22,7 +22,7 @@ from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union, get_args, get_origin
|
||||
from typing import Literal, Union, get_args, get_origin
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
@ -59,7 +59,7 @@ class WithDefaultExample:
|
||||
class WithDefaultBoolExample:
|
||||
foo: bool = False
|
||||
baz: bool = True
|
||||
opt: Optional[bool] = None
|
||||
opt: bool | None = None
|
||||
|
||||
|
||||
class BasicEnum(Enum):
|
||||
@ -91,11 +91,11 @@ class MixedTypeEnumExample:
|
||||
|
||||
@dataclass
|
||||
class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
bar: Optional[float] = field(default=None, metadata={"help": "help message"})
|
||||
baz: Optional[str] = None
|
||||
ces: Optional[list[str]] = list_field(default=[])
|
||||
des: Optional[list[int]] = list_field(default=[])
|
||||
foo: int | None = None
|
||||
bar: float | None = field(default=None, metadata={"help": "help message"})
|
||||
baz: str | None = None
|
||||
ces: list[str] | None = list_field(default=[])
|
||||
des: list[int] | None = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -120,7 +120,7 @@ class RequiredExample:
|
||||
class StringLiteralAnnotationExample:
|
||||
foo: int
|
||||
required_enum: "BasicEnum" = field()
|
||||
opt: "Optional[bool]" = None
|
||||
opt: "bool | None" = None
|
||||
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
||||
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
|
||||
|
@ -17,7 +17,6 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
@ -46,7 +45,7 @@ if is_vision_available():
|
||||
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
|
||||
|
||||
|
||||
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: Optional[str] = None) -> "PIL.Image.Image":
|
||||
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: str | None = None) -> "PIL.Image.Image":
|
||||
url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision)
|
||||
return PIL.Image.open(BytesIO(httpx.get(url, follow_redirects=True).content))
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
import io
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -31,8 +30,8 @@ if is_torch_available():
|
||||
@dataclass
|
||||
class ModelOutputTest(ModelOutput):
|
||||
a: float
|
||||
b: Optional[float] = None
|
||||
c: Optional[float] = None
|
||||
b: float | None = None
|
||||
c: float | None = None
|
||||
|
||||
|
||||
class ModelOutputTester(unittest.TestCase):
|
||||
@ -182,8 +181,8 @@ class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
a: float
|
||||
b: Optional[float] = None
|
||||
c: Optional[float] = None
|
||||
b: float | None = None
|
||||
c: float | None = None
|
||||
|
||||
|
||||
class ModelOutputSubclassTester(unittest.TestCase):
|
||||
|
@ -3,7 +3,6 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import paper_info
|
||||
|
||||
@ -51,7 +50,7 @@ def get_modified_cards() -> list[str]:
|
||||
return model_names
|
||||
|
||||
|
||||
def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
|
||||
def get_paper_link(model_card: str | None, path: str | None) -> str:
|
||||
"""Get the first paper link from the model card content."""
|
||||
|
||||
if model_card is not None and not model_card.endswith(".md"):
|
||||
@ -91,7 +90,7 @@ def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
|
||||
return paper_ids[0]
|
||||
|
||||
|
||||
def get_first_commit_date(model_name: Optional[str]) -> str:
|
||||
def get_first_commit_date(model_name: str | None) -> str:
|
||||
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
|
||||
|
||||
if model_name.endswith(".md"):
|
||||
|
@ -42,7 +42,6 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
@ -384,8 +383,8 @@ def split_code_into_blocks(
|
||||
|
||||
|
||||
def find_code_in_transformers(
|
||||
object_name: str, base_path: Optional[str] = None, return_indices: bool = False
|
||||
) -> Union[str, tuple[list[str], int, int]]:
|
||||
object_name: str, base_path: str | None = None, return_indices: bool = False
|
||||
) -> str | tuple[list[str], int, int]:
|
||||
"""
|
||||
Find and return the source code of an object.
|
||||
|
||||
@ -485,7 +484,7 @@ def replace_code(code: str, replace_pattern: str) -> str:
|
||||
return code
|
||||
|
||||
|
||||
def find_code_and_splits(object_name: str, base_path: str, buffer: Optional[dict] = None):
|
||||
def find_code_and_splits(object_name: str, base_path: str, buffer: dict | None = None):
|
||||
"""Find the code of an object (specified by `object_name`) and split it into blocks.
|
||||
|
||||
Args:
|
||||
@ -581,7 +580,7 @@ def stylify(code: str) -> str:
|
||||
return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
|
||||
|
||||
|
||||
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
|
||||
def check_codes_match(observed_code: str, theoretical_code: str) -> int | None:
|
||||
"""
|
||||
Checks if two version of a code match with the exception of the class/function name.
|
||||
|
||||
@ -633,8 +632,8 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
|
||||
|
||||
|
||||
def is_copy_consistent(
|
||||
filename: str, overwrite: bool = False, buffer: Optional[dict] = None
|
||||
) -> Optional[list[tuple[str, int]]]:
|
||||
filename: str, overwrite: bool = False, buffer: dict | None = None
|
||||
) -> list[tuple[str, int]] | None:
|
||||
"""
|
||||
Check if the code commented as a copy in a file matches the original.
|
||||
|
||||
@ -826,7 +825,7 @@ def is_copy_consistent(
|
||||
return diffs
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False, file: Optional[str] = None):
|
||||
def check_copies(overwrite: bool = False, file: str | None = None):
|
||||
"""
|
||||
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
||||
READMEs are consistent.
|
||||
|
@ -43,7 +43,7 @@ import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from check_repo import ignore_undocumented
|
||||
from git import Repo
|
||||
@ -525,7 +525,7 @@ def stringify_default(default: Any) -> str:
|
||||
return f"`{default}`"
|
||||
|
||||
|
||||
def eval_math_expression(expression: str) -> Optional[Union[float, int]]:
|
||||
def eval_math_expression(expression: str) -> float | int | None:
|
||||
# Mainly taken from the excellent https://stackoverflow.com/a/9558001
|
||||
"""
|
||||
Evaluate (safely) a mathematial expression and returns its value.
|
||||
@ -673,7 +673,7 @@ def find_source_file(obj: Any) -> Path:
|
||||
return obj_file.with_suffix(".py")
|
||||
|
||||
|
||||
def match_docstring_with_signature(obj: Any) -> Optional[tuple[str, str]]:
|
||||
def match_docstring_with_signature(obj: Any) -> tuple[str, str] | None:
|
||||
"""
|
||||
Matches the docstring of an object with its signature.
|
||||
|
||||
|
@ -37,7 +37,6 @@ python utils/check_dummies.py --fix_and_overwrite
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
@ -73,7 +72,7 @@ def {0}(*args, **kwargs):
|
||||
"""
|
||||
|
||||
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
def find_backend(line: str) -> str | None:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
@ -156,7 +155,7 @@ def create_dummy_object(name: str, backend_name: str) -> str:
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files(backend_specific_objects: Optional[dict[str, list[str]]] = None) -> dict[str, str]:
|
||||
def create_dummy_files(backend_specific_objects: dict[str, list[str]] | None = None) -> dict[str, str]:
|
||||
"""
|
||||
Create the content of the dummy files.
|
||||
|
||||
|
@ -39,7 +39,6 @@ import collections
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Path is set with the intent you should run this script from the root of the repo.
|
||||
@ -70,7 +69,7 @@ _re_try = re.compile(r"^\s*try:")
|
||||
_re_else = re.compile(r"^\s*else:")
|
||||
|
||||
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
def find_backend(line: str) -> str | None:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
@ -89,7 +88,7 @@ def find_backend(line: str) -> Optional[str]:
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file) -> Optional[tuple[dict[str, list[str]], dict[str, list[str]]]]:
|
||||
def parse_init(init_file) -> tuple[dict[str, list[str]], dict[str, list[str]]] | None:
|
||||
"""
|
||||
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
||||
defined.
|
||||
|
@ -39,7 +39,7 @@ import argparse
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Path is defined with the intent you should run this script from the root of the repo.
|
||||
@ -64,7 +64,7 @@ def get_indent(line: str) -> str:
|
||||
|
||||
|
||||
def split_code_in_indented_blocks(
|
||||
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
|
||||
code: str, indent_level: str = "", start_prompt: str | None = None, end_prompt: str | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split some code into its indented blocks, starting at a given level.
|
||||
@ -141,7 +141,7 @@ def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any]
|
||||
return _inner
|
||||
|
||||
|
||||
def sort_objects(objects: list[Any], key: Optional[Callable[[Any], str]] = None) -> list[Any]:
|
||||
def sort_objects(objects: list[Any], key: Callable[[Any], str] | None = None) -> list[Any]:
|
||||
"""
|
||||
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
||||
last).
|
||||
|
@ -9,7 +9,6 @@ import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from custom_init_isort import sort_imports_in_all_inits
|
||||
@ -77,7 +76,7 @@ def insert_tip_to_model_doc(model_doc_path, tip_message):
|
||||
f.write("\n".join(new_model_lines))
|
||||
|
||||
|
||||
def get_model_doc_path(model: str) -> tuple[Optional[str], Optional[str]]:
|
||||
def get_model_doc_path(model: str) -> tuple[str | None, str | None]:
|
||||
# Possible variants of the model name in the model doc path
|
||||
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
|
||||
|
||||
|
@ -33,7 +33,6 @@ import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -85,8 +84,8 @@ def handle_suite(
|
||||
machine_type: str,
|
||||
dry_run: bool,
|
||||
tmp_cache: str = "",
|
||||
resume_at: Optional[str] = None,
|
||||
only_in: Optional[list[str]] = None,
|
||||
resume_at: str | None = None,
|
||||
only_in: list[str] | None = None,
|
||||
cpu_tests: bool = False,
|
||||
process_id: int = 1,
|
||||
total_processes: int = 1,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import libcst as cst
|
||||
|
||||
@ -14,7 +13,7 @@ EXCLUDED_EXTERNAL_FILES = {
|
||||
def convert_relative_import_to_absolute(
|
||||
import_node: cst.ImportFrom,
|
||||
file_path: str,
|
||||
package_name: Optional[str] = "transformers",
|
||||
package_name: str | None = "transformers",
|
||||
) -> cst.ImportFrom:
|
||||
"""
|
||||
Convert a relative libcst.ImportFrom node into an absolute one,
|
||||
@ -51,7 +50,7 @@ def convert_relative_import_to_absolute(
|
||||
base_parts = module_parts[:-rel_level]
|
||||
|
||||
# Flatten the module being imported (if any)
|
||||
def flatten_module(module: Optional[cst.BaseExpression]) -> list[str]:
|
||||
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
|
||||
if not module:
|
||||
return []
|
||||
if isinstance(module, cst.Name):
|
||||
@ -76,7 +75,7 @@ def convert_relative_import_to_absolute(
|
||||
full_parts = [file_parts[pkg_index - 1]] + full_parts
|
||||
|
||||
# Build the dotted module path
|
||||
dotted_module: Optional[cst.BaseExpression] = None
|
||||
dotted_module: cst.BaseExpression | None = None
|
||||
for part in full_parts:
|
||||
name = cst.Name(part)
|
||||
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
|
||||
|
@ -22,7 +22,6 @@ import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict, deque
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import libcst as cst
|
||||
from create_dependency_mapping import find_priority_list
|
||||
@ -177,7 +176,7 @@ DOCSTRING_NODE = m.SimpleStatementLine(
|
||||
)
|
||||
|
||||
|
||||
def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]:
|
||||
def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None:
|
||||
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
|
||||
successive value of an Attribute are not Name nodes, return `None`."""
|
||||
if m.matches(node, m.Name()):
|
||||
@ -378,11 +377,11 @@ class ReplaceSuperCallTransformer(cst.CSTTransformer):
|
||||
|
||||
def find_all_dependencies(
|
||||
dependency_mapping: dict[str, set],
|
||||
start_entity: Optional[str] = None,
|
||||
initial_dependencies: Optional[set] = None,
|
||||
initial_checked_dependencies: Optional[set] = None,
|
||||
start_entity: str | None = None,
|
||||
initial_dependencies: set | None = None,
|
||||
initial_checked_dependencies: set | None = None,
|
||||
return_parent: bool = False,
|
||||
) -> Union[list, set]:
|
||||
) -> list | set:
|
||||
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
|
||||
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
|
||||
|
||||
@ -476,7 +475,7 @@ class ClassDependencyMapper(CSTVisitor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None
|
||||
self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self.class_name = class_name
|
||||
@ -504,7 +503,7 @@ def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> s
|
||||
|
||||
|
||||
def augmented_dependencies_for_class_node(
|
||||
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None
|
||||
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None
|
||||
) -> set:
|
||||
"""Create augmented dependencies for a class node based on a `mapper`.
|
||||
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
|
||||
@ -1659,8 +1658,8 @@ def get_class_node_and_dependencies(
|
||||
|
||||
def create_modules(
|
||||
modular_mapper: ModularFileMapper,
|
||||
file_path: Optional[str] = None,
|
||||
package_name: Optional[str] = "transformers",
|
||||
file_path: str | None = None,
|
||||
package_name: str | None = "transformers",
|
||||
) -> dict[str, cst.Module]:
|
||||
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary."""
|
||||
files = defaultdict(dict)
|
||||
@ -1747,7 +1746,7 @@ def run_ruff(code, check=False):
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
def convert_modular_file(modular_file: str, source_library: Optional[str] = "transformers") -> dict[str, str]:
|
||||
def convert_modular_file(modular_file: str, source_library: str | None = "transformers") -> dict[str, str]:
|
||||
"""Convert a `modular_file` into all the different model-specific files it depicts."""
|
||||
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||
output = {}
|
||||
@ -1818,7 +1817,7 @@ def count_loc(file_path: str) -> int:
|
||||
return len([line for line in comment_less_code.split("\n") if line.strip()])
|
||||
|
||||
|
||||
def run_converter(modular_file: str, source_library: Optional[str] = "transformers"):
|
||||
def run_converter(modular_file: str, source_library: str | None = "transformers"):
|
||||
"""Convert a modular file, and save resulting files."""
|
||||
print(f"Converting {modular_file} to a single model single file format")
|
||||
converted_files = convert_modular_file(modular_file, source_library=source_library)
|
||||
|
@ -21,7 +21,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from compare_test_runs import compare_job_sets
|
||||
@ -120,7 +120,7 @@ def handle_stacktraces(test_results):
|
||||
return stacktraces
|
||||
|
||||
|
||||
def dicts_to_sum(objects: Union[dict[str, dict], list[dict]]):
|
||||
def dicts_to_sum(objects: dict[str, dict] | list[dict]):
|
||||
if isinstance(objects, dict):
|
||||
lists = objects.values()
|
||||
else:
|
||||
@ -139,7 +139,7 @@ class Message:
|
||||
ci_title: str,
|
||||
model_results: dict,
|
||||
additional_results: dict,
|
||||
selected_warnings: Optional[list] = None,
|
||||
selected_warnings: list | None = None,
|
||||
prev_ci_artifacts=None,
|
||||
other_ci_artifacts=None,
|
||||
):
|
||||
@ -941,7 +941,7 @@ class Message:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def retrieve_artifact(artifact_path: str, gpu: Optional[str]):
|
||||
def retrieve_artifact(artifact_path: str, gpu: str | None):
|
||||
if gpu not in [None, "single", "multi"]:
|
||||
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
|
||||
|
||||
@ -970,7 +970,7 @@ def retrieve_available_artifacts():
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def add_path(self, path: str, gpu: Optional[str] = None):
|
||||
def add_path(self, path: str, gpu: str | None = None):
|
||||
self.paths.append({"name": self.name, "path": path, "gpu": gpu})
|
||||
|
||||
_available_artifacts: dict[str, Artifact] = {}
|
||||
|
@ -33,7 +33,6 @@ python utils/sort_auto_mappings.py --check_only
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Path are set with the intent you should run this script from the root of the repo.
|
||||
@ -47,7 +46,7 @@ _re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict
|
||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||
|
||||
|
||||
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:
|
||||
def sort_auto_mapping(fname: str, overwrite: bool = False) -> bool | None:
|
||||
"""
|
||||
Sort all auto mappings in a file.
|
||||
|
||||
|
@ -57,7 +57,6 @@ import os
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from git import Repo
|
||||
|
||||
@ -555,7 +554,7 @@ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*
|
||||
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
||||
|
||||
|
||||
def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
|
||||
def extract_imports(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
|
||||
"""
|
||||
Get the imports a given module makes.
|
||||
|
||||
@ -637,7 +636,7 @@ def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = N
|
||||
return result
|
||||
|
||||
|
||||
def get_module_dependencies(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]:
|
||||
def get_module_dependencies(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
|
||||
"""
|
||||
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
|
||||
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
|
||||
@ -734,7 +733,7 @@ def create_reverse_dependency_tree() -> list[tuple[str, str]]:
|
||||
return list(set(edges))
|
||||
|
||||
|
||||
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[Union[str, list[str]]]:
|
||||
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[str | list[str]]:
|
||||
"""
|
||||
Returns the tree starting at a given module following all edges.
|
||||
|
||||
@ -883,7 +882,7 @@ def create_reverse_dependency_map() -> dict[str, list[str]]:
|
||||
|
||||
|
||||
def create_module_to_test_map(
|
||||
reverse_map: Optional[dict[str, list[str]]] = None, filter_models: bool = False
|
||||
reverse_map: dict[str, list[str]] | None = None, filter_models: bool = False
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
||||
|
Reference in New Issue
Block a user