[env] feat: safely bump py version to 3.10 (#2421)

### What does this PR do?

This PR safely bumps python version to 3.10 for two reasons:
1.
[`removeprefix`](https://docs.python.org/3.9/whatsnew/3.9.html#new-string-methods-to-remove-prefixes-and-suffixes)
was introduced in python 3.9
588f9728f3/verl/single_controller/ray/base.py (L498-L505)
2.
[`match`](https://docs.python.org/3.10/whatsnew/3.10.html#simple-pattern-match-to-a-literal)
was introduced in python 3.10
588f9728f3/verl/tools/utils/tool_registry.py (L81-L92)



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`


### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
This commit is contained in:
Qizhi Chen
2025-07-13 07:29:39 +08:00
committed by GitHub
parent 6519220006
commit eac4863ad7
142 changed files with 598 additions and 596 deletions

View File

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.4"
rev: "v0.12.2"
hooks:
- id: ruff
args: ["--fix", "--show-fixes", "--output-format=full"]

View File

@ -62,7 +62,7 @@ def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm")
local_dir = os.path.expanduser(local_dir)
os.makedirs(local_dir, exist_ok=True)
for dataset, name in zip([train_dataset, test_dataset], ["train", "test"]):
for dataset, name in zip([train_dataset, test_dataset], ["train", "test"], strict=True):
output = {"prompt": [], "chosen": [], "rejected": []}
for data in tqdm(dataset):
# add chosen

View File

@ -18,7 +18,7 @@
import argparse
import json
import warnings
from typing import List, Optional
from typing import Optional
import datasets
import faiss
@ -75,7 +75,7 @@ class Encoder:
self.model.eval()
@torch.no_grad()
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
def encode(self, query_list: list[str], is_query=True) -> np.ndarray:
# processing query for different encoders
if isinstance(query_list, str):
query_list = [query_list]
@ -133,13 +133,13 @@ class BaseRetriever:
def _search(self, query: str, num: int, return_score: bool):
raise NotImplementedError
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
def _batch_search(self, query_list: list[str], num: int, return_score: bool):
raise NotImplementedError
def search(self, query: str, num: int = None, return_score: bool = False):
return self._search(query, num, return_score)
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
return self._batch_search(query_list, num, return_score)
@ -190,7 +190,7 @@ class BM25Retriever(BaseRetriever):
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
results = []
scores = []
for query in query_list:
@ -237,7 +237,7 @@ class DenseRetriever(BaseRetriever):
else:
return results
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
@ -318,7 +318,7 @@ class Config:
class QueryRequest(BaseModel):
queries: List[str]
queries: list[str]
topk: Optional[int] = None
return_scores: bool = False
@ -365,7 +365,7 @@ def retrieve_endpoint(request: QueryRequest):
if request.return_scores:
# If scores are returned, combine them with results
combined = []
for doc, score in zip(single_result, scores[i]):
for doc, score in zip(single_result, scores[i], strict=True):
combined.append({"document": doc, "score": score})
resp.append(combined)
else:

View File

@ -21,7 +21,7 @@ dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"
description = "verl: Volcano Engine Reinforcement Learning for LLM"
license = {text = "Apache-2.0"} # Changed from file to text format
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.8"
requires-python = ">=3.10"
# -------------------------------
# tool.ruff - Linting configuration
@ -57,10 +57,12 @@ ignore = [
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
# `.log()` statement uses f-string
"G004",
# X | None for type annotations
"UP045",
# deprecated import
"UP035",
]
# -------------------------------

View File

@ -214,7 +214,7 @@ class RayDAPOTrainer(RayPPOTrainer):
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)

View File

@ -202,7 +202,7 @@ class RayEntropyTrainer(RayPPOTrainer):
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)

View File

@ -28,7 +28,7 @@ def _default_compute_score(
if isinstance(res, dict):
return res
elif isinstance(res, (int, float, bool)):
elif isinstance(res, int | float | bool):
return float(res)
else:
return float(res[0])

View File

@ -977,7 +977,7 @@ def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
elif len(ground_truth_elems) != len(given_elems):
is_correct = False
else:
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
# if fractions aren't reduced, then shouldn't be marked as correct
# so, we don't want to allow sympy.simplify in this case

View File

@ -96,7 +96,6 @@ import contextlib
import math
import re
from math import isclose
from typing import Union
# sympy related
from sympy import N, simplify
@ -173,8 +172,8 @@ def handle_pi(string, pi):
def math_equal(
prediction: Union[bool, float, str],
reference: Union[float, str],
prediction: bool | float | str,
reference: float | str,
include_percentage: bool = True,
tolerance: float = 1e-4,
timeout: float = 10.0,
@ -251,7 +250,7 @@ def math_equal(
if len(pred_parts) == len(ref_parts) and all(
[
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
]
):
return True
@ -277,7 +276,7 @@ def math_equal(
if len(pred_parts) == len(ref_parts) and all(
[
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
]
):
return True
@ -290,7 +289,7 @@ def math_equal(
if len(pred_matrix) == len(ref_matrix_items) and all(
[
math_equal(pred, ref, include_percentage, tolerance)
for ref, pred in zip(ref_matrix_items, pred_matrix)
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
]
):
return True
@ -312,7 +311,7 @@ def math_equal(
if len(pred_matrix) == len(ref_matrix_items) and all(
[
math_equal(pred, ref, include_percentage, tolerance)
for ref, pred in zip(ref_matrix_items, pred_matrix)
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
]
):
return True

View File

@ -100,7 +100,7 @@ def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for data_source, solution_str, ground_truth, extra_info in zip(
data_sources, solution_strs, ground_truths, extra_infos
data_sources, solution_strs, ground_truths, extra_infos, strict=True
):
future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info)
futures.append(future)

View File

@ -19,7 +19,7 @@ import logging
import math
import os
import re
from typing import Dict, List, Optional, Union
from typing import Optional
import datasets
import torch
@ -88,7 +88,7 @@ def preprocess(
assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert isinstance(slice_config, dict)
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
@ -404,12 +404,12 @@ class RLHFDataset(Dataset):
def __init__(
self,
data_files: Union[str, List[str]],
data_files: str | list[str],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
if not isinstance(data_files, (List, ListConfig)):
if not isinstance(data_files, list | ListConfig):
data_files = [data_files]
self.data_files = copy.deepcopy(data_files)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
import re
from typing import Any, Dict, Tuple
from typing import Any
import datasets
@ -32,7 +32,7 @@ class CustomSandboxFusionTool(SandboxFusionTool):
self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
code = parameters["code"]
matches = self.code_pattern.findall(code)
if matches:
@ -81,7 +81,7 @@ class CustomRLHFDataset(RLHFDataset):
print(f"dataset len: {len(self.dataframe)}")
def map_fn(self, row: Dict, *, data_source: str = None):
def map_fn(self, row: dict, *, data_source: str = None):
if data_source == "Maxwell-Jia/AIME_2024":
problem, answer = row["Problem"], row["Answer"]
elif data_source == "yentinglin/aime_2025":
@ -97,7 +97,7 @@ class CustomRLHFDataset(RLHFDataset):
}
return data
def map_fn2(self, row: Dict):
def map_fn2(self, row: dict):
content = row["prompt"][0]["content"]
row["prompt"][0]["content"] = content + answer_format
row["agent_name"] = "tool_agent"

View File

@ -17,7 +17,7 @@ Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.
import json
import re
from typing import Any, Dict, Tuple
from typing import Any
import datasets
from omegaconf import OmegaConf
@ -25,7 +25,7 @@ from omegaconf import OmegaConf
code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
def extract_code_message(content: str) -> Tuple[Dict[str, Any], str]:
def extract_code_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<code>", "</code>"
i = content.find(start)
if i == -1:
@ -54,7 +54,7 @@ def extract_code_message(content: str) -> Tuple[Dict[str, Any], str]:
return message, content[j + len(stop) :]
def extract_answer_message(content: str) -> Tuple[Dict[str, Any], str]:
def extract_answer_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<answer>", "</answer>"
i = content.find(start)
if i == -1:
@ -70,7 +70,7 @@ def extract_answer_message(content: str) -> Tuple[Dict[str, Any], str]:
return message, content[j + len(stop) :]
def extract_interpreter_message(content: str) -> Tuple[Dict[str, Any], str]:
def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<interpreter>", "</interpreter>"
i = content.find(start)
if i == -1:
@ -86,7 +86,7 @@ def extract_interpreter_message(content: str) -> Tuple[Dict[str, Any], str]:
return message, content[j + len(stop) :]
def process(row: Dict, *, tools: str):
def process(row: dict, *, tools: str):
messages = []
# extract problem

View File

@ -21,7 +21,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Dict, Optional, Type
from typing import Any, Optional
import numpy as np
import ray
@ -49,7 +49,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
WorkerType = Type[Worker]
WorkerType = type[Worker]
class Role(Enum):
@ -142,7 +142,7 @@ class ResourcePoolManager:
)
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
"""Placeholder: Computes prompt and response lengths."""
try:
# Assuming 'prompts' and 'responses' keys exist after generation/union
@ -189,7 +189,7 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
# --- Modified Metric Function ---
def compute_dpo_data_metrics(batch: DataProto) -> Dict[str, Any]:
def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
"""
Computes and returns metrics relevant for the DPO-like process.
Assumes 'batch' contains results after generation and preference marking,
@ -354,7 +354,7 @@ def compute_onlineDPO_pref(data: DataProto):
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
def _timer(name: str, timing_raw: dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
@ -634,7 +634,7 @@ class RaySPINTrainer:
import numpy as np
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples = list(zip(inputs, outputs, scores, strict=True))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling
@ -1415,7 +1415,7 @@ class RaySPINTrainer:
postfix_metrics = {
k: f"{v:.3f}" if isinstance(v, float) else v
for k, v in metrics.items()
if isinstance(v, (int, float))
if isinstance(v, int | float)
}
progress_bar.set_postfix(postfix_metrics)

View File

@ -103,7 +103,7 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None)
with torch.no_grad():
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers):
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers, strict=True):
layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight)
q = hf_layer.self_attn.q_proj.weight.view(
@ -178,7 +178,7 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
copied_numel = 0
safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq)
copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight)
for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers):
for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True):
# norm1 --> linear_qkv.norm
copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight)
# norm2 --> mlp.linear_fc1.norm
@ -227,7 +227,7 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
mgllm = mgmodel.language_model
copied_numel = 0
copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight)
for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers):
for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True):
copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight)
q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
@ -264,7 +264,7 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_
numel: int = 0
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
print(f"{numel=}")
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers, strict=True)):
numel_cur: int = numel
numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)

View File

@ -29,7 +29,7 @@ import argparse
import json
import os
import warnings
from typing import Any, Dict
from typing import Any
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
@ -52,7 +52,7 @@ def check_output_path(output_path: str):
print(f"Output path '{output_path}' created.")
def check_configs(original_config: Dict[str, Any], new_config: Dict[str, Any]) -> bool:
def check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) -> bool:
"""
Check if the original config and new config are compatible.
This is a placeholder function; actual implementation may vary based on requirements.

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import ray
from omegaconf import DictConfig
@ -23,7 +22,7 @@ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
def init_agent_loop_manager(config: DictConfig) -> Union[AgentLoopManager, RayWorkerGroup]:
def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:
# =========================== 1. Create hybrid ActorRollout workers ===========================
actor_rollout_cls = (
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker

View File

@ -13,7 +13,7 @@
# limitations under the License.
import json
import os
from typing import Any, Tuple
from typing import Any
import numpy as np
import pytest
@ -119,7 +119,7 @@ class WeatherTool(BaseTool):
schema = get_json_schema(self.get_current_temperature)
return OpenAIFunctionToolSchema(**schema)
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_current_temperature(**parameters)
return json.dumps(result), 0, {}
@ -150,7 +150,7 @@ class WeatherToolWithData(BaseTool):
"unit": unit,
}
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_temperature_date(**parameters)
return json.dumps(result), 0, {}

View File

@ -96,7 +96,7 @@ def test_data_transfer():
# takes around 40 seconds
output_lst = ray.get(output_ref)
for input_data, output_data in zip(data_list, output_lst):
for input_data, output_data in zip(data_list, output_lst, strict=True):
for key in input_data.batch.keys():
assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
input_data.batch[key],

View File

@ -80,7 +80,7 @@ def test_fused_workers():
print(y)
z = fused_wg.foo(0.1)
print(z)
for i, j in zip(y, z):
for i, j in zip(y, z, strict=True):
assert i == j
ray.shutdown()

View File

@ -21,7 +21,7 @@ This is heavily inspired from CanineTokenizer in transformers package.
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union
from typing import Optional, Sequence
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
@ -86,7 +86,7 @@ class CharTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return self._vocab_str_to_int
def _tokenize(self, text: str) -> List[str]:
def _tokenize(self, text: str) -> list[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
@ -99,8 +99,8 @@ class CharTokenizer(PreTrainedTokenizer):
return "".join(tokens)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
) -> list[int]:
sep = [self.sep_token_id]
cls = [self.cls_token_id]
result = cls + token_ids_0 + sep
@ -110,10 +110,10 @@ class CharTokenizer(PreTrainedTokenizer):
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
token_ids_0: list[int],
token_ids_1: Optional[list[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
) -> list[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
@ -126,7 +126,7 @@ class CharTokenizer(PreTrainedTokenizer):
result += ([0] * len(token_ids_1)) + [1]
return result
def get_config(self) -> Dict:
def get_config(self) -> dict:
return {
"char_ords": [ord(ch) for ch in self.characters],
"model_max_length": self.model_max_length,
@ -134,21 +134,21 @@ class CharTokenizer(PreTrainedTokenizer):
}
@classmethod
def from_config(cls, config: Dict):
def from_config(cls, config: dict):
cfg = {}
cfg["characters"] = [chr(i) for i in config["char_ords"]]
cfg["model_max_length"] = config["model_max_length"]
cfg["chat_template"] = config["chat_template"]
return cls(**cfg)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
cfg = self.get_config()
with open(cfg_file, "w") as f:
json.dump(cfg, f, indent=4)
@classmethod
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):
cfg_file = Path(save_directory) / "tokenizer_config.json"
with open(cfg_file) as f:
cfg = json.load(f)

View File

@ -20,7 +20,6 @@ Checks that every public function and class has proper docstring documentation.
import ast
import os
import sys
from typing import List, Tuple
class DocstringChecker(ast.NodeVisitor):
@ -28,7 +27,7 @@ class DocstringChecker(ast.NodeVisitor):
def __init__(self, filename: str):
self.filename = filename
self.missing_docstrings: List[Tuple[str, str, int]] = []
self.missing_docstrings: list[tuple[str, str, int]] = []
self.current_class = None
self.function_nesting_level = 0
@ -70,7 +69,7 @@ class DocstringChecker(ast.NodeVisitor):
return ast.get_docstring(node) is not None
def check_file_docstrings(filepath: str) -> List[Tuple[str, str, int]]:
def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]:
"""Check docstrings in a single file."""
try:
with open(filepath, encoding="utf-8") as f:

View File

@ -22,23 +22,22 @@ import ast
import linecache
import subprocess
from pathlib import Path
from typing import List, Set, Tuple
def get_changed_files() -> List[Path]:
def get_changed_files() -> list[Path]:
result = subprocess.run(
["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], stdout=subprocess.PIPE, text=True
)
return [Path(f) for f in result.stdout.splitlines() if f.endswith(".py")]
def get_changed_lines(file_path: Path) -> Set[int]:
def get_changed_lines(file_path: Path) -> set[int]:
result = subprocess.run(
["git", "diff", "-U0", "origin/main...HEAD", "--", str(file_path)],
stdout=subprocess.PIPE,
text=True,
)
lines: Set[int] = set()
lines: set[int] = set()
for line in result.stdout.splitlines():
if line.startswith("@@"):
for part in line.split():
@ -84,19 +83,19 @@ def has_type_annotations(node: ast.AST, debug: bool = False) -> int:
def check_file(
file_path: Path, changed_lines: Set[int], debug: bool = False
) -> Tuple[int, int, List[Tuple[Path, int, str]], List[Tuple[Path, int, str]]]:
file_path: Path, changed_lines: set[int], debug: bool = False
) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]:
with open(file_path) as f:
source: str = f.read()
tree = ast.parse(source, filename=str(file_path))
annotated = 0
total = 0
warning_lines: List[Tuple[Path, int, str]] = []
failure_lines: List[Tuple[Path, int, str]] = []
warning_lines: list[tuple[Path, int, str]] = []
failure_lines: list[tuple[Path, int, str]] = []
for node in ast.walk(tree):
if hasattr(node, "lineno") and node.lineno in changed_lines:
if isinstance(node, (ast.FunctionDef, ast.Assign, ast.AnnAssign)):
if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign):
total += 1
result = has_type_annotations(node, debug)
if result == CHECK_SUCCESS or result == CHECK_WARNING:
@ -128,8 +127,8 @@ def main() -> None:
total_changed = 0
total_annotated = 0
all_warnings: List[Tuple[Path, int, str]] = []
all_failures: List[Tuple[Path, int, str]] = []
all_warnings: list[tuple[Path, int, str]] = []
all_failures: list[tuple[Path, int, str]] = []
target_files = [args.target_file] if args.target_file is not None else get_changed_files()
for fpath in target_files:

View File

@ -57,7 +57,7 @@ def test_memory_buffers():
change_ratio = (a - a_before) / a_before
assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}"
for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()):
for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True):
assert name1 == name2
assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}"

View File

@ -79,7 +79,7 @@ def test_tensor_dict_make_iterator():
for data in data_iter_2:
data_list_2.append(data)
for data1, data2 in zip(data_list_1, data_list_2):
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
assert isinstance(data1, DataProto)
assert isinstance(data2, DataProto)
result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"]))

View File

@ -14,7 +14,7 @@
# Unit Tests for `initialize_tools_from_config`
import json
import os
from typing import Any, Tuple
from typing import Any
import pytest
from transformers.utils import get_json_schema
@ -44,7 +44,7 @@ class WeatherToolForTest(BaseTool):
schema = get_json_schema(self.get_current_temperature)
return OpenAIFunctionToolSchema(**schema)
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_current_temperature(**parameters)
return json.dumps(result), 0, {}
@ -75,7 +75,7 @@ class WeatherToolWithDataForTest(BaseTool):
"unit": unit,
}
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_temperature_date(**parameters)
return json.dumps(result), 0, {}

View File

@ -57,7 +57,7 @@ class TestConfigComparison(unittest.TestCase):
len(legacy_config),
f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
)
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config)):
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
else:
self.assertEqual(

View File

@ -120,7 +120,7 @@ def test_prime_code():
Test PRIME code sandbox.
"""
data_source = "codecontests"
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):
score = default_compute_score(data_source, completion, ground_truth)
assert float(score) == score_
@ -136,7 +136,7 @@ def test_prime_code_sandbox_fusion():
sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL")
# Removed the previous 'if not sandbox_url' check block
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):
score = default_compute_score(
data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}
) # <-- Use the URL obtained from the environment variable
@ -180,6 +180,6 @@ def test_check_correctness():
def test_prime_math():
data_source = "numina_aops_forum"
for completion, ground_truth in zip(prime_math_answers, prime_math_gts):
for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True):
score = default_compute_score(data_source, completion, ground_truth)
assert float(score) == 1.0

View File

@ -146,7 +146,9 @@ def test_flops_counter(config_type: str):
test_config = CONFIG[config_type]
config = Config(test_config["config"])
flops_counter = FlopsCounter(config)
for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]):
for batch_seqlens, expected_flops in zip(
test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
):
# set delta time to 1 to get the flops
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")

View File

@ -30,7 +30,6 @@
# limitations under the License.
import os
import typing
import torch
@ -48,7 +47,7 @@ MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
def run_torch_entropy(
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
) -> typing.List[torch.Tensor]:
) -> list[torch.Tensor]:
hidden = hidden.squeeze(0).to(torch.float32)
weight = weight.transpose(0, 1).to(torch.float32)
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
@ -67,7 +66,7 @@ def run_verl_original_entropy(
weight: torch.Tensor,
labels: torch.Tensor,
temperature: float,
) -> typing.List[torch.Tensor]:
) -> list[torch.Tensor]:
hidden = hidden.squeeze(0).to(torch.float32)
weight = weight.transpose(0, 1).to(torch.float32)
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]

View File

@ -30,7 +30,6 @@
# limitations under the License.
import os
import typing
import torch
import torch.distributed as dist
@ -57,7 +56,7 @@ LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)
def run_torch_entropy(
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
) -> typing.List[torch.Tensor]:
) -> list[torch.Tensor]:
# [num_tokens, vocab_size]
if len(hidden.shape) > 2:
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]

View File

@ -106,9 +106,12 @@ class TestNsightSystemsProfiler(unittest.TestCase):
def test_func(self, *args, **kwargs):
return "result"
with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch(
"verl.utils.profiler.nvtx_profile.mark_start_range"
) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range:
with (
patch("torch.cuda.profiler.start") as mock_start,
patch("torch.cuda.profiler.stop") as mock_stop,
patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
):
result = test_func(mock_self)
self.assertEqual(result, "result")
mock_start_range.assert_called_once()
@ -127,9 +130,12 @@ class TestNsightSystemsProfiler(unittest.TestCase):
def test_func(self, *args, **kwargs):
return "result"
with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch(
"verl.utils.profiler.nvtx_profile.mark_start_range"
) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range:
with (
patch("torch.cuda.profiler.start") as mock_start,
patch("torch.cuda.profiler.stop") as mock_stop,
patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
):
result = test_func(mock_self)
self.assertEqual(result, "result")
mock_start_range.assert_called_once()

View File

@ -32,7 +32,6 @@ packages:
import os
import time
from typing import Tuple, Union
import ray
from omegaconf import DictConfig
@ -72,7 +71,7 @@ def init_config(n_gpus_per_node) -> DictConfig:
return config
def initialize(config, backend) -> Tuple[Union[AgentLoopManager, RayWorkerGroup], StatefulDataLoader]:
def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]:
env_vars = {
"NCCL_DEBUG": "WARN",
"VLLM_USE_V1": "1",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Tuple
from typing import Any
import numpy as np
import pytest
@ -120,7 +120,7 @@ class WeatherTool(BaseTool):
schema = get_json_schema(self.get_current_temperature)
return OpenAIFunctionToolSchema(**schema)
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_current_temperature(**parameters)
return json.dumps(result), 0, {}
@ -151,7 +151,7 @@ class WeatherToolWithData(BaseTool):
"unit": unit,
}
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
try:
result = self.get_temperature_date(**parameters)
return json.dumps(result), 0, {}

View File

@ -55,7 +55,7 @@ def are_lists_similar(a, b):
total_length = 0
total_diff = 0
for s1, s2 in zip(a, b):
for s1, s2 in zip(a, b, strict=True):
max_len = max(len(s1), len(s2))
total_length += max_len
diff = levenshtein(s1, s2)

View File

@ -19,7 +19,7 @@ import socket
import sys
import tempfile
from contextlib import asynccontextmanager
from typing import Any, Dict, List
from typing import Any
import fastapi
import numpy as np
@ -128,7 +128,7 @@ class CustomCompletionCallback(ToolCompletionCallback):
# TODO: support asyncio executor
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5))
async def sandbox_code_execution(self, code: str) -> Dict[str, Any]:
async def sandbox_code_execution(self, code: str) -> dict[str, Any]:
loop = asyncio.get_running_loop()
result_status, metadata = await loop.run_in_executor(
self.executor,
@ -153,7 +153,7 @@ class CustomCompletionCallback(ToolCompletionCallback):
}
return extra
async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]):
async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]):
role, content, finish_reason = (
completions.choices[0].message.role,
completions.choices[0].message.content,

View File

@ -262,10 +262,11 @@ class TestRolloutWithMCPSearchTools:
},
}
]
with patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema), patch.object(
SGLangRollout, "_init_distributed_env", return_value=None
), patch.object(SGLangRollout, "_init_inference_engine", return_value=None), patch.object(
SGLangRollout, "_init_sampling_params", return_value=None
with (
patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema),
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
rollout = SGLangRollout(
actor_module="",
@ -355,7 +356,7 @@ class TestRolloutWithMCPSearchTools:
mock_rollout._handle_engine_call = MagicMock()
futures = [asyncio.Future() for i in expect_turn_array]
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
i.set_result(
{
"text": turn,
@ -420,7 +421,7 @@ class TestRolloutWithMCPSearchTools:
req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))
futures = [asyncio.Future() for _ in expect_turn_array]
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)):
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
fut.set_result(
{
"text": turn,

View File

@ -166,9 +166,11 @@ class TestRolloutWithSearchTools:
@pytest.fixture
def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):
"""Mock the rollout instance with sampling_params initialized."""
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
rollout = SGLangRollout(
actor_module="",
config=search_rollout_config,
@ -308,7 +310,7 @@ class TestRolloutWithSearchTools:
mock_rollout._handle_engine_call = MagicMock()
futures = [asyncio.Future() for i in expect_turn_array]
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
i.set_result(
{
"text": turn,
@ -376,7 +378,7 @@ class TestRolloutWithSearchTools:
req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))
futures = [asyncio.Future() for _ in expect_turn_array]
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)):
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
fut.set_result(
{
"text": turn,

View File

@ -117,9 +117,11 @@ class TestSGLangMultiInteraction:
try:
# Mock SGLang engine and initialization methods like the reference test
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
# Create a real tokenizer like the reference test
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
@ -172,9 +174,11 @@ class TestSGLangMultiInteraction:
config, temp_config_path = create_mock_config_with_multi_interactions()
try:
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
@ -282,9 +286,11 @@ class TestSGLangMultiInteraction:
)
try:
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
@ -321,9 +327,11 @@ class TestSGLangMultiInteraction:
config, temp_config_path = create_mock_config_with_multi_interactions()
try:
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
@ -388,9 +396,11 @@ class TestSGLangMultiInteraction:
}
)
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
SGLangRollout, "_init_inference_engine", return_value=None
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
with (
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

View File

@ -43,7 +43,7 @@ def are_lists_similar(a, b, threshold=10):
return False
total_length = 0
total_diff = 0
for s1, s2 in zip(a, b):
for s1, s2 in zip(a, b, strict=True):
max_len = max(len(s1), len(s2))
total_length += max_len
total_diff += levenshtein(s1, s2)

View File

@ -17,7 +17,7 @@ import logging
import os
import random
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from typing import Any
import numpy as np
import ray
@ -46,7 +46,7 @@ class AsyncLLMServerManager:
- Sticky session: send multi-turn chat completions to same server for automatic prefix caching
"""
def __init__(self, config: DictConfig, server_handles: List[ray.actor.ActorHandle], max_cache_size: int = 10000):
def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):
"""Initialize the AsyncLLMServerManager.
Args:
@ -81,9 +81,9 @@ class AsyncLLMServerManager:
self,
request_id,
*,
prompt_ids: List[int],
sampling_params: Dict[str, Any],
) -> List[int]:
prompt_ids: list[int],
sampling_params: dict[str, Any],
) -> list[int]:
"""Generate tokens from prompt ids.
Args:
@ -113,9 +113,9 @@ class AgentLoopMetrics(BaseModel):
class AgentLoopOutput(BaseModel):
"""Agent loop output."""
prompt_ids: List[int]
response_ids: List[int]
response_mask: List[int]
prompt_ids: list[int]
response_ids: list[int]
response_mask: list[int]
num_turns: int = 0
metrics: AgentLoopMetrics
@ -148,7 +148,7 @@ class AgentLoopBase(ABC):
cls._class_initialized = True
@abstractmethod
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
"""Run agent loop to interact with LLM server and environment.
Args:
@ -165,7 +165,7 @@ class AgentLoopBase(ABC):
class AgentLoopWorker:
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
def __init__(self, config: DictConfig, server_handles: List[ray.actor.ActorHandle]):
def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]):
"""Initialize agent loop manager.
Args:
@ -236,7 +236,7 @@ class AgentLoopWorker:
trajectory_info = await get_trajectory_info(batch.meta_info.get("global_steps", -1), index)
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info):
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
tasks.append(
asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory))
)
@ -248,9 +248,9 @@ class AgentLoopWorker:
async def _run_agent_loop(
self,
agent_name: str,
messages: List[Dict[str, Any]],
sampling_params: Dict[str, Any],
trajectory: Dict[str, Any],
messages: list[dict[str, Any]],
sampling_params: dict[str, Any],
trajectory: dict[str, Any],
) -> AgentLoopOutput:
with rollout_trace_attr(
step=trajectory["step"], sample_index=trajectory["sample_index"], rollout_n=trajectory["rollout_n"]
@ -260,7 +260,7 @@ class AgentLoopWorker:
output = await agent_loop.run(messages, sampling_params)
return output
def get_agent_loop_class(self, agent_name: str) -> Type[AgentLoopBase]:
def get_agent_loop_class(self, agent_name: str) -> type[AgentLoopBase]:
"""Get the appropriate agent loop class based on agent name.
Factory method that returns the correct agent loop class implementation
@ -285,7 +285,7 @@ class AgentLoopWorker:
return ToolAgentLoop
raise ValueError(f"Unknown agent_name: {agent_name}")
def _postprocess(self, inputs: List[AgentLoopOutput]) -> DataProto:
def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
@ -452,7 +452,10 @@ class AgentLoopManager:
self.wake_up()
chunkes = prompts.chunk(len(self.agent_loop_workers))
outputs = ray.get(
[worker.generate_sequences.remote(chunk) for worker, chunk in zip(self.agent_loop_workers, chunkes)]
[
worker.generate_sequences.remote(chunk)
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
]
)
output = DataProto.concat(outputs)
if self.config.actor_rollout_ref.rollout.free_cache_engine:
@ -465,7 +468,7 @@ class AgentLoopManager:
output.meta_info = {"timing": timing}
return output
def _performance_metrics(self, metrics: List[List[Dict[str, str]]], output: DataProto) -> Dict[str, float]:
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
timing = {}
t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk])
t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk])

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
import os
from typing import Any, Dict, List
from typing import Any
from uuid import uuid4
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
@ -31,7 +31,7 @@ class SingleTurnAgentLoop(AgentLoopBase):
self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
self.response_length = config.actor_rollout_ref.rollout.response_length
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(

View File

@ -16,7 +16,7 @@ import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any
from uuid import uuid4
import regex as re
@ -46,7 +46,7 @@ class FunctionCall(BaseModel):
class ToolParser(ABC):
@abstractmethod
async def extract_tool_calls(self, responses_ids: List[int]) -> List[FunctionCall]:
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
"""Extract tool calls from the responses.
Args:
@ -69,7 +69,7 @@ class HermesToolParser(ToolParser):
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
@rollout_trace_op
async def extract_tool_calls(self, responses_ids: List[int]) -> List[FunctionCall]:
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
loop = asyncio.get_running_loop()
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
@ -117,7 +117,7 @@ class ToolAgentLoop(AgentLoopBase):
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
@rollout_trace_op
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
metrics = {}
request_id = uuid4().hex
prompt_ids = await self.loop.run_in_executor(
@ -194,7 +194,7 @@ class ToolAgentLoop(AgentLoopBase):
)
return output
async def _call_tool(self, tool_call: FunctionCall) -> Dict[str, str]:
async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]:
"""Call tool and return tool response."""
tool, instance_id = None, None
try:

View File

@ -21,7 +21,7 @@ on rollout data.
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import Optional
import datasets
from omegaconf import DictConfig
@ -73,7 +73,7 @@ class DynamicGenDataset(RLHFDataset):
def __init__(
self,
data_files: Union[str, List[str]],
data_files: str | list[str],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,

View File

@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
class BaseInteraction:
def __init__(self, config: Dict[str, Any]):
def __init__(self, config: dict[str, Any]):
self.config = config
self.name: str = config.get("name", "interaction_agent") # More general agent default role name
@ -37,8 +37,8 @@ class BaseInteraction:
return instance_id
async def generate_response(
self, instance_id: str, messages: List[Dict[str, Any]], **kwargs
) -> Tuple[bool, str, float, Dict[str, Any]]: # More clear response generation method
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method
"""
Generates a response for the current turn of interaction.
Returns a tuple containing:
@ -50,7 +50,7 @@ class BaseInteraction:
should_terminate_sequence: bool = False # if True, end rollout
response_content: str = "Your current result seems acceptable."
current_turn_score: float = 0.8
additional_data: Dict[str, Any] = {}
additional_data: dict[str, Any] = {}
return should_terminate_sequence, response_content, current_turn_score, additional_data
async def calculate_score(self) -> float: # More clear score calculation method

View File

@ -16,7 +16,7 @@
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
from verl.utils.reward_score import gsm8k
@ -53,8 +53,8 @@ class Gsm8kInteraction(BaseInteraction):
return instance_id
async def generate_response(
self, instance_id: str, messages: List[Dict[str, Any]], **kwargs
) -> Tuple[bool, str, float, dict]:
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
) -> tuple[bool, str, float, dict]:
content = ""
for i in range(len(messages) - 1, -1, -1):
item = messages[i]

View File

@ -16,7 +16,7 @@ import os
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, ContextManager, Dict, List
from typing import Any, Callable, ContextManager
import torch
from accelerate import init_empty_weights
@ -140,7 +140,7 @@ class MegatronModelMerger(BaseModelMerger):
"output_layer": "lm_head",
}
def _load_state_dicts(self, model_ckpt_path: str) -> Dict[str, Any]:
def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:
"""_summary_
Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.
@ -270,7 +270,7 @@ class MegatronModelMerger(BaseModelMerger):
else:
return [tensor]
def _merge_state_dicts(self, model_state_dict_list: List[Dict[str, Any]]) -> dict[str, torch.Tensor]:
def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
state_dict = {}
layers_cum = 0
@ -306,7 +306,7 @@ class MegatronModelMerger(BaseModelMerger):
state_dict[hf_name] = split_tensor[0]
elif len(split_tensor) == 3:
# split qkv
for n, d in zip(["q", "k", "v"], split_tensor):
for n, d in zip(["q", "k", "v"], split_tensor, strict=True):
state_dict[hf_name.replace("qkv", n)] = d
elif len(split_tensor) == 2:
# split gate up

View File

@ -86,7 +86,7 @@ def load_state_dict_to_megatron_llama(
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -86,7 +86,7 @@ def load_state_dict_to_megatron_llama(
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -100,7 +100,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -19,7 +19,7 @@
# limitations under the License.
import math
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@ -290,7 +290,7 @@ class ParallelLlamaAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional
import torch
from megatron.core import ModelParallelConfig
@ -49,7 +49,7 @@ class ParallelLlamaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -119,7 +119,7 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
hidden_states = self.input_layernorm(hidden_states)

View File

@ -19,7 +19,7 @@
# limitations under the License.
"""PyTorch LLaMA model with Megatron-style acceleration."""
from typing import Optional, Tuple, Union
from typing import Optional
import torch
import torch.utils.checkpoint
@ -125,7 +125,7 @@ class ParallelLlamaModel(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -184,7 +184,7 @@ class ParallelLlamaForCausalLM(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -255,7 +255,7 @@ class ParallelLlamaModelRmPad(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -325,7 +325,7 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -404,7 +404,7 @@ class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
output = super().forward(input_ids, attention_mask, position_ids)
output.logits = torch.squeeze(output.logits, dim=-1)
return output
@ -487,7 +487,7 @@ class ParallelLlamaModelRmPadPP(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -595,7 +595,7 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -679,7 +679,7 @@ class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
if self.post_process:
output.logits = torch.squeeze(output.logits, dim=-1)

View File

@ -87,7 +87,7 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -17,7 +17,7 @@ Registry module for model architecture components.
"""
from enum import Enum
from typing import Callable, Dict, Type
from typing import Callable
import torch
import torch.nn as nn
@ -73,7 +73,7 @@ class SupportedModel(Enum):
# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
SupportedModel.LLAMA: hf_to_mcore_config_dense,
SupportedModel.QWEN2: hf_to_mcore_config_dense,
SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
@ -87,7 +87,7 @@ MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig
}
# Registry for model initializers
MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {
SupportedModel.LLAMA: DenseModel,
SupportedModel.QWEN2: DenseModel,
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
@ -101,7 +101,7 @@ MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
}
# Registry for model forward functions
MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.LLAMA: gptmodel_forward,
SupportedModel.QWEN2: gptmodel_forward,
SupportedModel.QWEN2_MOE: gptmodel_forward,
@ -116,7 +116,7 @@ MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
}
# Registry for model forward functions
MODEL_FORWARD_FUSED_REGISTRY: Dict[SupportedModel, Callable] = {
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.LLAMA: fused_forward_gptmodel,
SupportedModel.QWEN2: fused_forward_gptmodel,
SupportedModel.QWEN2_MOE: fused_forward_gptmodel,
@ -131,7 +131,7 @@ MODEL_FORWARD_FUSED_REGISTRY: Dict[SupportedModel, Callable] = {
}
# Registry for model weight converters
MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {
SupportedModel.LLAMA: McoreToHFWeightConverterDense,
SupportedModel.QWEN2: McoreToHFWeightConverterDense,
SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,

View File

@ -112,7 +112,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -84,7 +84,7 @@ def load_state_dict_to_megatron_qwen2(
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -84,7 +84,7 @@ def load_state_dict_to_megatron_qwen2(
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -100,7 +100,7 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
if not isinstance(wrapped_models, list | tuple):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size

View File

@ -19,7 +19,7 @@
# limitations under the License.
import math
from typing import Optional, Tuple
from typing import Optional
import torch.nn.functional as F
from einops import rearrange
@ -235,7 +235,7 @@ class ParallelQwen2Attention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)

View File

@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional
import torch
from megatron.core import ModelParallelConfig
@ -49,7 +49,7 @@ class ParallelQwen2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -119,7 +119,7 @@ class ParallelQwen2DecoderLayerRmPad(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
hidden_states = self.input_layernorm(hidden_states)

View File

@ -19,7 +19,7 @@
# limitations under the License.
"""PyTorch Qwen2 model."""
from typing import Optional, Tuple, Union
from typing import Optional
import torch
import torch.utils.checkpoint
@ -126,7 +126,7 @@ class ParallelQwen2Model(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -185,7 +185,7 @@ class ParallelQwen2ForCausalLM(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -256,7 +256,7 @@ class ParallelQwen2ModelRmPad(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -326,7 +326,7 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -405,7 +405,7 @@ class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
output = super().forward(input_ids, attention_mask, position_ids)
output.logits = torch.squeeze(output.logits, dim=-1)
return output
@ -487,7 +487,7 @@ class ParallelQwen2ModelRmPadPP(nn.Module):
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
) -> tuple | BaseModelOutputWithPast:
"""
Args:
@ -645,7 +645,7 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -728,7 +728,7 @@ class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> tuple | CausalLMOutputWithPast:
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
if self.post_process:
output.logits = torch.squeeze(output.logits, dim=-1)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import importlib
from typing import List, Optional, Type
from typing import Optional
import torch.nn as nn
@ -38,7 +38,7 @@ _MODELS = {
# return model class
class ModelRegistry:
@staticmethod
def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]:
if model_arch not in _MODELS:
return None
@ -54,5 +54,5 @@ class ModelRegistry:
return getattr(module, model_cls_name, None)
@staticmethod
def get_supported_archs() -> List[str]:
def get_supported_archs() -> list[str]:
return list(_MODELS.keys())

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from transformers.cache_utils import Cache
@ -73,7 +73,7 @@ def forward_with_torch_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -81,10 +81,10 @@ def forward_with_torch_backend(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
logits_to_keep: int | torch.Tensor = 0,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputForPPO]:
) -> tuple | CausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO
outputs = forward_base_model(
@ -135,7 +135,7 @@ def forward_with_triton_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -143,10 +143,10 @@ def forward_with_triton_backend(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
logits_to_keep: int | torch.Tensor = 0,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputForPPO]:
) -> tuple | CausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
outputs = forward_base_model(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@ -93,7 +93,7 @@ def _ulysses_flash_attn_forward(
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:

View File

@ -13,7 +13,7 @@
# limitations under the License.
import sys
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import torch
@ -46,9 +46,9 @@ def llama_flash_attn_forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
@ -168,12 +168,12 @@ def llama_flash_attn_forward(
def llama_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch_npu
@ -28,7 +27,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm
# https://github.com/huggingface/transformers/pull/38491
def apply_rotary_pos_emb_flashatt_npu(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
cos = cos.repeat(1, 2)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple
from typing import Callable, Optional
import torch
from transformers.cache_utils import Cache
@ -39,7 +39,7 @@ def qwen2_flash_attn_forward(
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
):
"""
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
@ -157,12 +157,12 @@ def qwen2_flash_attn_forward(
def qwen2_attn_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional
import torch
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
@ -33,7 +33,7 @@ def forward_base_model(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -46,7 +46,7 @@ def forward_base_model(
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
r"""
Copy paste Qwen2_5_VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
@ -143,7 +143,7 @@ def forward_with_torch_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -159,7 +159,7 @@ def forward_with_torch_backend(
second_per_grid_ts: Optional[torch.Tensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO
outputs = forward_base_model(
@ -218,7 +218,7 @@ def forward_with_triton_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -234,7 +234,7 @@ def forward_with_triton_backend(
second_per_grid_ts: Optional[torch.Tensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
outputs = forward_base_model(

View File

@ -15,7 +15,7 @@
import inspect
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional
import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward
@ -230,9 +230,9 @@ def ulysses_flash_attn_forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, None, None]:
) -> tuple[torch.Tensor, None, None]:
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
@ -315,7 +315,7 @@ def forward_base_model(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -327,7 +327,7 @@ def forward_base_model(
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
) -> tuple | Qwen2VLCausalLMOutputWithPast:
r"""
Copy paste Qwen2VL's forward
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
@ -418,7 +418,7 @@ def forward_with_torch_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -433,7 +433,7 @@ def forward_with_torch_backend(
cache_position: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
) -> tuple | Qwen2VLCausalLMOutputForPPO:
from verl.utils.experimental.torch_functional import FusedLinearForPPO
outputs = forward_base_model(
@ -491,7 +491,7 @@ def forward_with_triton_backend(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -506,7 +506,7 @@ def forward_with_triton_backend(
cache_position: Optional[torch.LongTensor] = None,
temperature: float = 1.0,
**loss_kwargs,
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
) -> tuple | Qwen2VLCausalLMOutputForPPO:
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
outputs = forward_base_model(

View File

@ -22,7 +22,7 @@ import logging
import os
import pickle
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Optional
import numpy as np
import pandas as pd
@ -200,8 +200,8 @@ def collate_fn(x: list["DataProtoItem"]):
class DataProtoItem:
# TODO(zhangchi.usc1992) add consistency check
batch: TensorDict = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)
non_tensor_batch: dict = field(default_factory=dict)
meta_info: dict = field(default_factory=dict)
@dataclass
@ -214,8 +214,8 @@ class DataProto:
"""
batch: TensorDict = None
non_tensor_batch: Dict = field(default_factory=dict)
meta_info: Dict = field(default_factory=dict)
non_tensor_batch: dict = field(default_factory=dict)
meta_info: dict = field(default_factory=dict)
def __post_init__(self):
# perform necessary checking
@ -251,11 +251,11 @@ class DataProto:
return self.slice(item.start, item.stop, item.step)
# Case 2: List, numpy array, or torch tensor - use sel_idxs
elif isinstance(item, (list, np.ndarray, torch.Tensor)):
elif isinstance(item, list | np.ndarray | torch.Tensor):
return self.select_idxs(item)
# Case 3: Single integer - return DataProtoItem for backward compatibility
elif isinstance(item, (int, np.integer)):
elif isinstance(item, int | np.integer):
tensor_data = self.batch[item] if self.batch is not None else None
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
@ -343,7 +343,7 @@ class DataProto:
)
@classmethod
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None, auto_padding=False):
def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False):
"""Create a DataProto from a dict of tensors and non_tensors"""
tensors = {}
non_tensors = {}
@ -361,7 +361,7 @@ class DataProto:
@classmethod
def from_dict(
cls,
tensors: Optional[Dict[str, torch.Tensor]] = None,
tensors: Optional[dict[str, torch.Tensor]] = None,
non_tensors=None,
meta_info=None,
num_batch_dims=1,
@ -649,7 +649,7 @@ class DataProto:
else:
generator = None
assert isinstance(dataloader_kwargs, Dict)
assert isinstance(dataloader_kwargs, dict)
train_dataloader = DataLoader(
dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
)
@ -686,7 +686,7 @@ class DataProto:
self.batch = padded_dp.batch
self.non_tensor_batch = padded_dp.non_tensor_batch
def chunk(self, chunks: int) -> List["DataProto"]:
def chunk(self, chunks: int) -> list["DataProto"]:
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
Args:
@ -728,7 +728,7 @@ class DataProto:
return output
@staticmethod
def concat(data: List["DataProto"]) -> "DataProto":
def concat(data: list["DataProto"]) -> "DataProto":
"""Concat a list of DataProto. The batch is concatenated among dim=0.
The meta_info is assumed to be identical and will use the first one.
@ -802,7 +802,7 @@ class DataProto:
meta_info=self.meta_info,
)
def unfold_column_chunks(self, n_split: int, split_keys: Optional[List[str]] = None):
def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None):
"""Split along the second dim into `n_split`, unfold it to the first dim (batch dim)
Useful in passing grouped tensors that doesn't want to be shuffled in dataset.
keys not in split_keys are repeated to match the shape
@ -906,15 +906,15 @@ class DataProtoFuture:
"""
collect_fn: Callable
futures: List[ray.ObjectRef]
futures: list[ray.ObjectRef]
dispatch_fn: Callable = None
@staticmethod
def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture":
def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture":
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
return output
def chunk(self, chunks: int) -> List["DataProtoFuture"]:
def chunk(self, chunks: int) -> list["DataProtoFuture"]:
from functools import partial
arg_future_lst = []

View File

@ -15,7 +15,6 @@
import inspect
from functools import wraps
from types import FunctionType
from typing import Dict, List, Tuple
from verl.protocol import DataProtoFuture, _padding_size_key
from verl.utils.py_functional import DynamicEnum
@ -79,12 +78,12 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
assert isinstance(arg, DataProto | DataProtoFuture)
splitted_args.append(arg.chunk(chunks=chunks))
splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
assert isinstance(val, DataProto | DataProtoFuture)
splitted_kwargs[key] = val.chunk(chunks=chunks)
return splitted_args, splitted_kwargs
@ -99,7 +98,7 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
data_proto_len = None
padding_size = None
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
assert isinstance(arg, DataProto | DataProtoFuture)
if isinstance(arg, DataProto) and arg.is_padding_enabled():
# for padding, we only support DataProto with same length
if data_proto_len is None:
@ -116,7 +115,7 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
splitted_args.append(arg.chunk(chunks=chunks))
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
assert isinstance(val, DataProto | DataProtoFuture)
if isinstance(val, DataProto) and val.is_padding_enabled():
# for padding, we only support DataProto with same length
if data_proto_len is None:
@ -169,7 +168,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
all_args = []
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
assert isinstance(arg, tuple | list) and len(arg) == worker_group.dp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
@ -179,7 +178,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
assert isinstance(v, tuple | list) and len(v) == worker_group.dp_size
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
@ -216,7 +215,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
def _concat_data_proto_or_future(output: List):
def _concat_data_proto_or_future(output: list):
import ray
from verl.protocol import DataProto, DataProtoFuture
@ -245,7 +244,7 @@ def collect_megatron_compute_data_proto(worker_group, output):
output = collect_megatron_compute(worker_group, output)
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}"
return _concat_data_proto_or_future(output)
@ -265,7 +264,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
all_args = []
for arg in args:
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_cp_size
assert isinstance(arg, list | tuple) and len(arg) == pp_dp_cp_size
transformed_args = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
@ -290,7 +289,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
all_kwargs = {}
for k, v in kwargs.items():
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}"
assert isinstance(v, list | tuple) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}"
transformed_v = []
for i in range(worker_group.world_size):
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
@ -359,9 +358,9 @@ def dispatch_dp_compute(worker_group, *args, **kwargs):
assert isinstance(worker_group, WorkerGroup)
for arg in args:
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size
for k, v in kwargs.items():
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
assert isinstance(v, tuple | list) and len(v) == worker_group.world_size
return args, kwargs
@ -403,7 +402,7 @@ def collect_dp_compute_data_proto(worker_group, output):
from verl.protocol import DataProto
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}"
output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
@ -489,10 +488,10 @@ def get_predefined_execute_fn(execute_mode):
def _check_dispatch_mode(dispatch_mode):
assert isinstance(dispatch_mode, (Dispatch, Dict)), (
assert isinstance(dispatch_mode, Dispatch | dict), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
)
if isinstance(dispatch_mode, Dict):
if isinstance(dispatch_mode, dict):
necessary_keys = ["dispatch_fn", "collect_fn"]
for key in necessary_keys:
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from verl.single_controller.base import ResourcePool, WorkerGroup
@ -25,7 +24,7 @@ class MegatronWorkerGroup(WorkerGroup):
self._megatron_rank_info = None
self._megatron_global_info: DistGlobalInfo = None
def init_megatron(self, default_megatron_kwargs: Dict = None):
def init_megatron(self, default_megatron_kwargs: dict = None):
raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten")
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import ray
@ -22,7 +21,7 @@ class WorkerGroupRegisterCenter:
def __init__(self, rank_zero_info):
self.rank_zero_info = rank_zero_info
# rank -> node_id
self.workers_info: Dict[int, str] = {}
self.workers_info: dict[int, str] = {}
def get_rank_zero_info(self):
return self.rank_zero_info
@ -30,7 +29,7 @@ class WorkerGroupRegisterCenter:
def set_worker_info(self, rank, node_id) -> None:
self.workers_info[rank] = node_id
def get_worker_info(self) -> Dict[int, str]:
def get_worker_info(self) -> dict[int, str]:
return self.workers_info

View File

@ -18,7 +18,6 @@ the class for Worker
import os
import socket
from dataclasses import dataclass
from typing import Dict
import ray
@ -246,7 +245,7 @@ class Worker(WorkerHelper):
os.environ["LOCAL_RANK"] = local_rank
get_torch_device().set_device(int(local_rank))
def _configure_with_store(self, store: Dict):
def _configure_with_store(self, store: dict):
"""
This function should only be called inside by WorkerGroup
"""

View File

@ -19,7 +19,7 @@ import logging
import signal
import threading
import time
from typing import Any, Callable, Dict, List
from typing import Any, Callable
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
@ -60,14 +60,14 @@ class ResourcePool:
def store(self):
return self._store
def local_world_size_list(self) -> List[int]:
def local_world_size_list(self) -> list[int]:
"""Returns a flat list where each process has its local world size."""
nested_local_world_size_list = [
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
]
return [item for row in nested_local_world_size_list for item in row]
def local_rank_list(self) -> List[int]:
def local_rank_list(self) -> list[int]:
"""Returns a flat list of local ranks for all processes across all nodes."""
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
return [item for row in nested_local_rank_list for item in row]
@ -99,7 +99,7 @@ class ClassWithInitArgs:
return self.cls(*self.args, **self.kwargs)
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None:
"""Continuously monitors worker processes and raises SIGABRT if any worker dies.
Args:
@ -201,7 +201,7 @@ class WorkerGroup:
if hasattr(method, MAGIC_ATTR):
# this method is decorated by register
attribute = getattr(method, MAGIC_ATTR)
assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}"
assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
dispatch_mode = attribute["dispatch_mode"]

View File

@ -17,7 +17,7 @@ import logging
import os
import time
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from unittest.mock import patch
import ray
@ -62,7 +62,7 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
return type(method_name, (Functor,), {})()
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]:
"""
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
@ -85,7 +85,7 @@ def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[Placement
class RayResourcePool(ResourcePool):
def __init__(
self,
process_on_nodes: Optional[List[int]] = None,
process_on_nodes: Optional[list[int]] = None,
use_gpu: bool = True,
name_prefix: str = None,
max_colocate_count: int = 10,
@ -134,8 +134,8 @@ class RayResourcePool(ResourcePool):
def extract_pg_from_exist(
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
) -> List:
resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool
) -> list:
src_pgs = [
pg
for role_name, resource_pool in resource_pools.items()
@ -146,7 +146,7 @@ def extract_pg_from_exist(
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
unsorted_pgs: list[tuple[int, PlacementGroup]] = []
searching_idx = 0
for request_process, original_idx in sorted_process_on_nodes:
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
@ -195,7 +195,7 @@ class RayClassWithInitArgs(ClassWithInitArgs):
"""
self._additional_resource = additional_resource
def update_options(self, options: Dict):
def update_options(self, options: dict):
"""Update the Ray actor creation options.
Args:
@ -269,7 +269,7 @@ class RayWorkerGroup(WorkerGroup):
name_prefix: str = None,
detached=False,
worker_names=None,
worker_handles: List[ray.actor.ActorHandle] = None,
worker_handles: list[ray.actor.ActorHandle] = None,
ray_wait_register_center_timeout: int = 300,
device_name="cuda",
**kwargs,
@ -499,7 +499,6 @@ class RayWorkerGroup(WorkerGroup):
prefix: str = actor_name + "_"
for method_name in dir(worker_group):
if method_name.startswith(prefix):
# only valid when Python >= 3.9
original_method_name = method_name.removeprefix(prefix)
method = getattr(worker_group, method_name)
setattr(worker_group, original_method_name, method)
@ -740,7 +739,7 @@ def _unwrap_ray_remote(cls):
return cls
def _determine_fsdp_megatron_base_class(mros: List):
def _determine_fsdp_megatron_base_class(mros: list):
"""
- megatron: base class should be MegatronWorker
- fsdp: base class should be Worker
@ -836,7 +835,11 @@ def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs])
self.init_kwargs_dict = init_kwargs_dict
for cls_name, udc, ud_args, ud_kwargs in zip(
self.cls_names, self.raw_cls_dict.values(), self.init_args_dict.values(), self.init_kwargs_dict.values()
self.cls_names,
self.raw_cls_dict.values(),
self.init_args_dict.values(),
self.init_kwargs_dict.values(),
strict=True,
):
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from typing import Optional
import ray
@ -55,7 +55,7 @@ class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
self,
resource_pool: RayResourcePool,
ray_cls_with_init: RayClassWithInitArgs,
default_megatron_kwargs: Dict = None,
default_megatron_kwargs: dict = None,
**kwargs,
):
super().__init__(
@ -70,7 +70,7 @@ class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
self.execute_rank_zero_async(method_name="get_megatron_global_info")
)
def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
def init_megatron(self, default_megatron_kwargs: Optional[dict] = None):
# after super, we will call init of each worker
if not self._is_init_with_detached_workers:
# only init_megatron if the WorkerGroup is created from scratch

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
from verl.utils.rollout_trace import rollout_trace_op
@ -58,7 +58,7 @@ class BaseTool:
return instance_id
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
"""Execute the tool.
Args:

View File

@ -16,7 +16,7 @@
import logging
import os
from typing import Any, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
from verl.utils.reward_score import geo3k
@ -75,7 +75,7 @@ class Geo3kTool(BaseTool):
return instance_id, None
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
answer = parameters.get("answer", "")
if not isinstance(answer, str):
answer = str(answer)

View File

@ -15,7 +15,7 @@
import logging
import os
from typing import Any, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
from verl.utils.reward_score import gsm8k
@ -75,7 +75,7 @@ class Gsm8kTool(BaseTool):
return instance_id
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
answer = parameters.get("answer", "")
if not isinstance(answer, str):
answer = str(answer)

View File

@ -15,7 +15,7 @@
import json
import logging
import os
from typing import Any, Optional, Tuple
from typing import Any, Optional
from uuid import uuid4
from fastmcp.exceptions import ClientError
@ -60,7 +60,7 @@ class MCPBaseTool(BaseTool):
}
return instance_id
async def _call_tool(self, instance_id, parameters) -> Tuple[str, dict]:
async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]:
err_msg = ""
try:
call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)
@ -77,7 +77,7 @@ class MCPBaseTool(BaseTool):
return result, metadata
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
if self.name == "" or self.name is None or parameters is None:
error_msg = "Error: 'parameters' is missing or empty."
logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}")
@ -111,6 +111,6 @@ class MCPBaseTool(BaseTool):
if instance_id in self._instance_dict:
del self._instance_dict[instance_id]
def _parse_tool_result(self, content: list) -> Tuple[str, dict]:
def _parse_tool_result(self, content: list) -> tuple[str, dict]:
tools_content = [part.text for part in filter(lambda x: x.type == "text", content)]
return " ".join(tools_content), {}

View File

@ -16,7 +16,6 @@ import json
import logging
import os
import re
from typing import Tuple
from verl.tools.mcp_base_tool import MCPBaseTool
@ -30,7 +29,7 @@ class MCPSearchTool(MCPBaseTool):
def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
super().__init__(config, tool_schema)
def _parse_tool_result(self, content: list) -> Tuple[str, dict]:
def _parse_tool_result(self, content: list) -> tuple[str, dict]:
res = ""
res_cnt = 0
query_list = []

View File

@ -17,7 +17,7 @@ import os
import threading
from contextlib import ExitStack
from enum import Enum
from typing import Any, Callable, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, TypeVar
from uuid import uuid4
import ray
@ -163,7 +163,7 @@ class SandboxFusionTool(BaseTool):
return instance_id
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
code = parameters.get("code", "")
timeout = parameters.get("timeout", self.default_timeout)
language = parameters.get("language", self.default_language)

View File

@ -19,7 +19,7 @@ import os
import threading
from contextlib import ExitStack
from enum import Enum
from typing import Any, Callable, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, TypeVar
from uuid import uuid4
import ray
@ -226,7 +226,7 @@ class SearchTool(BaseTool):
return result_text, metadata
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
"""Execute the search tool.
Args:

View File

@ -19,7 +19,7 @@ import threading
import time
import traceback
import uuid
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import requests
@ -33,11 +33,11 @@ logger = logging.getLogger(__name__)
def call_search_api(
retrieval_service_url: str,
query_list: List[str],
query_list: list[str],
topk: int = 3,
return_scores: bool = True,
timeout: int = DEFAULT_TIMEOUT,
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
) -> tuple[Optional[dict[str, Any]], Optional[str]]:
"""
Calls the remote search API to perform retrieval with retry logic for various errors,
using increasing delay between retries. Logs internal calls with a unique ID.
@ -140,11 +140,11 @@ def _passages2string(retrieval_result):
def perform_single_search_batch(
retrieval_service_url: str,
query_list: List[str],
query_list: list[str],
topk: int = 3,
concurrent_semaphore: Optional[threading.Semaphore] = None,
timeout: int = DEFAULT_TIMEOUT,
) -> Tuple[str, Dict[str, Any]]:
) -> tuple[str, dict[str, Any]]:
"""
Performs a single batch search for multiple queries (original search tool behavior).

View File

@ -17,7 +17,7 @@ Metrics related to the PPO trainer.
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, List
from typing import Any, Callable
import numpy as np
import torch
@ -27,7 +27,7 @@ from verl.utils.import_utils import deprecated
@deprecated("verl.utils.metric.reduce_metrics")
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
"""
Reduces a dictionary of metric lists by computing the mean of each list.
@ -47,7 +47,7 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
return reduce_metrics(metrics)
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
"""
Computes information about prompts and responses from a batch.
@ -77,7 +77,7 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
)
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]:
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:
"""
Computes various metrics from a batch of data for PPO training.
@ -180,7 +180,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str,
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:
"""
Computes timing metrics for different processing stages in PPO training.
@ -222,7 +222,7 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:
"""
Computes throughput metrics for PPO training.
@ -416,7 +416,9 @@ def process_validation_metrics(
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
if var2vals.get("pred", None) is not None:
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
vote_data = [
{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
]
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
data=vote_data,
subset_size=n,

View File

@ -26,7 +26,7 @@ from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Optional, Type
from typing import Optional
import numpy as np
import ray
@ -61,7 +61,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
WorkerType = Type[Worker]
WorkerType = type[Worker]
class Role(Enum):
@ -674,7 +674,7 @@ class RayPPOTrainer:
import numpy as np
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples = list(zip(inputs, outputs, scores, strict=True))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling

View File

@ -263,10 +263,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
torch_stray_tensor = isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor,
)
need_offload = not torch_stray_tensor
need_offload = need_offload and self.tensor_need_offloading_checker(tensor)
@ -451,7 +448,7 @@ class ActivationHandler:
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True))
return args, kwargs
def _ckpt_forward(self, forward_method, *args, **kwargs):
@ -526,7 +523,7 @@ def enable_activation_offloading(model, strategy, enable_ckpt=False):
def get_layers(module):
for name, child in module.named_children():
if not isinstance(child, (FSDP, FSDP2)):
if not isinstance(child, FSDP | FSDP2):
get_layers(child)
else:
wrapped_module = child

View File

@ -15,7 +15,6 @@
import os
import random
import shutil
from typing import Union
import numpy as np
import torch
@ -46,7 +45,7 @@ class BaseCheckpointManager:
model,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
checkpoint_config: DictConfig = None,
):
self.checkpoint_config = checkpoint_config

View File

@ -17,7 +17,7 @@ import logging
import os
import warnings
from dataclasses import asdict, dataclass
from typing import Optional, Union
from typing import Optional
import torch
import torch.distributed
@ -73,7 +73,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
model: FSDP,
optimizer: Optional[torch.optim.Optimizer] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
checkpoint_config: DictConfig = None,
**kwargs,
):

View File

@ -13,14 +13,14 @@
# limitations under the License.
from dataclasses import is_dataclass
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Optional
from omegaconf import DictConfig, ListConfig, OmegaConf
__all__ = ["omega_conf_to_dataclass"]
def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Optional[Type[Any]] = None) -> Any:
def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:
"""
Convert an OmegaConf DictConfig to a dataclass.
@ -36,7 +36,7 @@ def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Opt
if not config:
return dataclass_type if dataclass_type is None else dataclass_type()
# Got an object
if not isinstance(config, (DictConfig, ListConfig, dict, list)):
if not isinstance(config, DictConfig | ListConfig | dict | list):
return config
if dataclass_type is None:
@ -59,7 +59,7 @@ def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Opt
return config_object
def update_dict_with_config(dictionary: Dict, config: DictConfig):
def update_dict_with_config(dictionary: dict, config: DictConfig):
for key in dictionary:
if hasattr(config, key):
dictionary[key] = getattr(config, key)

View File

@ -18,7 +18,7 @@ Multi-turn SFT dataset that supports training on conversation data with multiple
import json
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional
import numpy as np
import pandas as pd
@ -48,7 +48,7 @@ class MultiTurnSFTDataset(Dataset):
Dataset for multi-turn conversations where each assistant response should be trained
"""
def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None):
def __init__(self, parquet_files: str | list[str], tokenizer, config=None):
# Set defaults and extract parameters from config if provided
config = config or {}
self.truncation = config.get("truncation", "error")
@ -60,7 +60,7 @@ class MultiTurnSFTDataset(Dataset):
self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking")
assert self.truncation in ["error", "left", "right"]
if not isinstance(parquet_files, List):
if not isinstance(parquet_files, list):
parquet_files = [parquet_files]
self.parquet_files = parquet_files
@ -80,7 +80,7 @@ class MultiTurnSFTDataset(Dataset):
import numpy
import pandas
while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1:
while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:
ls = ls[0]
return ls
@ -109,13 +109,13 @@ class MultiTurnSFTDataset(Dataset):
def _process_message_tokens(
self,
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
start_idx: int,
end_idx: int,
is_assistant: bool = False,
enable_thinking: Optional[bool] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[List[int], List[int], List[int]]:
tools: Optional[list[dict[str, Any]]] = None,
) -> tuple[list[int], list[int], list[int]]:
"""
Process tokens for a single message or a group of messages.
@ -185,10 +185,10 @@ class MultiTurnSFTDataset(Dataset):
def _validate_and_convert_tokens(
self,
full_tokens: torch.Tensor,
concat_tokens: List[int],
concat_loss_mask: List[int],
concat_attention_mask: List[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
concat_tokens: list[int],
concat_loss_mask: list[int],
concat_attention_mask: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Validate tokenization and convert to tensors.
@ -204,7 +204,7 @@ class MultiTurnSFTDataset(Dataset):
full_tokens_list = full_tokens.tolist()
if len(concat_tokens) != len(full_tokens_list) or not all(
a == b for a, b in zip(concat_tokens, full_tokens_list)
a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True)
):
logging.warning(
f"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens "

View File

@ -19,7 +19,7 @@ import logging
import os
import re
from collections import defaultdict
from typing import List, Optional, Union
from typing import Optional
import datasets
import numpy as np
@ -84,12 +84,12 @@ class RLHFDataset(Dataset):
def __init__(
self,
data_files: Union[str, List[str]],
data_files: str | list[str],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
if not isinstance(data_files, (List, ListConfig)):
if not isinstance(data_files, list | ListConfig):
data_files = [data_files]
self.data_files = copy.deepcopy(data_files)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
from typing import List, Union
import pandas as pd
import torch
@ -39,7 +38,7 @@ def download_files_distributed(download_fn):
class RMDataset(Dataset):
def __init__(
self,
parquet_files: Union[str, List[str]],
parquet_files: str | list[str],
tokenizer,
prompt_key="prompt",
chosen_key="chosen",
@ -48,7 +47,7 @@ class RMDataset(Dataset):
add_eos=True,
cache_dir="~/.cache/verl/rm",
):
if not isinstance(parquet_files, List):
if not isinstance(parquet_files, list):
parquet_files = [parquet_files]
self.parquet_files = parquet_files

View File

@ -18,8 +18,6 @@ SFT dataset
Each parquet file contains
"""
from typing import Union
import pandas as pd
import torch
from omegaconf.listconfig import ListConfig
@ -39,7 +37,7 @@ class SFTDataset(Dataset):
config (OmegaConf): the data config
"""
def __init__(self, parquet_files: Union[str, ListConfig], tokenizer, config):
def __init__(self, parquet_files: str | ListConfig, tokenizer, config):
prompt_key = config.get("prompt_key", "prompt")
prompt_dict_keys = config.get("prompt_dict_keys", None)
response_key = config.get("response_key", "response")
@ -60,8 +58,8 @@ class SFTDataset(Dataset):
tokenizer = hf_tokenizer(tokenizer)
self.tokenizer: PreTrainedTokenizer = tokenizer
self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key]
self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key]
self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key]
self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key]
self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else []
self.response_dict_keys = response_dict_keys if response_dict_keys else []
@ -79,7 +77,7 @@ class SFTDataset(Dataset):
import numpy
import pandas
while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1:
while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:
ls = ls[0]
return ls

View File

@ -13,14 +13,14 @@
# limitations under the License.
from io import BytesIO
from typing import Optional, Union
from typing import Optional
import torch
from PIL import Image
from qwen_vl_utils import fetch_image, fetch_video
def process_image(image: Union[dict, Image.Image]) -> Image.Image:
def process_image(image: dict | Image.Image) -> Image.Image:
if isinstance(image, Image.Image):
return image.convert("RGB")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import Optional
import torch
@ -22,7 +22,7 @@ def _fused_linear_for_ppo_fwd(
vocab_weights: torch.FloatTensor,
input_ids: torch.LongTensor,
temperature: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
logits = (hidden_states @ vocab_weights.t()) / temperature
orig_dtype = logits.dtype
logits = logits.to(torch.float32)
@ -44,7 +44,7 @@ def _fused_linear_for_ppo_bwd(
vocab_weights: torch.FloatTensor,
input_ids: torch.LongTensor,
temperature: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
logits = (hidden_states @ vocab_weights.t()) / temperature
orig_dtype = logits.dtype
logits = logits.to(torch.float32)
@ -81,7 +81,7 @@ class FusedLinearForPPOFunction(torch.autograd.Function):
input_ids: torch.LongTensor,
temperature: float = 1.0,
chunk_size: int = 512,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
ctx.set_materialize_grads(False)
# Cast to a 2D tensor of the shape [T, D] for ease of working
@ -205,7 +205,7 @@ class FusedLinearForPPO(torch.nn.Module):
vocab_weights: torch.FloatTensor,
input_ids: torch.LongTensor,
temperature: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
input_ids = input_ids.to(torch.int64)
return FusedLinearForPPOFunction.apply(
hidden_states,

View File

@ -19,7 +19,6 @@ import math
import os
from collections import OrderedDict
from contextlib import contextmanager, nullcontext
from typing import Dict
import torch
import torch.distributed as dist
@ -308,7 +307,7 @@ def parallel_load_safetensors(filepath):
return shard_states
def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):
def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]):
"""
Generate a function to initialize sub-modules in the `module` with `shard_states`
from huggingface checkpoint.
@ -339,7 +338,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor
else: # buffer
param = torch.empty_like(state.data, device=device)
loaded = shard_states[param_name]
if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
if isinstance(loaded, torch.nn.Parameter | torch.Tensor):
# NOTE: loaded.dtype can be different with param.dtype
param.data.copy_(loaded.data)
dist.broadcast(param.data, src=dist.get_rank())

View File

@ -21,7 +21,7 @@ import importlib.util
import os
import warnings
from functools import cache, wraps
from typing import List, Optional
from typing import Optional
@cache
@ -72,7 +72,7 @@ def is_trl_available():
def import_external_libs(external_libs=None):
if external_libs is None:
return
if not isinstance(external_libs, List):
if not isinstance(external_libs, list):
external_libs = [external_libs]
import importlib

Some files were not shown because too many files have changed in this diff Show More