Use | for Optional and Union typing (#41646)

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
Yuanyuan Chen
2025-10-16 22:29:54 +08:00
committed by GitHub
parent bf815e9b5e
commit 9e99198e5e
40 changed files with 136 additions and 168 deletions

View File

@ -16,7 +16,6 @@ import sys
from logging import Logger from logging import Logger
from threading import Event, Thread from threading import Event, Thread
from time import perf_counter, sleep from time import perf_counter, sleep
from typing import Optional
# Add the parent directory to Python path to import benchmarks_entrypoint # Add the parent directory to Python path to import benchmarks_entrypoint
@ -145,7 +144,7 @@ def run_benchmark(
q = torch.empty_like(probs_sort).exponential_(1) q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): def logits_to_probs(logits, temperature: float = 1.0, top_k: int | None = None):
logits = logits / max(temperature, 1e-5) logits = logits / max(temperature, 1e-5)
if top_k is not None: if top_k is not None:
@ -155,7 +154,7 @@ def run_benchmark(
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
return probs return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): def sample(logits, temperature: float = 1.0, top_k: int | None = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k) probs = logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs) idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs return idx_next, probs

View File

@ -1,7 +1,7 @@
import hashlib import hashlib
import json import json
import logging import logging
from typing import Any, Optional from typing import Any
KERNELIZATION_AVAILABLE = False KERNELIZATION_AVAILABLE = False
@ -27,11 +27,11 @@ class BenchmarkConfig:
sequence_length: int = 128, sequence_length: int = 128,
num_tokens_to_generate: int = 128, num_tokens_to_generate: int = 128,
attn_implementation: str = "eager", attn_implementation: str = "eager",
sdpa_backend: Optional[str] = None, sdpa_backend: str | None = None,
compile_mode: Optional[str] = None, compile_mode: str | None = None,
compile_options: Optional[dict[str, Any]] = None, compile_options: dict[str, Any] | None = None,
kernelize: bool = False, kernelize: bool = False,
name: Optional[str] = None, name: str | None = None,
skip_validity_check: bool = False, skip_validity_check: bool = False,
) -> None: ) -> None:
# Benchmark parameters # Benchmark parameters
@ -128,8 +128,8 @@ class BenchmarkConfig:
def cross_generate_configs( def cross_generate_configs(
attn_impl_and_sdpa_backend: list[tuple[str, Optional[str]]], attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
compiled_mode: list[Optional[str]], compiled_mode: list[str | None],
kernelized: list[bool], kernelized: list[bool],
warmup_iterations: int = 5, warmup_iterations: int = 5,
measurement_iterations: int = 20, measurement_iterations: int = 20,

View File

@ -8,7 +8,7 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
from datetime import datetime from datetime import datetime
from queue import Queue from queue import Queue
from typing import Any, Optional from typing import Any
import torch import torch
from tqdm import trange from tqdm import trange
@ -74,7 +74,7 @@ def get_git_revision() -> str:
return git_hash.readline().strip() return git_hash.readline().strip()
def get_sdpa_backend(backend_name: Optional[str]) -> Optional[torch.nn.attention.SDPBackend]: def get_sdpa_backend(backend_name: str | None) -> torch.nn.attention.SDPBackend | None:
"""Get the SDPA backend enum from string name.""" """Get the SDPA backend enum from string name."""
if backend_name is None: if backend_name is None:
return None return None
@ -145,7 +145,7 @@ class BenchmarkRunner:
"""Main benchmark runner that coordinates benchmark execution.""" """Main benchmark runner that coordinates benchmark execution."""
def __init__( def __init__(
self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: Optional[str] = None self, logger: logging.Logger, output_dir: str = "benchmark_results", commit_id: str | None = None
) -> None: ) -> None:
# Those stay constant for the whole run # Those stay constant for the whole run
self.logger = logger self.logger = logger
@ -156,7 +156,7 @@ class BenchmarkRunner:
# Attributes that are reset for each model # Attributes that are reset for each model
self._setup_for = "" self._setup_for = ""
# Attributes that are reset for each run # Attributes that are reset for each run
self.model: Optional[GenerationMixin] = None self.model: GenerationMixin | None = None
def cleanup(self) -> None: def cleanup(self) -> None:
del self.model del self.model
@ -251,8 +251,8 @@ class BenchmarkRunner:
def time_generate( def time_generate(
self, self,
max_new_tokens: int, max_new_tokens: int,
gpu_monitor: Optional[GPUMonitor] = None, gpu_monitor: GPUMonitor | None = None,
) -> tuple[float, list[float], str, Optional[GPURawMetrics]]: ) -> tuple[float, list[float], str, GPURawMetrics | None]:
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens).""" """Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
# Prepare gpu monitoring if needed # Prepare gpu monitoring if needed
if gpu_monitor is not None: if gpu_monitor is not None:

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Union from typing import Any
import numpy as np import numpy as np
@ -90,14 +90,14 @@ class BenchmarkResult:
e2e_latency: float, e2e_latency: float,
token_generation_times: list[float], token_generation_times: list[float],
decoded_output: str, decoded_output: str,
gpu_metrics: Optional[GPURawMetrics], gpu_metrics: GPURawMetrics | None,
) -> None: ) -> None:
self.e2e_latency.append(e2e_latency) self.e2e_latency.append(e2e_latency)
self.token_generation_times.append(token_generation_times) self.token_generation_times.append(token_generation_times)
self.decoded_outputs.append(decoded_output) self.decoded_outputs.append(decoded_output)
self.gpu_metrics.append(gpu_metrics) self.gpu_metrics.append(gpu_metrics)
def to_dict(self) -> dict[str, Union[None, int, float]]: def to_dict(self) -> dict[str, None | int | float]:
# Save GPU metrics as None if it contains only None values # Save GPU metrics as None if it contains only None values
if all(gm is None for gm in self.gpu_metrics): if all(gm is None for gm in self.gpu_metrics):
gpu_metrics = None gpu_metrics = None
@ -111,7 +111,7 @@ class BenchmarkResult:
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Union[None, int, float]]) -> "BenchmarkResult": def from_dict(cls, data: dict[str, None | int | float]) -> "BenchmarkResult":
# Handle GPU metrics, which is saved as None if it contains only None values # Handle GPU metrics, which is saved as None if it contains only None values
if data["gpu_metrics"] is None: if data["gpu_metrics"] is None:
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))] gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]

View File

@ -7,7 +7,6 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from logging import Logger from logging import Logger
from typing import Optional, Union
import gpustat import gpustat
import psutil import psutil
@ -42,7 +41,7 @@ class HardwareInfo:
self.cpu_count = psutil.cpu_count() self.cpu_count = psutil.cpu_count()
self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024)) self.memory_total_mb = int(psutil.virtual_memory().total / (1024 * 1024))
def to_dict(self) -> dict[str, Union[None, int, float, str]]: def to_dict(self) -> dict[str, None | int | float | str]:
return { return {
"gpu_name": self.gpu_name, "gpu_name": self.gpu_name,
"gpu_memory_total_gb": self.gpu_memory_total_gb, "gpu_memory_total_gb": self.gpu_memory_total_gb,
@ -109,7 +108,7 @@ class GPURawMetrics:
timestamp_0: float # in seconds timestamp_0: float # in seconds
monitoring_status: GPUMonitoringStatus monitoring_status: GPUMonitoringStatus
def to_dict(self) -> dict[str, Union[None, int, float, str]]: def to_dict(self) -> dict[str, None | int | float | str]:
return { return {
"utilization": self.utilization, "utilization": self.utilization,
"memory_used": self.memory_used, "memory_used": self.memory_used,
@ -123,7 +122,7 @@ class GPURawMetrics:
class GPUMonitor: class GPUMonitor:
"""Monitor GPU utilization during benchmark execution.""" """Monitor GPU utilization during benchmark execution."""
def __init__(self, sample_interval_sec: float = 0.1, logger: Optional[Logger] = None): def __init__(self, sample_interval_sec: float = 0.1, logger: Logger | None = None):
self.sample_interval_sec = sample_interval_sec self.sample_interval_sec = sample_interval_sec
self.logger = logger if logger is not None else logging.getLogger(__name__) self.logger = logger if logger is not None else logging.getLogger(__name__)

View File

@ -19,7 +19,6 @@ import os
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Optional
import datasets import datasets
from parameterized import parameterized from parameterized import parameterized
@ -1254,8 +1253,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
do_eval: bool = True, do_eval: bool = True,
quality_checks: bool = True, quality_checks: bool = True,
fp32: bool = False, fp32: bool = False,
extra_args_str: Optional[str] = None, extra_args_str: str | None = None,
remove_args_str: Optional[str] = None, remove_args_str: str | None = None,
): ):
# we are doing quality testing so using a small real model # we are doing quality testing so using a small real model
output_dir = self.run_trainer( output_dir = self.run_trainer(
@ -1287,8 +1286,8 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
do_eval: bool = True, do_eval: bool = True,
distributed: bool = True, distributed: bool = True,
fp32: bool = False, fp32: bool = False,
extra_args_str: Optional[str] = None, extra_args_str: str | None = None,
remove_args_str: Optional[str] = None, remove_args_str: str | None = None,
): ):
max_len = 32 max_len = 32
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro" data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"

View File

@ -17,7 +17,6 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
from unittest.mock import patch from unittest.mock import patch
from parameterized import parameterized from parameterized import parameterized
@ -251,13 +250,13 @@ class TestTrainerExt(TestCasePlus):
learning_rate: float = 3e-3, learning_rate: float = 3e-3,
optim: str = "adafactor", optim: str = "adafactor",
distributed: bool = False, distributed: bool = False,
extra_args_str: Optional[str] = None, extra_args_str: str | None = None,
eval_steps: int = 0, eval_steps: int = 0,
predict_with_generate: bool = True, predict_with_generate: bool = True,
do_train: bool = True, do_train: bool = True,
do_eval: bool = True, do_eval: bool = True,
do_predict: bool = True, do_predict: bool = True,
n_gpus_to_use: Optional[int] = None, n_gpus_to_use: int | None = None,
): ):
data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro" data_dir = self.test_file_dir / "../fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from typing import Optional
import torch import torch
from parameterized import parameterized from parameterized import parameterized
@ -43,8 +42,8 @@ class ContinuousBatchingTest(unittest.TestCase):
) )
def test_group_layers( def test_group_layers(
self, self,
layer_types_str: Optional[str], layer_types_str: str | None,
sliding_window: Optional[int], sliding_window: int | None,
expected_groups: str, expected_groups: str,
) -> None: ) -> None:
# Take a config and change the layer_types attribute to the mix we want # Take a config and change the layer_types attribute to the mix we want

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from typing import Union
import numpy as np import numpy as np
from parameterized import parameterized from parameterized import parameterized
@ -90,7 +89,7 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertFalse(torch.isinf(scores_before_min_length).any()) self.assertFalse(torch.isinf(scores_before_min_length).any())
@parameterized.expand([(0,), ([0, 18],)]) @parameterized.expand([(0,), ([0, 18],)])
def test_new_min_length_dist_processor(self, eos_token_id: Union[int, list[int]]): def test_new_min_length_dist_processor(self, eos_token_id: int | list[int]):
vocab_size = 20 vocab_size = 20
batch_size = 4 batch_size = 4

View File

@ -23,7 +23,6 @@ import tempfile
import unittest import unittest
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
@ -908,7 +907,7 @@ class GenerationTesterMixin:
@pytest.mark.generate @pytest.mark.generate
def test_left_padding_compatibility( def test_left_padding_compatibility(
self, unpadded_custom_inputs: Optional[dict] = None, padded_custom_inputs: Optional[dict] = None self, unpadded_custom_inputs: dict | None = None, padded_custom_inputs: dict | None = None
): ):
""" """
Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom Tests that adding left-padding yields the same logits as the original input. Exposes arguments for custom

View File

@ -14,7 +14,6 @@
import unittest import unittest
from typing import Optional, Union
from transformers.image_utils import load_image from transformers.image_utils import load_image
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
@ -36,14 +35,14 @@ class BridgeTowerImageProcessingTester:
self, self,
parent, parent,
do_resize: bool = True, do_resize: bool = True,
size: Optional[dict[str, int]] = None, size: dict[str, int] | None = None,
size_divisor: int = 32, size_divisor: int = 32,
do_rescale: bool = True, do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255, rescale_factor: int | float = 1 / 255,
do_normalize: bool = True, do_normalize: bool = True,
do_center_crop: bool = True, do_center_crop: bool = True,
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073], image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711], image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
do_pad: bool = True, do_pad: bool = True,
batch_size=7, batch_size=7,
min_resolution=30, min_resolution=30,

View File

@ -15,7 +15,6 @@
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
from transformers import Gemma3Processor, GemmaTokenizer from transformers import Gemma3Processor, GemmaTokenizer
from transformers.testing_utils import get_tests_dir, require_vision from transformers.testing_utils import get_tests_dir, require_vision
@ -83,7 +82,7 @@ class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
} # fmt: skip } # fmt: skip
# Override as Gemma3 needs images to be an explicitly nested batch # Override as Gemma3 needs images to be an explicitly nested batch
def prepare_image_inputs(self, batch_size: Optional[int] = None): def prepare_image_inputs(self, batch_size: int | None = None):
"""This function prepares a list of PIL images for testing""" """This function prepares a list of PIL images for testing"""
images = super().prepare_image_inputs(batch_size) images = super().prepare_image_inputs(batch_size)
if isinstance(images, (list, tuple)): if isinstance(images, (list, tuple)):

View File

@ -19,7 +19,6 @@ import random
import tempfile import tempfile
import unittest import unittest
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional
import numpy as np import numpy as np
from parameterized import parameterized from parameterized import parameterized
@ -78,8 +77,8 @@ class Gemma3nAudioFeatureExtractionTester:
dither: float = 0.0, dither: float = 0.0,
input_scale_factor: float = 1.0, input_scale_factor: float = 1.0,
mel_floor: float = 1e-5, mel_floor: float = 1e-5,
per_bin_mean: Optional[Sequence[float]] = None, per_bin_mean: Sequence[float] | None = None,
per_bin_stddev: Optional[Sequence[float]] = None, per_bin_stddev: Sequence[float] | None = None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size

View File

@ -17,7 +17,6 @@ import os
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import pytest import pytest
@ -79,7 +78,7 @@ class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
cls.embed_dim = 5 cls.embed_dim = 5
cls.seq_length = 5 cls.seq_length = 5
def prepare_text_inputs(self, batch_size: Optional[int] = None, **kwargs): def prepare_text_inputs(self, batch_size: int | None = None, **kwargs):
labels = ["a cat", "remote control"] labels = ["a cat", "remote control"]
labels_longer = ["a person", "a car", "a dog", "a cat"] labels_longer = ["a person", "a car", "a dog", "a cat"]

View File

@ -14,7 +14,6 @@
import unittest import unittest
from typing import Union
import numpy as np import numpy as np
@ -151,7 +150,7 @@ class LlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152 # taken from original implementation: https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L152
def pad_to_square_original( def pad_to_square_original(
image: Image.Image, background_color: Union[int, tuple[int, int, int]] = 0 image: Image.Image, background_color: int | tuple[int, int, int] = 0
) -> Image.Image: ) -> Image.Image:
width, height = image.size width, height = image.size
if width == height: if width == height:

View File

@ -16,7 +16,6 @@ import json
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import numpy as np import numpy as np
@ -60,7 +59,7 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
# Override as Mllama needs images to be an explicitly nested batch # Override as Mllama needs images to be an explicitly nested batch
def prepare_image_inputs(self, batch_size: Optional[int] = None): def prepare_image_inputs(self, batch_size: int | None = None):
"""This function prepares a list of PIL images for testing""" """This function prepares a list of PIL images for testing"""
images = super().prepare_image_inputs(batch_size) images = super().prepare_image_inputs(batch_size)
if isinstance(images, (list, tuple)): if isinstance(images, (list, tuple)):

View File

@ -18,7 +18,6 @@ import itertools
import random import random
import tempfile import tempfile
import unittest import unittest
from typing import Optional, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -87,17 +86,17 @@ class PatchTSMixerModelTester:
masked_loss: bool = False, masked_loss: bool = False,
mask_mode: str = "mask_before_encoder", mask_mode: str = "mask_before_encoder",
channel_consistent_masking: bool = True, channel_consistent_masking: bool = True,
scaling: Optional[Union[str, bool]] = "std", scaling: str | bool | None = "std",
# Head related # Head related
head_dropout: float = 0.2, head_dropout: float = 0.2,
# forecast related # forecast related
prediction_length: int = 16, prediction_length: int = 16,
out_channels: Optional[int] = None, out_channels: int | None = None,
# Classification/regression related # Classification/regression related
# num_labels: int = 3, # num_labels: int = 3,
num_targets: int = 3, num_targets: int = 3,
output_range: Optional[list] = None, output_range: list | None = None,
head_aggregation: Optional[str] = None, head_aggregation: str | None = None,
# Trainer related # Trainer related
batch_size=13, batch_size=13,
is_training=True, is_training=True,

View File

@ -15,7 +15,6 @@
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import numpy as np import numpy as np
@ -94,14 +93,14 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
} }
# Override as SmolVLM needs images/video to be an explicitly nested batch # Override as SmolVLM needs images/video to be an explicitly nested batch
def prepare_image_inputs(self, batch_size: Optional[int] = None): def prepare_image_inputs(self, batch_size: int | None = None):
"""This function prepares a list of PIL images for testing""" """This function prepares a list of PIL images for testing"""
images = super().prepare_image_inputs(batch_size) images = super().prepare_image_inputs(batch_size)
if isinstance(images, (list, tuple)): if isinstance(images, (list, tuple)):
images = [[image] for image in images] images = [[image] for image in images]
return images return images
def prepare_video_inputs(self, batch_size: Optional[int] = None): def prepare_video_inputs(self, batch_size: int | None = None):
"""This function prepares a list of numpy videos.""" """This function prepares a list of numpy videos."""
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8 video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
if batch_size is None: if batch_size is None:

View File

@ -14,7 +14,6 @@
import unittest import unittest
from typing import Optional, Union
import numpy as np import numpy as np
@ -41,16 +40,16 @@ class TvpImageProcessingTester:
do_resize: bool = True, do_resize: bool = True,
size: dict[str, int] = {"longest_edge": 40}, size: dict[str, int] = {"longest_edge": 40},
do_center_crop: bool = False, do_center_crop: bool = False,
crop_size: Optional[dict[str, int]] = None, crop_size: dict[str, int] | None = None,
do_rescale: bool = False, do_rescale: bool = False,
rescale_factor: Union[int, float] = 1 / 255, rescale_factor: int | float = 1 / 255,
do_pad: bool = True, do_pad: bool = True,
pad_size: dict[str, int] = {"height": 80, "width": 80}, pad_size: dict[str, int] = {"height": 80, "width": 80},
fill: Optional[int] = None, fill: int | None = None,
pad_mode: Optional[PaddingMode] = None, pad_mode: PaddingMode | None = None,
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = [0.48145466, 0.4578275, 0.40821073], image_mean: float | list[float] | None = [0.48145466, 0.4578275, 0.40821073],
image_std: Optional[Union[float, list[float]]] = [0.26862954, 0.26130258, 0.27577711], image_std: float | list[float] | None = [0.26862954, 0.26130258, 0.27577711],
batch_size=2, batch_size=2,
min_resolution=40, min_resolution=40,
max_resolution=80, max_resolution=80,

View File

@ -17,7 +17,6 @@ import inspect
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
import numpy as np import numpy as np
import pytest import pytest
@ -79,7 +78,7 @@ class VideoLlama3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
shutil.rmtree(cls.tmpdirname, ignore_errors=True) shutil.rmtree(cls.tmpdirname, ignore_errors=True)
@require_vision @require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None): def prepare_image_inputs(self, batch_size: int | None = None):
"""This function prepares a list of PIL images for testing""" """This function prepares a list of PIL images for testing"""
if batch_size is None: if batch_size is None:
return prepare_image_inputs()[0] return prepare_image_inputs()[0]

View File

@ -20,7 +20,6 @@ import os
import random import random
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
@ -70,7 +69,7 @@ class DataTrainingArguments:
the command line. the command line.
""" """
task_name: Optional[str] = field( task_name: str | None = field(
default=None, default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
) )
@ -95,7 +94,7 @@ class DataTrainingArguments:
) )
}, },
) )
max_train_samples: Optional[int] = field( max_train_samples: int | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@ -104,7 +103,7 @@ class DataTrainingArguments:
) )
}, },
) )
max_val_samples: Optional[int] = field( max_val_samples: int | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@ -113,7 +112,7 @@ class DataTrainingArguments:
) )
}, },
) )
max_test_samples: Optional[int] = field( max_test_samples: int | None = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": (
@ -122,13 +121,13 @@ class DataTrainingArguments:
) )
}, },
) )
train_file: Optional[str] = field( train_file: str | None = field(
default=None, metadata={"help": "A csv or a json file containing the training data."} default=None, metadata={"help": "A csv or a json file containing the training data."}
) )
validation_file: Optional[str] = field( validation_file: str | None = field(
default=None, metadata={"help": "A csv or a json file containing the validation data."} default=None, metadata={"help": "A csv or a json file containing the validation data."}
) )
test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) test_file: str | None = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
def __post_init__(self): def __post_init__(self):
if self.task_name is not None: if self.task_name is not None:
@ -155,13 +154,13 @@ class ModelArguments:
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
) )
config_name: Optional[str] = field( config_name: str | None = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
tokenizer_name: Optional[str] = field( tokenizer_name: str | None = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
) )
cache_dir: Optional[str] = field( cache_dir: str | None = field(
default=None, default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
) )

View File

@ -19,7 +19,6 @@ import random
import sys import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional, Union
import numpy as np import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -136,7 +135,7 @@ class ProcessorTesterMixin:
processor = self.processor_class(**components, **self.prepare_processor_dict()) processor = self.processor_class(**components, **self.prepare_processor_dict())
return processor return processor
def prepare_text_inputs(self, batch_size: Optional[int] = None, modalities: Optional[Union[str, list]] = None): def prepare_text_inputs(self, batch_size: int | None = None, modalities: str | list | None = None):
if isinstance(modalities, str): if isinstance(modalities, str):
modalities = [modalities] modalities = [modalities]
@ -158,7 +157,7 @@ class ProcessorTesterMixin:
] * (batch_size - 2) ] * (batch_size - 2)
@require_vision @require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None): def prepare_image_inputs(self, batch_size: int | None = None):
"""This function prepares a list of PIL images for testing""" """This function prepares a list of PIL images for testing"""
if batch_size is None: if batch_size is None:
return prepare_image_inputs()[0] return prepare_image_inputs()[0]
@ -167,7 +166,7 @@ class ProcessorTesterMixin:
return prepare_image_inputs() * batch_size return prepare_image_inputs() * batch_size
@require_vision @require_vision
def prepare_video_inputs(self, batch_size: Optional[int] = None): def prepare_video_inputs(self, batch_size: int | None = None):
"""This function prepares a list of numpy videos.""" """This function prepares a list of numpy videos."""
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8 video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
video_input = np.array(video_input) video_input = np.array(video_input)
@ -175,7 +174,7 @@ class ProcessorTesterMixin:
return video_input return video_input
return [video_input] * batch_size return [video_input] * batch_size
def prepare_audio_inputs(self, batch_size: Optional[int] = None): def prepare_audio_inputs(self, batch_size: int | None = None):
"""This function prepares a list of numpy audio.""" """This function prepares a list of numpy audio."""
raw_speech = floats_list((1, 1000)) raw_speech = floats_list((1, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech] raw_speech = [np.asarray(audio) for audio in raw_speech]

View File

@ -26,7 +26,7 @@ import unittest
from collections import OrderedDict from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Union
from parameterized import parameterized from parameterized import parameterized
@ -169,7 +169,7 @@ def _test_subword_regularization_tokenizer(in_queue, out_queue, timeout):
def check_subword_sampling( def check_subword_sampling(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
text: Optional[str] = None, text: str | None = None,
test_sentencepiece_ignore_case: bool = True, test_sentencepiece_ignore_case: bool = True,
) -> None: ) -> None:
""" """
@ -313,9 +313,9 @@ class TokenizerTesterMixin:
self, self,
expected_encoding: dict, expected_encoding: dict,
model_name: str, model_name: str,
revision: Optional[str] = None, revision: str | None = None,
sequences: Optional[list[str]] = None, sequences: list[str] | None = None,
decode_kwargs: Optional[dict[str, Any]] = None, decode_kwargs: dict[str, Any] | None = None,
padding: bool = True, padding: bool = True,
): ):
""" """

View File

@ -16,7 +16,6 @@ import logging
import os import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Union
import transformers import transformers
from transformers.testing_utils import require_torch, slow from transformers.testing_utils import require_torch, slow
@ -32,9 +31,9 @@ class TestCodeExamples(unittest.TestCase):
def analyze_directory( def analyze_directory(
self, self,
directory: Path, directory: Path,
identifier: Union[str, None] = None, identifier: str | None = None,
ignore_files: Union[list[str], None] = None, ignore_files: list[str] | None = None,
n_identifier: Union[str, list[str], None] = None, n_identifier: str | list[str] | None = None,
only_modules: bool = True, only_modules: bool = True,
): ):
""" """

View File

@ -22,7 +22,7 @@ from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union, get_args, get_origin from typing import Literal, Union, get_args, get_origin
from unittest.mock import patch from unittest.mock import patch
import yaml import yaml
@ -59,7 +59,7 @@ class WithDefaultExample:
class WithDefaultBoolExample: class WithDefaultBoolExample:
foo: bool = False foo: bool = False
baz: bool = True baz: bool = True
opt: Optional[bool] = None opt: bool | None = None
class BasicEnum(Enum): class BasicEnum(Enum):
@ -91,11 +91,11 @@ class MixedTypeEnumExample:
@dataclass @dataclass
class OptionalExample: class OptionalExample:
foo: Optional[int] = None foo: int | None = None
bar: Optional[float] = field(default=None, metadata={"help": "help message"}) bar: float | None = field(default=None, metadata={"help": "help message"})
baz: Optional[str] = None baz: str | None = None
ces: Optional[list[str]] = list_field(default=[]) ces: list[str] | None = list_field(default=[])
des: Optional[list[int]] = list_field(default=[]) des: list[int] | None = list_field(default=[])
@dataclass @dataclass
@ -120,7 +120,7 @@ class RequiredExample:
class StringLiteralAnnotationExample: class StringLiteralAnnotationExample:
foo: int foo: int
required_enum: "BasicEnum" = field() required_enum: "BasicEnum" = field()
opt: "Optional[bool]" = None opt: "bool | None" = None
baz: "str" = field(default="toto", metadata={"help": "help message"}) baz: "str" = field(default="toto", metadata={"help": "help message"})
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"]) foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])

View File

@ -17,7 +17,6 @@ import os
import tempfile import tempfile
import unittest import unittest
from io import BytesIO from io import BytesIO
from typing import Optional
import httpx import httpx
import numpy as np import numpy as np
@ -46,7 +45,7 @@ if is_vision_available():
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: Optional[str] = None) -> "PIL.Image.Image": def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: str | None = None) -> "PIL.Image.Image":
url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision) url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision)
return PIL.Image.open(BytesIO(httpx.get(url, follow_redirects=True).content)) return PIL.Image.open(BytesIO(httpx.get(url, follow_redirects=True).content))

View File

@ -15,7 +15,6 @@
import io import io
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import pytest import pytest
@ -31,8 +30,8 @@ if is_torch_available():
@dataclass @dataclass
class ModelOutputTest(ModelOutput): class ModelOutputTest(ModelOutput):
a: float a: float
b: Optional[float] = None b: float | None = None
c: Optional[float] = None c: float | None = None
class ModelOutputTester(unittest.TestCase): class ModelOutputTester(unittest.TestCase):
@ -182,8 +181,8 @@ class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used""" """Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
a: float a: float
b: Optional[float] = None b: float | None = None
c: Optional[float] = None c: float | None = None
class ModelOutputSubclassTester(unittest.TestCase): class ModelOutputSubclassTester(unittest.TestCase):

View File

@ -3,7 +3,6 @@ import os
import re import re
import subprocess import subprocess
from datetime import date from datetime import date
from typing import Optional
from huggingface_hub import paper_info from huggingface_hub import paper_info
@ -51,7 +50,7 @@ def get_modified_cards() -> list[str]:
return model_names return model_names
def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str: def get_paper_link(model_card: str | None, path: str | None) -> str:
"""Get the first paper link from the model card content.""" """Get the first paper link from the model card content."""
if model_card is not None and not model_card.endswith(".md"): if model_card is not None and not model_card.endswith(".md"):
@ -91,7 +90,7 @@ def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
return paper_ids[0] return paper_ids[0]
def get_first_commit_date(model_name: Optional[str]) -> str: def get_first_commit_date(model_name: str | None) -> str:
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers""" """Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
if model_name.endswith(".md"): if model_name.endswith(".md"):

View File

@ -42,7 +42,6 @@ import os
import re import re
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Union
from transformers.utils import direct_transformers_import from transformers.utils import direct_transformers_import
@ -384,8 +383,8 @@ def split_code_into_blocks(
def find_code_in_transformers( def find_code_in_transformers(
object_name: str, base_path: Optional[str] = None, return_indices: bool = False object_name: str, base_path: str | None = None, return_indices: bool = False
) -> Union[str, tuple[list[str], int, int]]: ) -> str | tuple[list[str], int, int]:
""" """
Find and return the source code of an object. Find and return the source code of an object.
@ -485,7 +484,7 @@ def replace_code(code: str, replace_pattern: str) -> str:
return code return code
def find_code_and_splits(object_name: str, base_path: str, buffer: Optional[dict] = None): def find_code_and_splits(object_name: str, base_path: str, buffer: dict | None = None):
"""Find the code of an object (specified by `object_name`) and split it into blocks. """Find the code of an object (specified by `object_name`) and split it into blocks.
Args: Args:
@ -581,7 +580,7 @@ def stylify(code: str) -> str:
return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]: def check_codes_match(observed_code: str, theoretical_code: str) -> int | None:
""" """
Checks if two version of a code match with the exception of the class/function name. Checks if two version of a code match with the exception of the class/function name.
@ -633,8 +632,8 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
def is_copy_consistent( def is_copy_consistent(
filename: str, overwrite: bool = False, buffer: Optional[dict] = None filename: str, overwrite: bool = False, buffer: dict | None = None
) -> Optional[list[tuple[str, int]]]: ) -> list[tuple[str, int]] | None:
""" """
Check if the code commented as a copy in a file matches the original. Check if the code commented as a copy in a file matches the original.
@ -826,7 +825,7 @@ def is_copy_consistent(
return diffs return diffs
def check_copies(overwrite: bool = False, file: Optional[str] = None): def check_copies(overwrite: bool = False, file: str | None = None):
""" """
Check every file is copy-consistent with the original. Also check the model list in the main README and other Check every file is copy-consistent with the original. Also check the model list in the main README and other
READMEs are consistent. READMEs are consistent.

View File

@ -43,7 +43,7 @@ import os
import re import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any
from check_repo import ignore_undocumented from check_repo import ignore_undocumented
from git import Repo from git import Repo
@ -525,7 +525,7 @@ def stringify_default(default: Any) -> str:
return f"`{default}`" return f"`{default}`"
def eval_math_expression(expression: str) -> Optional[Union[float, int]]: def eval_math_expression(expression: str) -> float | int | None:
# Mainly taken from the excellent https://stackoverflow.com/a/9558001 # Mainly taken from the excellent https://stackoverflow.com/a/9558001
""" """
Evaluate (safely) a mathematial expression and returns its value. Evaluate (safely) a mathematial expression and returns its value.
@ -673,7 +673,7 @@ def find_source_file(obj: Any) -> Path:
return obj_file.with_suffix(".py") return obj_file.with_suffix(".py")
def match_docstring_with_signature(obj: Any) -> Optional[tuple[str, str]]: def match_docstring_with_signature(obj: Any) -> tuple[str, str] | None:
""" """
Matches the docstring of an object with its signature. Matches the docstring of an object with its signature.

View File

@ -37,7 +37,6 @@ python utils/check_dummies.py --fix_and_overwrite
import argparse import argparse
import os import os
import re import re
from typing import Optional
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
@ -73,7 +72,7 @@ def {0}(*args, **kwargs):
""" """
def find_backend(line: str) -> Optional[str]: def find_backend(line: str) -> str | None:
""" """
Find one (or multiple) backend in a code line of the init. Find one (or multiple) backend in a code line of the init.
@ -156,7 +155,7 @@ def create_dummy_object(name: str, backend_name: str) -> str:
return DUMMY_CLASS.format(name, backend_name) return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files(backend_specific_objects: Optional[dict[str, list[str]]] = None) -> dict[str, str]: def create_dummy_files(backend_specific_objects: dict[str, list[str]] | None = None) -> dict[str, str]:
""" """
Create the content of the dummy files. Create the content of the dummy files.

View File

@ -39,7 +39,6 @@ import collections
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Optional
# Path is set with the intent you should run this script from the root of the repo. # Path is set with the intent you should run this script from the root of the repo.
@ -70,7 +69,7 @@ _re_try = re.compile(r"^\s*try:")
_re_else = re.compile(r"^\s*else:") _re_else = re.compile(r"^\s*else:")
def find_backend(line: str) -> Optional[str]: def find_backend(line: str) -> str | None:
""" """
Find one (or multiple) backend in a code line of the init. Find one (or multiple) backend in a code line of the init.
@ -89,7 +88,7 @@ def find_backend(line: str) -> Optional[str]:
return "_and_".join(backends) return "_and_".join(backends)
def parse_init(init_file) -> Optional[tuple[dict[str, list[str]], dict[str, list[str]]]]: def parse_init(init_file) -> tuple[dict[str, list[str]], dict[str, list[str]]] | None:
""" """
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
defined. defined.

View File

@ -39,7 +39,7 @@ import argparse
import os import os
import re import re
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Optional from typing import Any
# Path is defined with the intent you should run this script from the root of the repo. # Path is defined with the intent you should run this script from the root of the repo.
@ -64,7 +64,7 @@ def get_indent(line: str) -> str:
def split_code_in_indented_blocks( def split_code_in_indented_blocks(
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None code: str, indent_level: str = "", start_prompt: str | None = None, end_prompt: str | None = None
) -> list[str]: ) -> list[str]:
""" """
Split some code into its indented blocks, starting at a given level. Split some code into its indented blocks, starting at a given level.
@ -141,7 +141,7 @@ def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any]
return _inner return _inner
def sort_objects(objects: list[Any], key: Optional[Callable[[Any], str]] = None) -> list[Any]: def sort_objects(objects: list[Any], key: Callable[[Any], str] | None = None) -> list[Any]:
""" """
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
last). last).

View File

@ -9,7 +9,6 @@ import argparse
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Optional
import requests import requests
from custom_init_isort import sort_imports_in_all_inits from custom_init_isort import sort_imports_in_all_inits
@ -77,7 +76,7 @@ def insert_tip_to_model_doc(model_doc_path, tip_message):
f.write("\n".join(new_model_lines)) f.write("\n".join(new_model_lines))
def get_model_doc_path(model: str) -> tuple[Optional[str], Optional[str]]: def get_model_doc_path(model: str) -> tuple[str | None, str | None]:
# Possible variants of the model name in the model doc path # Possible variants of the model name in the model doc path
model_names = [model, model.replace("_", "-"), model.replace("_", "")] model_names = [model, model.replace("_", "-"), model.replace("_", "")]

View File

@ -33,7 +33,6 @@ import os
import subprocess import subprocess
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
@ -85,8 +84,8 @@ def handle_suite(
machine_type: str, machine_type: str,
dry_run: bool, dry_run: bool,
tmp_cache: str = "", tmp_cache: str = "",
resume_at: Optional[str] = None, resume_at: str | None = None,
only_in: Optional[list[str]] = None, only_in: list[str] | None = None,
cpu_tests: bool = False, cpu_tests: bool = False,
process_id: int = 1, process_id: int = 1,
total_processes: int = 1, total_processes: int = 1,

View File

@ -1,5 +1,4 @@
import os import os
from typing import Optional
import libcst as cst import libcst as cst
@ -14,7 +13,7 @@ EXCLUDED_EXTERNAL_FILES = {
def convert_relative_import_to_absolute( def convert_relative_import_to_absolute(
import_node: cst.ImportFrom, import_node: cst.ImportFrom,
file_path: str, file_path: str,
package_name: Optional[str] = "transformers", package_name: str | None = "transformers",
) -> cst.ImportFrom: ) -> cst.ImportFrom:
""" """
Convert a relative libcst.ImportFrom node into an absolute one, Convert a relative libcst.ImportFrom node into an absolute one,
@ -51,7 +50,7 @@ def convert_relative_import_to_absolute(
base_parts = module_parts[:-rel_level] base_parts = module_parts[:-rel_level]
# Flatten the module being imported (if any) # Flatten the module being imported (if any)
def flatten_module(module: Optional[cst.BaseExpression]) -> list[str]: def flatten_module(module: cst.BaseExpression | None) -> list[str]:
if not module: if not module:
return [] return []
if isinstance(module, cst.Name): if isinstance(module, cst.Name):
@ -76,7 +75,7 @@ def convert_relative_import_to_absolute(
full_parts = [file_parts[pkg_index - 1]] + full_parts full_parts = [file_parts[pkg_index - 1]] + full_parts
# Build the dotted module path # Build the dotted module path
dotted_module: Optional[cst.BaseExpression] = None dotted_module: cst.BaseExpression | None = None
for part in full_parts: for part in full_parts:
name = cst.Name(part) name = cst.Name(part)
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name) dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)

View File

@ -22,7 +22,6 @@ import subprocess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque from collections import Counter, defaultdict, deque
from functools import partial from functools import partial
from typing import Optional, Union
import libcst as cst import libcst as cst
from create_dependency_mapping import find_priority_list from create_dependency_mapping import find_priority_list
@ -177,7 +176,7 @@ DOCSTRING_NODE = m.SimpleStatementLine(
) )
def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]: def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None:
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the """Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
successive value of an Attribute are not Name nodes, return `None`.""" successive value of an Attribute are not Name nodes, return `None`."""
if m.matches(node, m.Name()): if m.matches(node, m.Name()):
@ -378,11 +377,11 @@ class ReplaceSuperCallTransformer(cst.CSTTransformer):
def find_all_dependencies( def find_all_dependencies(
dependency_mapping: dict[str, set], dependency_mapping: dict[str, set],
start_entity: Optional[str] = None, start_entity: str | None = None,
initial_dependencies: Optional[set] = None, initial_dependencies: set | None = None,
initial_checked_dependencies: Optional[set] = None, initial_checked_dependencies: set | None = None,
return_parent: bool = False, return_parent: bool = False,
) -> Union[list, set]: ) -> list | set:
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
@ -476,7 +475,7 @@ class ClassDependencyMapper(CSTVisitor):
""" """
def __init__( def __init__(
self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None
): ):
super().__init__() super().__init__()
self.class_name = class_name self.class_name = class_name
@ -504,7 +503,7 @@ def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> s
def augmented_dependencies_for_class_node( def augmented_dependencies_for_class_node(
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None
) -> set: ) -> set:
"""Create augmented dependencies for a class node based on a `mapper`. """Create augmented dependencies for a class node based on a `mapper`.
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
@ -1659,8 +1658,8 @@ def get_class_node_and_dependencies(
def create_modules( def create_modules(
modular_mapper: ModularFileMapper, modular_mapper: ModularFileMapper,
file_path: Optional[str] = None, file_path: str | None = None,
package_name: Optional[str] = "transformers", package_name: str | None = "transformers",
) -> dict[str, cst.Module]: ) -> dict[str, cst.Module]:
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary.""" """Create all the new modules based on visiting the modular file. It replaces all classes as necessary."""
files = defaultdict(dict) files = defaultdict(dict)
@ -1747,7 +1746,7 @@ def run_ruff(code, check=False):
return stdout.decode() return stdout.decode()
def convert_modular_file(modular_file: str, source_library: Optional[str] = "transformers") -> dict[str, str]: def convert_modular_file(modular_file: str, source_library: str | None = "transformers") -> dict[str, str]:
"""Convert a `modular_file` into all the different model-specific files it depicts.""" """Convert a `modular_file` into all the different model-specific files it depicts."""
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {} output = {}
@ -1818,7 +1817,7 @@ def count_loc(file_path: str) -> int:
return len([line for line in comment_less_code.split("\n") if line.strip()]) return len([line for line in comment_less_code.split("\n") if line.strip()])
def run_converter(modular_file: str, source_library: Optional[str] = "transformers"): def run_converter(modular_file: str, source_library: str | None = "transformers"):
"""Convert a modular file, and save resulting files.""" """Convert a modular file, and save resulting files."""
print(f"Converting {modular_file} to a single model single file format") print(f"Converting {modular_file} to a single model single file format")
converted_files = convert_modular_file(modular_file, source_library=source_library) converted_files = convert_modular_file(modular_file, source_library=source_library)

View File

@ -21,7 +21,7 @@ import os
import re import re
import sys import sys
import time import time
from typing import Any, Optional, Union from typing import Any
import requests import requests
from compare_test_runs import compare_job_sets from compare_test_runs import compare_job_sets
@ -120,7 +120,7 @@ def handle_stacktraces(test_results):
return stacktraces return stacktraces
def dicts_to_sum(objects: Union[dict[str, dict], list[dict]]): def dicts_to_sum(objects: dict[str, dict] | list[dict]):
if isinstance(objects, dict): if isinstance(objects, dict):
lists = objects.values() lists = objects.values()
else: else:
@ -139,7 +139,7 @@ class Message:
ci_title: str, ci_title: str,
model_results: dict, model_results: dict,
additional_results: dict, additional_results: dict,
selected_warnings: Optional[list] = None, selected_warnings: list | None = None,
prev_ci_artifacts=None, prev_ci_artifacts=None,
other_ci_artifacts=None, other_ci_artifacts=None,
): ):
@ -941,7 +941,7 @@ class Message:
time.sleep(1) time.sleep(1)
def retrieve_artifact(artifact_path: str, gpu: Optional[str]): def retrieve_artifact(artifact_path: str, gpu: str | None):
if gpu not in [None, "single", "multi"]: if gpu not in [None, "single", "multi"]:
raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.") raise ValueError(f"Invalid GPU for artifact. Passed GPU: `{gpu}`.")
@ -970,7 +970,7 @@ def retrieve_available_artifacts():
def __str__(self): def __str__(self):
return self.name return self.name
def add_path(self, path: str, gpu: Optional[str] = None): def add_path(self, path: str, gpu: str | None = None):
self.paths.append({"name": self.name, "path": path, "gpu": gpu}) self.paths.append({"name": self.name, "path": path, "gpu": gpu})
_available_artifacts: dict[str, Artifact] = {} _available_artifacts: dict[str, Artifact] = {}

View File

@ -33,7 +33,6 @@ python utils/sort_auto_mappings.py --check_only
import argparse import argparse
import os import os
import re import re
from typing import Optional
# Path are set with the intent you should run this script from the root of the repo. # Path are set with the intent you should run this script from the root of the repo.
@ -47,7 +46,7 @@ _re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"') _re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]: def sort_auto_mapping(fname: str, overwrite: bool = False) -> bool | None:
""" """
Sort all auto mappings in a file. Sort all auto mappings in a file.

View File

@ -57,7 +57,6 @@ import os
import re import re
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from git import Repo from git import Repo
@ -555,7 +554,7 @@ _re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)") _re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]: def extract_imports(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
""" """
Get the imports a given module makes. Get the imports a given module makes.
@ -637,7 +636,7 @@ def extract_imports(module_fname: str, cache: Optional[dict[str, list[str]]] = N
return result return result
def get_module_dependencies(module_fname: str, cache: Optional[dict[str, list[str]]] = None) -> list[str]: def get_module_dependencies(module_fname: str, cache: dict[str, list[str]] | None = None) -> list[str]:
""" """
Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file Refines the result of `extract_imports` to remove subfolders and get a proper list of module filenames: if a file
as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse as an import `from utils import Foo, Bar`, with `utils` being a subfolder containing many files, this will traverse
@ -734,7 +733,7 @@ def create_reverse_dependency_tree() -> list[tuple[str, str]]:
return list(set(edges)) return list(set(edges))
def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[Union[str, list[str]]]: def get_tree_starting_at(module: str, edges: list[tuple[str, str]]) -> list[str | list[str]]:
""" """
Returns the tree starting at a given module following all edges. Returns the tree starting at a given module following all edges.
@ -883,7 +882,7 @@ def create_reverse_dependency_map() -> dict[str, list[str]]:
def create_module_to_test_map( def create_module_to_test_map(
reverse_map: Optional[dict[str, list[str]]] = None, filter_models: bool = False reverse_map: dict[str, list[str]] | None = None, filter_models: bool = False
) -> dict[str, list[str]]: ) -> dict[str, list[str]]:
""" """
Extract the tests from the reverse_dependency_map and potentially filters the model tests. Extract the tests from the reverse_dependency_map and potentially filters the model tests.