mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[env] feat: safely bump py version to 3.10 (#2421)
### What does this PR do? This PR safely bumps python version to 3.10 for two reasons: 1. [`removeprefix`](https://docs.python.org/3.9/whatsnew/3.9.html#new-string-methods-to-remove-prefixes-and-suffixes) was introduced in python 3.9588f9728f3/verl/single_controller/ray/base.py (L498-L505)
2. [`match`](https://docs.python.org/3.10/whatsnew/3.10.html#simple-pattern-match-to-a-literal) was introduced in python 3.10588f9728f3/verl/tools/utils/tool_registry.py (L81-L92)
### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.11.4"
|
||||
rev: "v0.12.2"
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: ["--fix", "--show-fixes", "--output-format=full"]
|
||||
|
@ -62,7 +62,7 @@ def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm")
|
||||
local_dir = os.path.expanduser(local_dir)
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
for dataset, name in zip([train_dataset, test_dataset], ["train", "test"]):
|
||||
for dataset, name in zip([train_dataset, test_dataset], ["train", "test"], strict=True):
|
||||
output = {"prompt": [], "chosen": [], "rejected": []}
|
||||
for data in tqdm(dataset):
|
||||
# add chosen
|
||||
|
@ -18,7 +18,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import faiss
|
||||
@ -75,7 +75,7 @@ class Encoder:
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
|
||||
def encode(self, query_list: list[str], is_query=True) -> np.ndarray:
|
||||
# processing query for different encoders
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
@ -133,13 +133,13 @@ class BaseRetriever:
|
||||
def _search(self, query: str, num: int, return_score: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int, return_score: bool):
|
||||
def _batch_search(self, query_list: list[str], num: int, return_score: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def search(self, query: str, num: int = None, return_score: bool = False):
|
||||
return self._search(query, num, return_score)
|
||||
|
||||
def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
|
||||
return self._batch_search(query_list, num, return_score)
|
||||
|
||||
|
||||
@ -190,7 +190,7 @@ class BM25Retriever(BaseRetriever):
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
|
||||
results = []
|
||||
scores = []
|
||||
for query in query_list:
|
||||
@ -237,7 +237,7 @@ class DenseRetriever(BaseRetriever):
|
||||
else:
|
||||
return results
|
||||
|
||||
def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
|
||||
def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False):
|
||||
if isinstance(query_list, str):
|
||||
query_list = [query_list]
|
||||
if num is None:
|
||||
@ -318,7 +318,7 @@ class Config:
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
queries: List[str]
|
||||
queries: list[str]
|
||||
topk: Optional[int] = None
|
||||
return_scores: bool = False
|
||||
|
||||
@ -365,7 +365,7 @@ def retrieve_endpoint(request: QueryRequest):
|
||||
if request.return_scores:
|
||||
# If scores are returned, combine them with results
|
||||
combined = []
|
||||
for doc, score in zip(single_result, scores[i]):
|
||||
for doc, score in zip(single_result, scores[i], strict=True):
|
||||
combined.append({"document": doc, "score": score})
|
||||
resp.append(combined)
|
||||
else:
|
||||
|
@ -21,7 +21,7 @@ dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"
|
||||
description = "verl: Volcano Engine Reinforcement Learning for LLM"
|
||||
license = {text = "Apache-2.0"} # Changed from file to text format
|
||||
readme = {file = "README.md", content-type = "text/markdown"}
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
# -------------------------------
|
||||
# tool.ruff - Linting configuration
|
||||
@ -57,10 +57,12 @@ ignore = [
|
||||
"B007",
|
||||
# f-string format
|
||||
"UP032",
|
||||
# Can remove once 3.10+ is the minimum Python version
|
||||
"UP007",
|
||||
# `.log()` statement uses f-string
|
||||
"G004",
|
||||
# X | None for type annotations
|
||||
"UP045",
|
||||
# deprecated import
|
||||
"UP035",
|
||||
]
|
||||
|
||||
# -------------------------------
|
||||
|
@ -214,7 +214,7 @@ class RayDAPOTrainer(RayPPOTrainer):
|
||||
# Collect the sequence reward for each trajectory
|
||||
prompt_uid2metric_vals = defaultdict(list)
|
||||
for uid, metric_val in zip(
|
||||
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]
|
||||
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
|
||||
):
|
||||
prompt_uid2metric_vals[uid].append(metric_val)
|
||||
|
||||
|
@ -202,7 +202,7 @@ class RayEntropyTrainer(RayPPOTrainer):
|
||||
# Collect the sequence reward for each trajectory
|
||||
prompt_uid2metric_vals = defaultdict(list)
|
||||
for uid, metric_val in zip(
|
||||
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name]
|
||||
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
|
||||
):
|
||||
prompt_uid2metric_vals[uid].append(metric_val)
|
||||
|
||||
|
@ -28,7 +28,7 @@ def _default_compute_score(
|
||||
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
elif isinstance(res, (int, float, bool)):
|
||||
elif isinstance(res, int | float | bool):
|
||||
return float(res)
|
||||
else:
|
||||
return float(res[0])
|
||||
|
@ -977,7 +977,7 @@ def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
|
||||
elif len(ground_truth_elems) != len(given_elems):
|
||||
is_correct = False
|
||||
else:
|
||||
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
|
||||
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):
|
||||
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
|
||||
# if fractions aren't reduced, then shouldn't be marked as correct
|
||||
# so, we don't want to allow sympy.simplify in this case
|
||||
|
@ -96,7 +96,6 @@ import contextlib
|
||||
import math
|
||||
import re
|
||||
from math import isclose
|
||||
from typing import Union
|
||||
|
||||
# sympy related
|
||||
from sympy import N, simplify
|
||||
@ -173,8 +172,8 @@ def handle_pi(string, pi):
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
prediction: bool | float | str,
|
||||
reference: float | str,
|
||||
include_percentage: bool = True,
|
||||
tolerance: float = 1e-4,
|
||||
timeout: float = 10.0,
|
||||
@ -251,7 +250,7 @@ def math_equal(
|
||||
if len(pred_parts) == len(ref_parts) and all(
|
||||
[
|
||||
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
||||
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
|
||||
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
|
||||
]
|
||||
):
|
||||
return True
|
||||
@ -277,7 +276,7 @@ def math_equal(
|
||||
if len(pred_parts) == len(ref_parts) and all(
|
||||
[
|
||||
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
||||
for pred_pt, ref_pt in zip(pred_parts, ref_parts)
|
||||
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
|
||||
]
|
||||
):
|
||||
return True
|
||||
@ -290,7 +289,7 @@ def math_equal(
|
||||
if len(pred_matrix) == len(ref_matrix_items) and all(
|
||||
[
|
||||
math_equal(pred, ref, include_percentage, tolerance)
|
||||
for ref, pred in zip(ref_matrix_items, pred_matrix)
|
||||
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
|
||||
]
|
||||
):
|
||||
return True
|
||||
@ -312,7 +311,7 @@ def math_equal(
|
||||
if len(pred_matrix) == len(ref_matrix_items) and all(
|
||||
[
|
||||
math_equal(pred, ref, include_percentage, tolerance)
|
||||
for ref, pred in zip(ref_matrix_items, pred_matrix)
|
||||
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
@ -100,7 +100,7 @@ def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos)
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||||
futures = []
|
||||
for data_source, solution_str, ground_truth, extra_info in zip(
|
||||
data_sources, solution_strs, ground_truths, extra_infos
|
||||
data_sources, solution_strs, ground_truths, extra_infos, strict=True
|
||||
):
|
||||
future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info)
|
||||
futures.append(future)
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@ -88,7 +88,7 @@ def preprocess(
|
||||
assert conversations[0]["role"] == "user", "the first role must be user"
|
||||
|
||||
if slice_config is not None:
|
||||
assert isinstance(slice_config, Dict)
|
||||
assert isinstance(slice_config, dict)
|
||||
assert "patch_size" in slice_config
|
||||
assert "max_slice_nums" in slice_config
|
||||
assert "scale_resolution" in slice_config
|
||||
@ -404,12 +404,12 @@ class RLHFDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_files: Union[str, List[str]],
|
||||
data_files: str | list[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: DictConfig,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
):
|
||||
if not isinstance(data_files, (List, ListConfig)):
|
||||
if not isinstance(data_files, list | ListConfig):
|
||||
data_files = [data_files]
|
||||
|
||||
self.data_files = copy.deepcopy(data_files)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
|
||||
@ -32,7 +32,7 @@ class CustomSandboxFusionTool(SandboxFusionTool):
|
||||
self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
code = parameters["code"]
|
||||
matches = self.code_pattern.findall(code)
|
||||
if matches:
|
||||
@ -81,7 +81,7 @@ class CustomRLHFDataset(RLHFDataset):
|
||||
|
||||
print(f"dataset len: {len(self.dataframe)}")
|
||||
|
||||
def map_fn(self, row: Dict, *, data_source: str = None):
|
||||
def map_fn(self, row: dict, *, data_source: str = None):
|
||||
if data_source == "Maxwell-Jia/AIME_2024":
|
||||
problem, answer = row["Problem"], row["Answer"]
|
||||
elif data_source == "yentinglin/aime_2025":
|
||||
@ -97,7 +97,7 @@ class CustomRLHFDataset(RLHFDataset):
|
||||
}
|
||||
return data
|
||||
|
||||
def map_fn2(self, row: Dict):
|
||||
def map_fn2(self, row: dict):
|
||||
content = row["prompt"][0]["content"]
|
||||
row["prompt"][0]["content"] = content + answer_format
|
||||
row["agent_name"] = "tool_agent"
|
||||
|
@ -17,7 +17,7 @@ Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
from omegaconf import OmegaConf
|
||||
@ -25,7 +25,7 @@ from omegaconf import OmegaConf
|
||||
code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
|
||||
|
||||
|
||||
def extract_code_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
def extract_code_message(content: str) -> tuple[dict[str, Any], str]:
|
||||
start, stop = "<code>", "</code>"
|
||||
i = content.find(start)
|
||||
if i == -1:
|
||||
@ -54,7 +54,7 @@ def extract_code_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
return message, content[j + len(stop) :]
|
||||
|
||||
|
||||
def extract_answer_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
def extract_answer_message(content: str) -> tuple[dict[str, Any], str]:
|
||||
start, stop = "<answer>", "</answer>"
|
||||
i = content.find(start)
|
||||
if i == -1:
|
||||
@ -70,7 +70,7 @@ def extract_answer_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
return message, content[j + len(stop) :]
|
||||
|
||||
|
||||
def extract_interpreter_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:
|
||||
start, stop = "<interpreter>", "</interpreter>"
|
||||
i = content.find(start)
|
||||
if i == -1:
|
||||
@ -86,7 +86,7 @@ def extract_interpreter_message(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
return message, content[j + len(stop) :]
|
||||
|
||||
|
||||
def process(row: Dict, *, tools: str):
|
||||
def process(row: dict, *, tools: str):
|
||||
messages = []
|
||||
|
||||
# extract problem
|
||||
|
@ -21,7 +21,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@ -49,7 +49,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
WorkerType = Type[Worker]
|
||||
WorkerType = type[Worker]
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
@ -142,7 +142,7 @@ class ResourcePoolManager:
|
||||
)
|
||||
|
||||
|
||||
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
|
||||
"""Placeholder: Computes prompt and response lengths."""
|
||||
try:
|
||||
# Assuming 'prompts' and 'responses' keys exist after generation/union
|
||||
@ -189,7 +189,7 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
|
||||
|
||||
# --- Modified Metric Function ---
|
||||
def compute_dpo_data_metrics(batch: DataProto) -> Dict[str, Any]:
|
||||
def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
|
||||
"""
|
||||
Computes and returns metrics relevant for the DPO-like process.
|
||||
Assumes 'batch' contains results after generation and preference marking,
|
||||
@ -354,7 +354,7 @@ def compute_onlineDPO_pref(data: DataProto):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _timer(name: str, timing_raw: Dict[str, float]):
|
||||
def _timer(name: str, timing_raw: dict[str, float]):
|
||||
with Timer(name=name, logger=None) as timer:
|
||||
yield
|
||||
timing_raw[name] = timer.last
|
||||
@ -634,7 +634,7 @@ class RaySPINTrainer:
|
||||
import numpy as np
|
||||
|
||||
# Create tuples of (input, output, score) and sort by input text
|
||||
samples = list(zip(inputs, outputs, scores))
|
||||
samples = list(zip(inputs, outputs, scores, strict=True))
|
||||
samples.sort(key=lambda x: x[0]) # Sort by input text
|
||||
|
||||
# Use fixed random seed for deterministic shuffling
|
||||
@ -1415,7 +1415,7 @@ class RaySPINTrainer:
|
||||
postfix_metrics = {
|
||||
k: f"{v:.3f}" if isinstance(v, float) else v
|
||||
for k, v in metrics.items()
|
||||
if isinstance(v, (int, float))
|
||||
if isinstance(v, int | float)
|
||||
}
|
||||
progress_bar.set_postfix(postfix_metrics)
|
||||
|
||||
|
@ -103,7 +103,7 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
|
||||
has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None)
|
||||
with torch.no_grad():
|
||||
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
|
||||
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers):
|
||||
for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers, strict=True):
|
||||
layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight)
|
||||
|
||||
q = hf_layer.self_attn.q_proj.weight.view(
|
||||
@ -178,7 +178,7 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
|
||||
copied_numel = 0
|
||||
safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq)
|
||||
copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight)
|
||||
for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers):
|
||||
for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True):
|
||||
# norm1 --> linear_qkv.norm
|
||||
copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight)
|
||||
# norm2 --> mlp.linear_fc1.norm
|
||||
@ -227,7 +227,7 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel
|
||||
mgllm = mgmodel.language_model
|
||||
copied_numel = 0
|
||||
copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight)
|
||||
for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers):
|
||||
for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True):
|
||||
copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight)
|
||||
|
||||
q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
|
||||
@ -264,7 +264,7 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_
|
||||
numel: int = 0
|
||||
numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight)
|
||||
print(f"{numel=}")
|
||||
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
|
||||
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers, strict=True)):
|
||||
numel_cur: int = numel
|
||||
numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight)
|
||||
|
||||
|
@ -29,7 +29,7 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
|
||||
|
||||
@ -52,7 +52,7 @@ def check_output_path(output_path: str):
|
||||
print(f"Output path '{output_path}' created.")
|
||||
|
||||
|
||||
def check_configs(original_config: Dict[str, Any], new_config: Dict[str, Any]) -> bool:
|
||||
def check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if the original config and new config are compatible.
|
||||
This is a placeholder function; actual implementation may vary based on requirements.
|
||||
|
@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
|
||||
import ray
|
||||
from omegaconf import DictConfig
|
||||
@ -23,7 +22,7 @@ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
|
||||
|
||||
|
||||
def init_agent_loop_manager(config: DictConfig) -> Union[AgentLoopManager, RayWorkerGroup]:
|
||||
def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:
|
||||
# =========================== 1. Create hybrid ActorRollout workers ===========================
|
||||
actor_rollout_cls = (
|
||||
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -119,7 +119,7 @@ class WeatherTool(BaseTool):
|
||||
schema = get_json_schema(self.get_current_temperature)
|
||||
return OpenAIFunctionToolSchema(**schema)
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_current_temperature(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
@ -150,7 +150,7 @@ class WeatherToolWithData(BaseTool):
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_temperature_date(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
|
@ -96,7 +96,7 @@ def test_data_transfer():
|
||||
# takes around 40 seconds
|
||||
output_lst = ray.get(output_ref)
|
||||
|
||||
for input_data, output_data in zip(data_list, output_lst):
|
||||
for input_data, output_data in zip(data_list, output_lst, strict=True):
|
||||
for key in input_data.batch.keys():
|
||||
assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
|
||||
input_data.batch[key],
|
||||
|
@ -80,7 +80,7 @@ def test_fused_workers():
|
||||
print(y)
|
||||
z = fused_wg.foo(0.1)
|
||||
print(z)
|
||||
for i, j in zip(y, z):
|
||||
for i, j in zip(y, z, strict=True):
|
||||
assert i == j
|
||||
|
||||
ray.shutdown()
|
||||
|
@ -21,7 +21,7 @@ This is heavily inspired from CanineTokenizer in transformers package.
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
|
||||
@ -86,7 +86,7 @@ class CharTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return self._vocab_str_to_int
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
return list(text)
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
@ -99,8 +99,8 @@ class CharTokenizer(PreTrainedTokenizer):
|
||||
return "".join(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
|
||||
) -> list[int]:
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
result = cls + token_ids_0 + sep
|
||||
@ -110,10 +110,10 @@ class CharTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
token_ids_0: list[int],
|
||||
token_ids_1: Optional[list[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
@ -126,7 +126,7 @@ class CharTokenizer(PreTrainedTokenizer):
|
||||
result += ([0] * len(token_ids_1)) + [1]
|
||||
return result
|
||||
|
||||
def get_config(self) -> Dict:
|
||||
def get_config(self) -> dict:
|
||||
return {
|
||||
"char_ords": [ord(ch) for ch in self.characters],
|
||||
"model_max_length": self.model_max_length,
|
||||
@ -134,21 +134,21 @@ class CharTokenizer(PreTrainedTokenizer):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict):
|
||||
def from_config(cls, config: dict):
|
||||
cfg = {}
|
||||
cfg["characters"] = [chr(i) for i in config["char_ords"]]
|
||||
cfg["model_max_length"] = config["model_max_length"]
|
||||
cfg["chat_template"] = config["chat_template"]
|
||||
return cls(**cfg)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, **kwargs):
|
||||
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
||||
cfg = self.get_config()
|
||||
with open(cfg_file, "w") as f:
|
||||
json.dump(cfg, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
|
||||
def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs):
|
||||
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
||||
with open(cfg_file) as f:
|
||||
cfg = json.load(f)
|
||||
|
@ -20,7 +20,6 @@ Checks that every public function and class has proper docstring documentation.
|
||||
import ast
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class DocstringChecker(ast.NodeVisitor):
|
||||
@ -28,7 +27,7 @@ class DocstringChecker(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self.filename = filename
|
||||
self.missing_docstrings: List[Tuple[str, str, int]] = []
|
||||
self.missing_docstrings: list[tuple[str, str, int]] = []
|
||||
self.current_class = None
|
||||
self.function_nesting_level = 0
|
||||
|
||||
@ -70,7 +69,7 @@ class DocstringChecker(ast.NodeVisitor):
|
||||
return ast.get_docstring(node) is not None
|
||||
|
||||
|
||||
def check_file_docstrings(filepath: str) -> List[Tuple[str, str, int]]:
|
||||
def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]:
|
||||
"""Check docstrings in a single file."""
|
||||
try:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
|
@ -22,23 +22,22 @@ import ast
|
||||
import linecache
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
|
||||
def get_changed_files() -> List[Path]:
|
||||
def get_changed_files() -> list[Path]:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], stdout=subprocess.PIPE, text=True
|
||||
)
|
||||
return [Path(f) for f in result.stdout.splitlines() if f.endswith(".py")]
|
||||
|
||||
|
||||
def get_changed_lines(file_path: Path) -> Set[int]:
|
||||
def get_changed_lines(file_path: Path) -> set[int]:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "-U0", "origin/main...HEAD", "--", str(file_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
lines: Set[int] = set()
|
||||
lines: set[int] = set()
|
||||
for line in result.stdout.splitlines():
|
||||
if line.startswith("@@"):
|
||||
for part in line.split():
|
||||
@ -84,19 +83,19 @@ def has_type_annotations(node: ast.AST, debug: bool = False) -> int:
|
||||
|
||||
|
||||
def check_file(
|
||||
file_path: Path, changed_lines: Set[int], debug: bool = False
|
||||
) -> Tuple[int, int, List[Tuple[Path, int, str]], List[Tuple[Path, int, str]]]:
|
||||
file_path: Path, changed_lines: set[int], debug: bool = False
|
||||
) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]:
|
||||
with open(file_path) as f:
|
||||
source: str = f.read()
|
||||
tree = ast.parse(source, filename=str(file_path))
|
||||
annotated = 0
|
||||
total = 0
|
||||
warning_lines: List[Tuple[Path, int, str]] = []
|
||||
failure_lines: List[Tuple[Path, int, str]] = []
|
||||
warning_lines: list[tuple[Path, int, str]] = []
|
||||
failure_lines: list[tuple[Path, int, str]] = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if hasattr(node, "lineno") and node.lineno in changed_lines:
|
||||
if isinstance(node, (ast.FunctionDef, ast.Assign, ast.AnnAssign)):
|
||||
if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign):
|
||||
total += 1
|
||||
result = has_type_annotations(node, debug)
|
||||
if result == CHECK_SUCCESS or result == CHECK_WARNING:
|
||||
@ -128,8 +127,8 @@ def main() -> None:
|
||||
|
||||
total_changed = 0
|
||||
total_annotated = 0
|
||||
all_warnings: List[Tuple[Path, int, str]] = []
|
||||
all_failures: List[Tuple[Path, int, str]] = []
|
||||
all_warnings: list[tuple[Path, int, str]] = []
|
||||
all_failures: list[tuple[Path, int, str]] = []
|
||||
|
||||
target_files = [args.target_file] if args.target_file is not None else get_changed_files()
|
||||
for fpath in target_files:
|
||||
|
@ -57,7 +57,7 @@ def test_memory_buffers():
|
||||
change_ratio = (a - a_before) / a_before
|
||||
assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}"
|
||||
|
||||
for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
||||
for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True):
|
||||
assert name1 == name2
|
||||
assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}"
|
||||
|
||||
|
@ -79,7 +79,7 @@ def test_tensor_dict_make_iterator():
|
||||
for data in data_iter_2:
|
||||
data_list_2.append(data)
|
||||
|
||||
for data1, data2 in zip(data_list_1, data_list_2):
|
||||
for data1, data2 in zip(data_list_1, data_list_2, strict=True):
|
||||
assert isinstance(data1, DataProto)
|
||||
assert isinstance(data2, DataProto)
|
||||
result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"]))
|
||||
|
@ -14,7 +14,7 @@
|
||||
# Unit Tests for `initialize_tools_from_config`
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers.utils import get_json_schema
|
||||
@ -44,7 +44,7 @@ class WeatherToolForTest(BaseTool):
|
||||
schema = get_json_schema(self.get_current_temperature)
|
||||
return OpenAIFunctionToolSchema(**schema)
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_current_temperature(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
@ -75,7 +75,7 @@ class WeatherToolWithDataForTest(BaseTool):
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_temperature_date(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
|
@ -57,7 +57,7 @@ class TestConfigComparison(unittest.TestCase):
|
||||
len(legacy_config),
|
||||
f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
|
||||
)
|
||||
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config)):
|
||||
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)):
|
||||
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
|
||||
else:
|
||||
self.assertEqual(
|
||||
|
@ -120,7 +120,7 @@ def test_prime_code():
|
||||
Test PRIME code sandbox.
|
||||
"""
|
||||
data_source = "codecontests"
|
||||
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
|
||||
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):
|
||||
score = default_compute_score(data_source, completion, ground_truth)
|
||||
assert float(score) == score_
|
||||
|
||||
@ -136,7 +136,7 @@ def test_prime_code_sandbox_fusion():
|
||||
sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL")
|
||||
# Removed the previous 'if not sandbox_url' check block
|
||||
|
||||
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores):
|
||||
for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True):
|
||||
score = default_compute_score(
|
||||
data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url}
|
||||
) # <-- Use the URL obtained from the environment variable
|
||||
@ -180,6 +180,6 @@ def test_check_correctness():
|
||||
|
||||
def test_prime_math():
|
||||
data_source = "numina_aops_forum"
|
||||
for completion, ground_truth in zip(prime_math_answers, prime_math_gts):
|
||||
for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True):
|
||||
score = default_compute_score(data_source, completion, ground_truth)
|
||||
assert float(score) == 1.0
|
||||
|
@ -146,7 +146,9 @@ def test_flops_counter(config_type: str):
|
||||
test_config = CONFIG[config_type]
|
||||
config = Config(test_config["config"])
|
||||
flops_counter = FlopsCounter(config)
|
||||
for batch_seqlens, expected_flops in zip(test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"]):
|
||||
for batch_seqlens, expected_flops in zip(
|
||||
test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
|
||||
):
|
||||
# set delta time to 1 to get the flops
|
||||
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
|
||||
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
|
||||
|
@ -30,7 +30,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import typing
|
||||
|
||||
import torch
|
||||
|
||||
@ -48,7 +47,7 @@ MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
|
||||
|
||||
def run_torch_entropy(
|
||||
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
|
||||
) -> typing.List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
hidden = hidden.squeeze(0).to(torch.float32)
|
||||
weight = weight.transpose(0, 1).to(torch.float32)
|
||||
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
|
||||
@ -67,7 +66,7 @@ def run_verl_original_entropy(
|
||||
weight: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
temperature: float,
|
||||
) -> typing.List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
hidden = hidden.squeeze(0).to(torch.float32)
|
||||
weight = weight.transpose(0, 1).to(torch.float32)
|
||||
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
|
||||
|
@ -30,7 +30,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -57,7 +56,7 @@ LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)
|
||||
|
||||
def run_torch_entropy(
|
||||
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
|
||||
) -> typing.List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
# [num_tokens, vocab_size]
|
||||
if len(hidden.shape) > 2:
|
||||
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
|
||||
|
@ -106,9 +106,12 @@ class TestNsightSystemsProfiler(unittest.TestCase):
|
||||
def test_func(self, *args, **kwargs):
|
||||
return "result"
|
||||
|
||||
with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch(
|
||||
"verl.utils.profiler.nvtx_profile.mark_start_range"
|
||||
) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range:
|
||||
with (
|
||||
patch("torch.cuda.profiler.start") as mock_start,
|
||||
patch("torch.cuda.profiler.stop") as mock_stop,
|
||||
patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
|
||||
patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
|
||||
):
|
||||
result = test_func(mock_self)
|
||||
self.assertEqual(result, "result")
|
||||
mock_start_range.assert_called_once()
|
||||
@ -127,9 +130,12 @@ class TestNsightSystemsProfiler(unittest.TestCase):
|
||||
def test_func(self, *args, **kwargs):
|
||||
return "result"
|
||||
|
||||
with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, patch(
|
||||
"verl.utils.profiler.nvtx_profile.mark_start_range"
|
||||
) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range:
|
||||
with (
|
||||
patch("torch.cuda.profiler.start") as mock_start,
|
||||
patch("torch.cuda.profiler.stop") as mock_stop,
|
||||
patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
|
||||
patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
|
||||
):
|
||||
result = test_func(mock_self)
|
||||
self.assertEqual(result, "result")
|
||||
mock_start_range.assert_called_once()
|
||||
|
@ -32,7 +32,6 @@ packages:
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
|
||||
import ray
|
||||
from omegaconf import DictConfig
|
||||
@ -72,7 +71,7 @@ def init_config(n_gpus_per_node) -> DictConfig:
|
||||
return config
|
||||
|
||||
|
||||
def initialize(config, backend) -> Tuple[Union[AgentLoopManager, RayWorkerGroup], StatefulDataLoader]:
|
||||
def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]:
|
||||
env_vars = {
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"VLLM_USE_V1": "1",
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -120,7 +120,7 @@ class WeatherTool(BaseTool):
|
||||
schema = get_json_schema(self.get_current_temperature)
|
||||
return OpenAIFunctionToolSchema(**schema)
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_current_temperature(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
@ -151,7 +151,7 @@ class WeatherToolWithData(BaseTool):
|
||||
"unit": unit,
|
||||
}
|
||||
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
try:
|
||||
result = self.get_temperature_date(**parameters)
|
||||
return json.dumps(result), 0, {}
|
||||
|
@ -55,7 +55,7 @@ def are_lists_similar(a, b):
|
||||
total_length = 0
|
||||
total_diff = 0
|
||||
|
||||
for s1, s2 in zip(a, b):
|
||||
for s1, s2 in zip(a, b, strict=True):
|
||||
max_len = max(len(s1), len(s2))
|
||||
total_length += max_len
|
||||
diff = levenshtein(s1, s2)
|
||||
|
@ -19,7 +19,7 @@ import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import fastapi
|
||||
import numpy as np
|
||||
@ -128,7 +128,7 @@ class CustomCompletionCallback(ToolCompletionCallback):
|
||||
# TODO: support asyncio executor
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5))
|
||||
|
||||
async def sandbox_code_execution(self, code: str) -> Dict[str, Any]:
|
||||
async def sandbox_code_execution(self, code: str) -> dict[str, Any]:
|
||||
loop = asyncio.get_running_loop()
|
||||
result_status, metadata = await loop.run_in_executor(
|
||||
self.executor,
|
||||
@ -153,7 +153,7 @@ class CustomCompletionCallback(ToolCompletionCallback):
|
||||
}
|
||||
return extra
|
||||
|
||||
async def __call__(self, messages: List[Dict[str, str]], completions: ChatCompletion, info: Dict[str, Any]):
|
||||
async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]):
|
||||
role, content, finish_reason = (
|
||||
completions.choices[0].message.role,
|
||||
completions.choices[0].message.content,
|
||||
|
@ -262,10 +262,11 @@ class TestRolloutWithMCPSearchTools:
|
||||
},
|
||||
}
|
||||
]
|
||||
with patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema), patch.object(
|
||||
SGLangRollout, "_init_distributed_env", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_inference_engine", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_sampling_params", return_value=None
|
||||
with (
|
||||
patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema),
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
rollout = SGLangRollout(
|
||||
actor_module="",
|
||||
@ -355,7 +356,7 @@ class TestRolloutWithMCPSearchTools:
|
||||
|
||||
mock_rollout._handle_engine_call = MagicMock()
|
||||
futures = [asyncio.Future() for i in expect_turn_array]
|
||||
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):
|
||||
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
|
||||
i.set_result(
|
||||
{
|
||||
"text": turn,
|
||||
@ -420,7 +421,7 @@ class TestRolloutWithMCPSearchTools:
|
||||
req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))
|
||||
|
||||
futures = [asyncio.Future() for _ in expect_turn_array]
|
||||
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)):
|
||||
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
|
||||
fut.set_result(
|
||||
{
|
||||
"text": turn,
|
||||
|
@ -166,9 +166,11 @@ class TestRolloutWithSearchTools:
|
||||
@pytest.fixture
|
||||
def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config):
|
||||
"""Mock the rollout instance with sampling_params initialized."""
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
rollout = SGLangRollout(
|
||||
actor_module="",
|
||||
config=search_rollout_config,
|
||||
@ -308,7 +310,7 @@ class TestRolloutWithSearchTools:
|
||||
|
||||
mock_rollout._handle_engine_call = MagicMock()
|
||||
futures = [asyncio.Future() for i in expect_turn_array]
|
||||
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)):
|
||||
for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
|
||||
i.set_result(
|
||||
{
|
||||
"text": turn,
|
||||
@ -376,7 +378,7 @@ class TestRolloutWithSearchTools:
|
||||
req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest))
|
||||
|
||||
futures = [asyncio.Future() for _ in expect_turn_array]
|
||||
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array)):
|
||||
for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)):
|
||||
fut.set_result(
|
||||
{
|
||||
"text": turn,
|
||||
|
@ -117,9 +117,11 @@ class TestSGLangMultiInteraction:
|
||||
|
||||
try:
|
||||
# Mock SGLang engine and initialization methods like the reference test
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
# Create a real tokenizer like the reference test
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
@ -172,9 +174,11 @@ class TestSGLangMultiInteraction:
|
||||
config, temp_config_path = create_mock_config_with_multi_interactions()
|
||||
|
||||
try:
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
@ -282,9 +286,11 @@ class TestSGLangMultiInteraction:
|
||||
)
|
||||
|
||||
try:
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
@ -321,9 +327,11 @@ class TestSGLangMultiInteraction:
|
||||
config, temp_config_path = create_mock_config_with_multi_interactions()
|
||||
|
||||
try:
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
@ -388,9 +396,11 @@ class TestSGLangMultiInteraction:
|
||||
}
|
||||
)
|
||||
|
||||
with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(
|
||||
SGLangRollout, "_init_inference_engine", return_value=None
|
||||
), patch.object(SGLangRollout, "_init_sampling_params", return_value=None):
|
||||
with (
|
||||
patch.object(SGLangRollout, "_init_distributed_env", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_inference_engine", return_value=None),
|
||||
patch.object(SGLangRollout, "_init_sampling_params", return_value=None),
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
@ -43,7 +43,7 @@ def are_lists_similar(a, b, threshold=10):
|
||||
return False
|
||||
total_length = 0
|
||||
total_diff = 0
|
||||
for s1, s2 in zip(a, b):
|
||||
for s1, s2 in zip(a, b, strict=True):
|
||||
max_len = max(len(s1), len(s2))
|
||||
total_length += max_len
|
||||
total_diff += levenshtein(s1, s2)
|
||||
|
@ -17,7 +17,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@ -46,7 +46,7 @@ class AsyncLLMServerManager:
|
||||
- Sticky session: send multi-turn chat completions to same server for automatic prefix caching
|
||||
"""
|
||||
|
||||
def __init__(self, config: DictConfig, server_handles: List[ray.actor.ActorHandle], max_cache_size: int = 10000):
|
||||
def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):
|
||||
"""Initialize the AsyncLLMServerManager.
|
||||
|
||||
Args:
|
||||
@ -81,9 +81,9 @@ class AsyncLLMServerManager:
|
||||
self,
|
||||
request_id,
|
||||
*,
|
||||
prompt_ids: List[int],
|
||||
sampling_params: Dict[str, Any],
|
||||
) -> List[int]:
|
||||
prompt_ids: list[int],
|
||||
sampling_params: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""Generate tokens from prompt ids.
|
||||
|
||||
Args:
|
||||
@ -113,9 +113,9 @@ class AgentLoopMetrics(BaseModel):
|
||||
class AgentLoopOutput(BaseModel):
|
||||
"""Agent loop output."""
|
||||
|
||||
prompt_ids: List[int]
|
||||
response_ids: List[int]
|
||||
response_mask: List[int]
|
||||
prompt_ids: list[int]
|
||||
response_ids: list[int]
|
||||
response_mask: list[int]
|
||||
num_turns: int = 0
|
||||
metrics: AgentLoopMetrics
|
||||
|
||||
@ -148,7 +148,7 @@ class AgentLoopBase(ABC):
|
||||
cls._class_initialized = True
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
"""Run agent loop to interact with LLM server and environment.
|
||||
|
||||
Args:
|
||||
@ -165,7 +165,7 @@ class AgentLoopBase(ABC):
|
||||
class AgentLoopWorker:
|
||||
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
|
||||
|
||||
def __init__(self, config: DictConfig, server_handles: List[ray.actor.ActorHandle]):
|
||||
def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]):
|
||||
"""Initialize agent loop manager.
|
||||
|
||||
Args:
|
||||
@ -236,7 +236,7 @@ class AgentLoopWorker:
|
||||
|
||||
trajectory_info = await get_trajectory_info(batch.meta_info.get("global_steps", -1), index)
|
||||
|
||||
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info):
|
||||
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
|
||||
tasks.append(
|
||||
asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory))
|
||||
)
|
||||
@ -248,9 +248,9 @@ class AgentLoopWorker:
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
agent_name: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
sampling_params: Dict[str, Any],
|
||||
trajectory: Dict[str, Any],
|
||||
messages: list[dict[str, Any]],
|
||||
sampling_params: dict[str, Any],
|
||||
trajectory: dict[str, Any],
|
||||
) -> AgentLoopOutput:
|
||||
with rollout_trace_attr(
|
||||
step=trajectory["step"], sample_index=trajectory["sample_index"], rollout_n=trajectory["rollout_n"]
|
||||
@ -260,7 +260,7 @@ class AgentLoopWorker:
|
||||
output = await agent_loop.run(messages, sampling_params)
|
||||
return output
|
||||
|
||||
def get_agent_loop_class(self, agent_name: str) -> Type[AgentLoopBase]:
|
||||
def get_agent_loop_class(self, agent_name: str) -> type[AgentLoopBase]:
|
||||
"""Get the appropriate agent loop class based on agent name.
|
||||
|
||||
Factory method that returns the correct agent loop class implementation
|
||||
@ -285,7 +285,7 @@ class AgentLoopWorker:
|
||||
return ToolAgentLoop
|
||||
raise ValueError(f"Unknown agent_name: {agent_name}")
|
||||
|
||||
def _postprocess(self, inputs: List[AgentLoopOutput]) -> DataProto:
|
||||
def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto:
|
||||
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
|
||||
# prompts: left pad
|
||||
# responses: right pad
|
||||
@ -452,7 +452,10 @@ class AgentLoopManager:
|
||||
self.wake_up()
|
||||
chunkes = prompts.chunk(len(self.agent_loop_workers))
|
||||
outputs = ray.get(
|
||||
[worker.generate_sequences.remote(chunk) for worker, chunk in zip(self.agent_loop_workers, chunkes)]
|
||||
[
|
||||
worker.generate_sequences.remote(chunk)
|
||||
for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
|
||||
]
|
||||
)
|
||||
output = DataProto.concat(outputs)
|
||||
if self.config.actor_rollout_ref.rollout.free_cache_engine:
|
||||
@ -465,7 +468,7 @@ class AgentLoopManager:
|
||||
output.meta_info = {"timing": timing}
|
||||
return output
|
||||
|
||||
def _performance_metrics(self, metrics: List[List[Dict[str, str]]], output: DataProto) -> Dict[str, float]:
|
||||
def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
|
||||
timing = {}
|
||||
t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk])
|
||||
t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk])
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
|
||||
@ -31,7 +31,7 @@ class SingleTurnAgentLoop(AgentLoopBase):
|
||||
self.prompt_length = config.actor_rollout_ref.rollout.prompt_length
|
||||
self.response_length = config.actor_rollout_ref.rollout.response_length
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
prompt_ids = await self.loop.run_in_executor(
|
||||
|
@ -16,7 +16,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import regex as re
|
||||
@ -46,7 +46,7 @@ class FunctionCall(BaseModel):
|
||||
|
||||
class ToolParser(ABC):
|
||||
@abstractmethod
|
||||
async def extract_tool_calls(self, responses_ids: List[int]) -> List[FunctionCall]:
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
|
||||
"""Extract tool calls from the responses.
|
||||
|
||||
Args:
|
||||
@ -69,7 +69,7 @@ class HermesToolParser(ToolParser):
|
||||
self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)
|
||||
|
||||
@rollout_trace_op
|
||||
async def extract_tool_calls(self, responses_ids: List[int]) -> List[FunctionCall]:
|
||||
async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]:
|
||||
loop = asyncio.get_running_loop()
|
||||
text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids)
|
||||
if self.tool_call_start_token not in text or self.tool_call_end_token not in text:
|
||||
@ -117,7 +117,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
|
||||
|
||||
@rollout_trace_op
|
||||
async def run(self, messages: List[Dict[str, Any]], sampling_params: Dict[str, Any]) -> AgentLoopOutput:
|
||||
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
prompt_ids = await self.loop.run_in_executor(
|
||||
@ -194,7 +194,7 @@ class ToolAgentLoop(AgentLoopBase):
|
||||
)
|
||||
return output
|
||||
|
||||
async def _call_tool(self, tool_call: FunctionCall) -> Dict[str, str]:
|
||||
async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]:
|
||||
"""Call tool and return tool response."""
|
||||
tool, instance_id = None, None
|
||||
try:
|
||||
|
@ -21,7 +21,7 @@ on rollout data.
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
from omegaconf import DictConfig
|
||||
@ -73,7 +73,7 @@ class DynamicGenDataset(RLHFDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_files: Union[str, List[str]],
|
||||
data_files: str | list[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: DictConfig,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
|
@ -13,12 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class BaseInteraction:
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
self.name: str = config.get("name", "interaction_agent") # More general agent default role name
|
||||
|
||||
@ -37,8 +37,8 @@ class BaseInteraction:
|
||||
return instance_id
|
||||
|
||||
async def generate_response(
|
||||
self, instance_id: str, messages: List[Dict[str, Any]], **kwargs
|
||||
) -> Tuple[bool, str, float, Dict[str, Any]]: # More clear response generation method
|
||||
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
|
||||
) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method
|
||||
"""
|
||||
Generates a response for the current turn of interaction.
|
||||
Returns a tuple containing:
|
||||
@ -50,7 +50,7 @@ class BaseInteraction:
|
||||
should_terminate_sequence: bool = False # if True, end rollout
|
||||
response_content: str = "Your current result seems acceptable."
|
||||
current_turn_score: float = 0.8
|
||||
additional_data: Dict[str, Any] = {}
|
||||
additional_data: dict[str, Any] = {}
|
||||
return should_terminate_sequence, response_content, current_turn_score, additional_data
|
||||
|
||||
async def calculate_score(self) -> float: # More clear score calculation method
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.utils.reward_score import gsm8k
|
||||
@ -53,8 +53,8 @@ class Gsm8kInteraction(BaseInteraction):
|
||||
return instance_id
|
||||
|
||||
async def generate_response(
|
||||
self, instance_id: str, messages: List[Dict[str, Any]], **kwargs
|
||||
) -> Tuple[bool, str, float, dict]:
|
||||
self, instance_id: str, messages: list[dict[str, Any]], **kwargs
|
||||
) -> tuple[bool, str, float, dict]:
|
||||
content = ""
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
item = messages[i]
|
||||
|
@ -16,7 +16,7 @@ import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ContextManager, Dict, List
|
||||
from typing import Any, Callable, ContextManager
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
@ -140,7 +140,7 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
"output_layer": "lm_head",
|
||||
}
|
||||
|
||||
def _load_state_dicts(self, model_ckpt_path: str) -> Dict[str, Any]:
|
||||
def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]:
|
||||
"""_summary_
|
||||
Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory.
|
||||
|
||||
@ -270,7 +270,7 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
else:
|
||||
return [tensor]
|
||||
|
||||
def _merge_state_dicts(self, model_state_dict_list: List[Dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
layers_cum = 0
|
||||
|
||||
@ -306,7 +306,7 @@ class MegatronModelMerger(BaseModelMerger):
|
||||
state_dict[hf_name] = split_tensor[0]
|
||||
elif len(split_tensor) == 3:
|
||||
# split qkv
|
||||
for n, d in zip(["q", "k", "v"], split_tensor):
|
||||
for n, d in zip(["q", "k", "v"], split_tensor, strict=True):
|
||||
state_dict[hf_name.replace("qkv", n)] = d
|
||||
elif len(split_tensor) == 2:
|
||||
# split gate up
|
||||
|
@ -86,7 +86,7 @@ def load_state_dict_to_megatron_llama(
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -86,7 +86,7 @@ def load_state_dict_to_megatron_llama(
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -100,7 +100,7 @@ def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=Fals
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -290,7 +290,7 @@ class ParallelLlamaAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
qkv = self.qkv_proj(hidden_states)[0]
|
||||
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from megatron.core import ModelParallelConfig
|
||||
@ -49,7 +49,7 @@ class ParallelLlamaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -119,7 +119,7 @@ class ParallelLlamaDecoderLayerRmPad(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch LLaMA model with Megatron-style acceleration."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -125,7 +125,7 @@ class ParallelLlamaModel(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -184,7 +184,7 @@ class ParallelLlamaForCausalLM(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -255,7 +255,7 @@ class ParallelLlamaModelRmPad(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -325,7 +325,7 @@ class ParallelLlamaForCausalLMRmPad(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -404,7 +404,7 @@ class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
output = super().forward(input_ids, attention_mask, position_ids)
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
@ -487,7 +487,7 @@ class ParallelLlamaModelRmPadPP(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -595,7 +595,7 @@ class ParallelLlamaForCausalLMRmPadPP(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -679,7 +679,7 @@ class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
if self.post_process:
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
|
@ -87,7 +87,7 @@ def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, par
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -17,7 +17,7 @@ Registry module for model architecture components.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, Type
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -73,7 +73,7 @@ class SupportedModel(Enum):
|
||||
|
||||
|
||||
# Registry for model configuration converters
|
||||
MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
|
||||
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
|
||||
SupportedModel.LLAMA: hf_to_mcore_config_dense,
|
||||
SupportedModel.QWEN2: hf_to_mcore_config_dense,
|
||||
SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
|
||||
@ -87,7 +87,7 @@ MODEL_CONFIG_CONVERTER_REGISTRY: Dict[SupportedModel, Callable[[PretrainedConfig
|
||||
}
|
||||
|
||||
# Registry for model initializers
|
||||
MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
|
||||
MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {
|
||||
SupportedModel.LLAMA: DenseModel,
|
||||
SupportedModel.QWEN2: DenseModel,
|
||||
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
|
||||
@ -101,7 +101,7 @@ MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
|
||||
}
|
||||
|
||||
# Registry for model forward functions
|
||||
MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
|
||||
MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.LLAMA: gptmodel_forward,
|
||||
SupportedModel.QWEN2: gptmodel_forward,
|
||||
SupportedModel.QWEN2_MOE: gptmodel_forward,
|
||||
@ -116,7 +116,7 @@ MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
|
||||
}
|
||||
|
||||
# Registry for model forward functions
|
||||
MODEL_FORWARD_FUSED_REGISTRY: Dict[SupportedModel, Callable] = {
|
||||
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.LLAMA: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN2: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN2_MOE: fused_forward_gptmodel,
|
||||
@ -131,7 +131,7 @@ MODEL_FORWARD_FUSED_REGISTRY: Dict[SupportedModel, Callable] = {
|
||||
}
|
||||
|
||||
# Registry for model weight converters
|
||||
MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
|
||||
MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {
|
||||
SupportedModel.LLAMA: McoreToHFWeightConverterDense,
|
||||
SupportedModel.QWEN2: McoreToHFWeightConverterDense,
|
||||
SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
|
||||
|
@ -112,7 +112,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -84,7 +84,7 @@ def load_state_dict_to_megatron_qwen2(
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -84,7 +84,7 @@ def load_state_dict_to_megatron_qwen2(
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -100,7 +100,7 @@ def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=Fals
|
||||
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
|
||||
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
|
||||
|
||||
if not isinstance(wrapped_models, (list, tuple)):
|
||||
if not isinstance(wrapped_models, list | tuple):
|
||||
wrapped_models = list(wrapped_models)
|
||||
|
||||
assert len(wrapped_models) == virtual_pp_size
|
||||
|
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
@ -235,7 +235,7 @@ class ParallelQwen2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
qkv = self.qkv_proj(hidden_states)[0]
|
||||
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
|
@ -18,7 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from megatron.core import ModelParallelConfig
|
||||
@ -49,7 +49,7 @@ class ParallelQwen2DecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
@ -119,7 +119,7 @@ class ParallelQwen2DecoderLayerRmPad(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen2 model."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -126,7 +126,7 @@ class ParallelQwen2Model(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -185,7 +185,7 @@ class ParallelQwen2ForCausalLM(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -256,7 +256,7 @@ class ParallelQwen2ModelRmPad(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -326,7 +326,7 @@ class ParallelQwen2ForCausalLMRmPad(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -405,7 +405,7 @@ class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
output = super().forward(input_ids, attention_mask, position_ids)
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
return output
|
||||
@ -487,7 +487,7 @@ class ParallelQwen2ModelRmPadPP(nn.Module):
|
||||
indices: torch.Tensor = None,
|
||||
cu_seqlens: int = None,
|
||||
max_seqlen_in_batch: int = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
) -> tuple | BaseModelOutputWithPast:
|
||||
"""
|
||||
|
||||
Args:
|
||||
@ -645,7 +645,7 @@ class ParallelQwen2ForCausalLMRmPadPP(nn.Module):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -728,7 +728,7 @@ class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP):
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
||||
if self.post_process:
|
||||
output.logits = torch.squeeze(output.logits, dim=-1)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@ -38,7 +38,7 @@ _MODELS = {
|
||||
# return model class
|
||||
class ModelRegistry:
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str, value=False) -> Optional[Type[nn.Module]]:
|
||||
def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
|
||||
@ -54,5 +54,5 @@ class ModelRegistry:
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
def get_supported_archs() -> list[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
@ -73,7 +73,7 @@ def forward_with_torch_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -81,10 +81,10 @@ def forward_with_torch_backend(
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputForPPO]:
|
||||
) -> tuple | CausalLMOutputForPPO:
|
||||
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
||||
|
||||
outputs = forward_base_model(
|
||||
@ -135,7 +135,7 @@ def forward_with_triton_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -143,10 +143,10 @@ def forward_with_triton_backend(
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputForPPO]:
|
||||
) -> tuple | CausalLMOutputForPPO:
|
||||
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
||||
|
||||
outputs = forward_base_model(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -93,7 +93,7 @@ def _ulysses_flash_attn_forward(
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -46,9 +46,9 @@ def llama_flash_attn_forward(
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
|
||||
|
||||
@ -168,12 +168,12 @@ def llama_flash_attn_forward(
|
||||
def llama_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@ -28,7 +27,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm
|
||||
# https://github.com/huggingface/transformers/pull/38491
|
||||
def apply_rotary_pos_emb_flashatt_npu(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
cos = cos.repeat(1, 2)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
@ -39,7 +39,7 @@ def qwen2_flash_attn_forward(
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
"""
|
||||
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
|
||||
@ -157,12 +157,12 @@ def qwen2_flash_attn_forward(
|
||||
def qwen2_attn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
@ -33,7 +33,7 @@ def forward_base_model(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -46,7 +46,7 @@ def forward_base_model(
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputWithPast:
|
||||
r"""
|
||||
Copy paste Qwen2_5_VL's forward
|
||||
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py
|
||||
@ -143,7 +143,7 @@ def forward_with_torch_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -159,7 +159,7 @@ def forward_with_torch_backend(
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
|
||||
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
||||
|
||||
outputs = forward_base_model(
|
||||
@ -218,7 +218,7 @@ def forward_with_triton_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -234,7 +234,7 @@ def forward_with_triton_backend(
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputForPPO]:
|
||||
) -> tuple | Qwen2_5_VLCausalLMOutputForPPO:
|
||||
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
||||
|
||||
outputs = forward_base_model(
|
||||
|
@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
@ -230,9 +230,9 @@ def ulysses_flash_attn_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, None, None]:
|
||||
) -> tuple[torch.Tensor, None, None]:
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv
|
||||
|
||||
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
|
||||
@ -315,7 +315,7 @@ def forward_base_model(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -327,7 +327,7 @@ def forward_base_model(
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
) -> tuple | Qwen2VLCausalLMOutputWithPast:
|
||||
r"""
|
||||
Copy paste Qwen2VL's forward
|
||||
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py
|
||||
@ -418,7 +418,7 @@ def forward_with_torch_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -433,7 +433,7 @@ def forward_with_torch_backend(
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
|
||||
) -> tuple | Qwen2VLCausalLMOutputForPPO:
|
||||
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
||||
|
||||
outputs = forward_base_model(
|
||||
@ -491,7 +491,7 @@ def forward_with_triton_backend(
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -506,7 +506,7 @@ def forward_with_triton_backend(
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
temperature: float = 1.0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputForPPO]:
|
||||
) -> tuple | Qwen2VLCausalLMOutputForPPO:
|
||||
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
||||
|
||||
outputs = forward_base_model(
|
||||
|
@ -22,7 +22,7 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -200,8 +200,8 @@ def collate_fn(x: list["DataProtoItem"]):
|
||||
class DataProtoItem:
|
||||
# TODO(zhangchi.usc1992) add consistency check
|
||||
batch: TensorDict = None
|
||||
non_tensor_batch: Dict = field(default_factory=dict)
|
||||
meta_info: Dict = field(default_factory=dict)
|
||||
non_tensor_batch: dict = field(default_factory=dict)
|
||||
meta_info: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -214,8 +214,8 @@ class DataProto:
|
||||
"""
|
||||
|
||||
batch: TensorDict = None
|
||||
non_tensor_batch: Dict = field(default_factory=dict)
|
||||
meta_info: Dict = field(default_factory=dict)
|
||||
non_tensor_batch: dict = field(default_factory=dict)
|
||||
meta_info: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
# perform necessary checking
|
||||
@ -251,11 +251,11 @@ class DataProto:
|
||||
return self.slice(item.start, item.stop, item.step)
|
||||
|
||||
# Case 2: List, numpy array, or torch tensor - use sel_idxs
|
||||
elif isinstance(item, (list, np.ndarray, torch.Tensor)):
|
||||
elif isinstance(item, list | np.ndarray | torch.Tensor):
|
||||
return self.select_idxs(item)
|
||||
|
||||
# Case 3: Single integer - return DataProtoItem for backward compatibility
|
||||
elif isinstance(item, (int, np.integer)):
|
||||
elif isinstance(item, int | np.integer):
|
||||
tensor_data = self.batch[item] if self.batch is not None else None
|
||||
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
|
||||
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
|
||||
@ -343,7 +343,7 @@ class DataProto:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_single_dict(cls, data: Dict[str, Union[torch.Tensor, np.ndarray]], meta_info=None, auto_padding=False):
|
||||
def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False):
|
||||
"""Create a DataProto from a dict of tensors and non_tensors"""
|
||||
tensors = {}
|
||||
non_tensors = {}
|
||||
@ -361,7 +361,7 @@ class DataProto:
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
tensors: Optional[Dict[str, torch.Tensor]] = None,
|
||||
tensors: Optional[dict[str, torch.Tensor]] = None,
|
||||
non_tensors=None,
|
||||
meta_info=None,
|
||||
num_batch_dims=1,
|
||||
@ -649,7 +649,7 @@ class DataProto:
|
||||
else:
|
||||
generator = None
|
||||
|
||||
assert isinstance(dataloader_kwargs, Dict)
|
||||
assert isinstance(dataloader_kwargs, dict)
|
||||
train_dataloader = DataLoader(
|
||||
dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs
|
||||
)
|
||||
@ -686,7 +686,7 @@ class DataProto:
|
||||
self.batch = padded_dp.batch
|
||||
self.non_tensor_batch = padded_dp.non_tensor_batch
|
||||
|
||||
def chunk(self, chunks: int) -> List["DataProto"]:
|
||||
def chunk(self, chunks: int) -> list["DataProto"]:
|
||||
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split.
|
||||
|
||||
Args:
|
||||
@ -728,7 +728,7 @@ class DataProto:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def concat(data: List["DataProto"]) -> "DataProto":
|
||||
def concat(data: list["DataProto"]) -> "DataProto":
|
||||
"""Concat a list of DataProto. The batch is concatenated among dim=0.
|
||||
The meta_info is assumed to be identical and will use the first one.
|
||||
|
||||
@ -802,7 +802,7 @@ class DataProto:
|
||||
meta_info=self.meta_info,
|
||||
)
|
||||
|
||||
def unfold_column_chunks(self, n_split: int, split_keys: Optional[List[str]] = None):
|
||||
def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None):
|
||||
"""Split along the second dim into `n_split`, unfold it to the first dim (batch dim)
|
||||
Useful in passing grouped tensors that doesn't want to be shuffled in dataset.
|
||||
keys not in split_keys are repeated to match the shape
|
||||
@ -906,15 +906,15 @@ class DataProtoFuture:
|
||||
"""
|
||||
|
||||
collect_fn: Callable
|
||||
futures: List[ray.ObjectRef]
|
||||
futures: list[ray.ObjectRef]
|
||||
dispatch_fn: Callable = None
|
||||
|
||||
@staticmethod
|
||||
def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture":
|
||||
def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture":
|
||||
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data)
|
||||
return output
|
||||
|
||||
def chunk(self, chunks: int) -> List["DataProtoFuture"]:
|
||||
def chunk(self, chunks: int) -> list["DataProtoFuture"]:
|
||||
from functools import partial
|
||||
|
||||
arg_future_lst = []
|
||||
|
@ -15,7 +15,6 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from types import FunctionType
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from verl.protocol import DataProtoFuture, _padding_size_key
|
||||
from verl.utils.py_functional import DynamicEnum
|
||||
@ -79,12 +78,12 @@ def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
|
||||
|
||||
splitted_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (DataProto, DataProtoFuture))
|
||||
assert isinstance(arg, DataProto | DataProtoFuture)
|
||||
splitted_args.append(arg.chunk(chunks=chunks))
|
||||
|
||||
splitted_kwargs = {}
|
||||
for key, val in kwargs.items():
|
||||
assert isinstance(val, (DataProto, DataProtoFuture))
|
||||
assert isinstance(val, DataProto | DataProtoFuture)
|
||||
splitted_kwargs[key] = val.chunk(chunks=chunks)
|
||||
|
||||
return splitted_args, splitted_kwargs
|
||||
@ -99,7 +98,7 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
|
||||
data_proto_len = None
|
||||
padding_size = None
|
||||
for arg in args:
|
||||
assert isinstance(arg, (DataProto, DataProtoFuture))
|
||||
assert isinstance(arg, DataProto | DataProtoFuture)
|
||||
if isinstance(arg, DataProto) and arg.is_padding_enabled():
|
||||
# for padding, we only support DataProto with same length
|
||||
if data_proto_len is None:
|
||||
@ -116,7 +115,7 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
|
||||
splitted_args.append(arg.chunk(chunks=chunks))
|
||||
|
||||
for key, val in kwargs.items():
|
||||
assert isinstance(val, (DataProto, DataProtoFuture))
|
||||
assert isinstance(val, DataProto | DataProtoFuture)
|
||||
if isinstance(val, DataProto) and val.is_padding_enabled():
|
||||
# for padding, we only support DataProto with same length
|
||||
if data_proto_len is None:
|
||||
@ -169,7 +168,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
|
||||
assert isinstance(arg, tuple | list) and len(arg) == worker_group.dp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
@ -179,7 +178,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs):
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
|
||||
assert isinstance(v, tuple | list) and len(v) == worker_group.dp_size
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
@ -216,7 +215,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
|
||||
return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs)
|
||||
|
||||
|
||||
def _concat_data_proto_or_future(output: List):
|
||||
def _concat_data_proto_or_future(output: list):
|
||||
import ray
|
||||
|
||||
from verl.protocol import DataProto, DataProtoFuture
|
||||
@ -245,7 +244,7 @@ def collect_megatron_compute_data_proto(worker_group, output):
|
||||
|
||||
output = collect_megatron_compute(worker_group, output)
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
return _concat_data_proto_or_future(output)
|
||||
|
||||
@ -265,7 +264,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
|
||||
|
||||
all_args = []
|
||||
for arg in args:
|
||||
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_cp_size
|
||||
assert isinstance(arg, list | tuple) and len(arg) == pp_dp_cp_size
|
||||
transformed_args = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
@ -290,7 +289,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
|
||||
|
||||
all_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}"
|
||||
assert isinstance(v, list | tuple) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}"
|
||||
transformed_v = []
|
||||
for i in range(worker_group.world_size):
|
||||
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
|
||||
@ -359,9 +358,9 @@ def dispatch_dp_compute(worker_group, *args, **kwargs):
|
||||
|
||||
assert isinstance(worker_group, WorkerGroup)
|
||||
for arg in args:
|
||||
assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size
|
||||
assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size
|
||||
for k, v in kwargs.items():
|
||||
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
|
||||
assert isinstance(v, tuple | list) and len(v) == worker_group.world_size
|
||||
return args, kwargs
|
||||
|
||||
|
||||
@ -403,7 +402,7 @@ def collect_dp_compute_data_proto(worker_group, output):
|
||||
from verl.protocol import DataProto
|
||||
|
||||
for o in output:
|
||||
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}"
|
||||
|
||||
output = collect_dp_compute(worker_group, output)
|
||||
return _concat_data_proto_or_future(output)
|
||||
@ -489,10 +488,10 @@ def get_predefined_execute_fn(execute_mode):
|
||||
|
||||
|
||||
def _check_dispatch_mode(dispatch_mode):
|
||||
assert isinstance(dispatch_mode, (Dispatch, Dict)), (
|
||||
assert isinstance(dispatch_mode, Dispatch | dict), (
|
||||
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
|
||||
)
|
||||
if isinstance(dispatch_mode, Dict):
|
||||
if isinstance(dispatch_mode, dict):
|
||||
necessary_keys = ["dispatch_fn", "collect_fn"]
|
||||
for key in necessary_keys:
|
||||
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from verl.single_controller.base import ResourcePool, WorkerGroup
|
||||
|
||||
@ -25,7 +24,7 @@ class MegatronWorkerGroup(WorkerGroup):
|
||||
self._megatron_rank_info = None
|
||||
self._megatron_global_info: DistGlobalInfo = None
|
||||
|
||||
def init_megatron(self, default_megatron_kwargs: Dict = None):
|
||||
def init_megatron(self, default_megatron_kwargs: dict = None):
|
||||
raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten")
|
||||
|
||||
def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import ray
|
||||
|
||||
@ -22,7 +21,7 @@ class WorkerGroupRegisterCenter:
|
||||
def __init__(self, rank_zero_info):
|
||||
self.rank_zero_info = rank_zero_info
|
||||
# rank -> node_id
|
||||
self.workers_info: Dict[int, str] = {}
|
||||
self.workers_info: dict[int, str] = {}
|
||||
|
||||
def get_rank_zero_info(self):
|
||||
return self.rank_zero_info
|
||||
@ -30,7 +29,7 @@ class WorkerGroupRegisterCenter:
|
||||
def set_worker_info(self, rank, node_id) -> None:
|
||||
self.workers_info[rank] = node_id
|
||||
|
||||
def get_worker_info(self) -> Dict[int, str]:
|
||||
def get_worker_info(self) -> dict[int, str]:
|
||||
return self.workers_info
|
||||
|
||||
|
||||
|
@ -18,7 +18,6 @@ the class for Worker
|
||||
import os
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import ray
|
||||
|
||||
@ -246,7 +245,7 @@ class Worker(WorkerHelper):
|
||||
os.environ["LOCAL_RANK"] = local_rank
|
||||
get_torch_device().set_device(int(local_rank))
|
||||
|
||||
def _configure_with_store(self, store: Dict):
|
||||
def _configure_with_store(self, store: dict):
|
||||
"""
|
||||
This function should only be called inside by WorkerGroup
|
||||
"""
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable
|
||||
|
||||
from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn
|
||||
|
||||
@ -60,14 +60,14 @@ class ResourcePool:
|
||||
def store(self):
|
||||
return self._store
|
||||
|
||||
def local_world_size_list(self) -> List[int]:
|
||||
def local_world_size_list(self) -> list[int]:
|
||||
"""Returns a flat list where each process has its local world size."""
|
||||
nested_local_world_size_list = [
|
||||
[local_world_size for _ in range(local_world_size)] for local_world_size in self._store
|
||||
]
|
||||
return [item for row in nested_local_world_size_list for item in row]
|
||||
|
||||
def local_rank_list(self) -> List[int]:
|
||||
def local_rank_list(self) -> list[int]:
|
||||
"""Returns a flat list of local ranks for all processes across all nodes."""
|
||||
nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store]
|
||||
return [item for row in nested_local_rank_list for item in row]
|
||||
@ -99,7 +99,7 @@ class ClassWithInitArgs:
|
||||
return self.cls(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
|
||||
def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None:
|
||||
"""Continuously monitors worker processes and raises SIGABRT if any worker dies.
|
||||
|
||||
Args:
|
||||
@ -201,7 +201,7 @@ class WorkerGroup:
|
||||
if hasattr(method, MAGIC_ATTR):
|
||||
# this method is decorated by register
|
||||
attribute = getattr(method, MAGIC_ATTR)
|
||||
assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}"
|
||||
assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}"
|
||||
assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key"
|
||||
|
||||
dispatch_mode = attribute["dispatch_mode"]
|
||||
|
@ -17,7 +17,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
@ -62,7 +62,7 @@ def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, block
|
||||
return type(method_name, (Functor,), {})()
|
||||
|
||||
|
||||
def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[PlacementGroup]:
|
||||
def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]:
|
||||
"""
|
||||
Sort the placement groups by node ip, all bundles in a single placement group should be on the same node.
|
||||
|
||||
@ -85,7 +85,7 @@ def sort_placement_group_by_node_ip(pgs: List[PlacementGroup]) -> List[Placement
|
||||
class RayResourcePool(ResourcePool):
|
||||
def __init__(
|
||||
self,
|
||||
process_on_nodes: Optional[List[int]] = None,
|
||||
process_on_nodes: Optional[list[int]] = None,
|
||||
use_gpu: bool = True,
|
||||
name_prefix: str = None,
|
||||
max_colocate_count: int = 10,
|
||||
@ -134,8 +134,8 @@ class RayResourcePool(ResourcePool):
|
||||
|
||||
|
||||
def extract_pg_from_exist(
|
||||
resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool
|
||||
) -> List:
|
||||
resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool
|
||||
) -> list:
|
||||
src_pgs = [
|
||||
pg
|
||||
for role_name, resource_pool in resource_pools.items()
|
||||
@ -146,7 +146,7 @@ def extract_pg_from_exist(
|
||||
sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True)
|
||||
sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True)
|
||||
|
||||
unsorted_pgs: List[Tuple[int, PlacementGroup]] = []
|
||||
unsorted_pgs: list[tuple[int, PlacementGroup]] = []
|
||||
searching_idx = 0
|
||||
for request_process, original_idx in sorted_process_on_nodes:
|
||||
assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node"
|
||||
@ -195,7 +195,7 @@ class RayClassWithInitArgs(ClassWithInitArgs):
|
||||
"""
|
||||
self._additional_resource = additional_resource
|
||||
|
||||
def update_options(self, options: Dict):
|
||||
def update_options(self, options: dict):
|
||||
"""Update the Ray actor creation options.
|
||||
|
||||
Args:
|
||||
@ -269,7 +269,7 @@ class RayWorkerGroup(WorkerGroup):
|
||||
name_prefix: str = None,
|
||||
detached=False,
|
||||
worker_names=None,
|
||||
worker_handles: List[ray.actor.ActorHandle] = None,
|
||||
worker_handles: list[ray.actor.ActorHandle] = None,
|
||||
ray_wait_register_center_timeout: int = 300,
|
||||
device_name="cuda",
|
||||
**kwargs,
|
||||
@ -499,7 +499,6 @@ class RayWorkerGroup(WorkerGroup):
|
||||
prefix: str = actor_name + "_"
|
||||
for method_name in dir(worker_group):
|
||||
if method_name.startswith(prefix):
|
||||
# only valid when Python >= 3.9
|
||||
original_method_name = method_name.removeprefix(prefix)
|
||||
method = getattr(worker_group, method_name)
|
||||
setattr(worker_group, original_method_name, method)
|
||||
@ -740,7 +739,7 @@ def _unwrap_ray_remote(cls):
|
||||
return cls
|
||||
|
||||
|
||||
def _determine_fsdp_megatron_base_class(mros: List):
|
||||
def _determine_fsdp_megatron_base_class(mros: list):
|
||||
"""
|
||||
- megatron: base class should be MegatronWorker
|
||||
- fsdp: base class should be Worker
|
||||
@ -836,7 +835,11 @@ def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs])
|
||||
self.init_kwargs_dict = init_kwargs_dict
|
||||
|
||||
for cls_name, udc, ud_args, ud_kwargs in zip(
|
||||
self.cls_names, self.raw_cls_dict.values(), self.init_args_dict.values(), self.init_kwargs_dict.values()
|
||||
self.cls_names,
|
||||
self.raw_cls_dict.values(),
|
||||
self.init_args_dict.values(),
|
||||
self.init_kwargs_dict.values(),
|
||||
strict=True,
|
||||
):
|
||||
with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}):
|
||||
udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
|
||||
@ -55,7 +55,7 @@ class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||
self,
|
||||
resource_pool: RayResourcePool,
|
||||
ray_cls_with_init: RayClassWithInitArgs,
|
||||
default_megatron_kwargs: Dict = None,
|
||||
default_megatron_kwargs: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -70,7 +70,7 @@ class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup):
|
||||
self.execute_rank_zero_async(method_name="get_megatron_global_info")
|
||||
)
|
||||
|
||||
def init_megatron(self, default_megatron_kwargs: Optional[Dict] = None):
|
||||
def init_megatron(self, default_megatron_kwargs: Optional[dict] = None):
|
||||
# after super, we will call init of each worker
|
||||
if not self._is_init_with_detached_workers:
|
||||
# only init_megatron if the WorkerGroup is created from scratch
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.utils.rollout_trace import rollout_trace_op
|
||||
@ -58,7 +58,7 @@ class BaseTool:
|
||||
return instance_id
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
"""Execute the tool.
|
||||
|
||||
Args:
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.utils.reward_score import geo3k
|
||||
@ -75,7 +75,7 @@ class Geo3kTool(BaseTool):
|
||||
return instance_id, None
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
answer = parameters.get("answer", "")
|
||||
if not isinstance(answer, str):
|
||||
answer = str(answer)
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from verl.utils.reward_score import gsm8k
|
||||
@ -75,7 +75,7 @@ class Gsm8kTool(BaseTool):
|
||||
return instance_id
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
answer = parameters.get("answer", "")
|
||||
if not isinstance(answer, str):
|
||||
answer = str(answer)
|
||||
|
@ -15,7 +15,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastmcp.exceptions import ClientError
|
||||
@ -60,7 +60,7 @@ class MCPBaseTool(BaseTool):
|
||||
}
|
||||
return instance_id
|
||||
|
||||
async def _call_tool(self, instance_id, parameters) -> Tuple[str, dict]:
|
||||
async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]:
|
||||
err_msg = ""
|
||||
try:
|
||||
call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)
|
||||
@ -77,7 +77,7 @@ class MCPBaseTool(BaseTool):
|
||||
return result, metadata
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
if self.name == "" or self.name is None or parameters is None:
|
||||
error_msg = "Error: 'parameters' is missing or empty."
|
||||
logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}")
|
||||
@ -111,6 +111,6 @@ class MCPBaseTool(BaseTool):
|
||||
if instance_id in self._instance_dict:
|
||||
del self._instance_dict[instance_id]
|
||||
|
||||
def _parse_tool_result(self, content: list) -> Tuple[str, dict]:
|
||||
def _parse_tool_result(self, content: list) -> tuple[str, dict]:
|
||||
tools_content = [part.text for part in filter(lambda x: x.type == "text", content)]
|
||||
return " ".join(tools_content), {}
|
||||
|
@ -16,7 +16,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from verl.tools.mcp_base_tool import MCPBaseTool
|
||||
|
||||
@ -30,7 +29,7 @@ class MCPSearchTool(MCPBaseTool):
|
||||
def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
|
||||
super().__init__(config, tool_schema)
|
||||
|
||||
def _parse_tool_result(self, content: list) -> Tuple[str, dict]:
|
||||
def _parse_tool_result(self, content: list) -> tuple[str, dict]:
|
||||
res = ""
|
||||
res_cnt = 0
|
||||
query_list = []
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
import threading
|
||||
from contextlib import ExitStack
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
import ray
|
||||
@ -163,7 +163,7 @@ class SandboxFusionTool(BaseTool):
|
||||
return instance_id
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
code = parameters.get("code", "")
|
||||
timeout = parameters.get("timeout", self.default_timeout)
|
||||
language = parameters.get("language", self.default_language)
|
||||
|
@ -19,7 +19,7 @@ import os
|
||||
import threading
|
||||
from contextlib import ExitStack
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
import ray
|
||||
@ -226,7 +226,7 @@ class SearchTool(BaseTool):
|
||||
return result_text, metadata
|
||||
|
||||
@rollout_trace_op
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
|
||||
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
|
||||
"""Execute the search tool.
|
||||
|
||||
Args:
|
||||
|
@ -19,7 +19,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@ -33,11 +33,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def call_search_api(
|
||||
retrieval_service_url: str,
|
||||
query_list: List[str],
|
||||
query_list: list[str],
|
||||
topk: int = 3,
|
||||
return_scores: bool = True,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
) -> tuple[Optional[dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Calls the remote search API to perform retrieval with retry logic for various errors,
|
||||
using increasing delay between retries. Logs internal calls with a unique ID.
|
||||
@ -140,11 +140,11 @@ def _passages2string(retrieval_result):
|
||||
|
||||
def perform_single_search_batch(
|
||||
retrieval_service_url: str,
|
||||
query_list: List[str],
|
||||
query_list: list[str],
|
||||
topk: int = 3,
|
||||
concurrent_semaphore: Optional[threading.Semaphore] = None,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Performs a single batch search for multiple queries (original search tool behavior).
|
||||
|
||||
|
@ -17,7 +17,7 @@ Metrics related to the PPO trainer.
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -27,7 +27,7 @@ from verl.utils.import_utils import deprecated
|
||||
|
||||
|
||||
@deprecated("verl.utils.metric.reduce_metrics")
|
||||
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]:
|
||||
"""
|
||||
Reduces a dictionary of metric lists by computing the mean of each list.
|
||||
|
||||
@ -47,7 +47,7 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
return reduce_metrics(metrics)
|
||||
|
||||
|
||||
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
|
||||
"""
|
||||
Computes information about prompts and responses from a batch.
|
||||
|
||||
@ -77,7 +77,7 @@ def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]:
|
||||
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Computes various metrics from a batch of data for PPO training.
|
||||
|
||||
@ -180,7 +180,7 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str,
|
||||
return metrics
|
||||
|
||||
|
||||
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
|
||||
def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]:
|
||||
"""
|
||||
Computes timing metrics for different processing stages in PPO training.
|
||||
|
||||
@ -222,7 +222,7 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
|
||||
}
|
||||
|
||||
|
||||
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
|
||||
def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]:
|
||||
"""
|
||||
Computes throughput metrics for PPO training.
|
||||
|
||||
@ -416,7 +416,9 @@ def process_validation_metrics(
|
||||
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
|
||||
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
|
||||
if var2vals.get("pred", None) is not None:
|
||||
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
|
||||
vote_data = [
|
||||
{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
|
||||
]
|
||||
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
|
||||
data=vote_data,
|
||||
subset_size=n,
|
||||
|
@ -26,7 +26,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@ -61,7 +61,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seql
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
WorkerType = Type[Worker]
|
||||
WorkerType = type[Worker]
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
@ -674,7 +674,7 @@ class RayPPOTrainer:
|
||||
import numpy as np
|
||||
|
||||
# Create tuples of (input, output, score) and sort by input text
|
||||
samples = list(zip(inputs, outputs, scores))
|
||||
samples = list(zip(inputs, outputs, scores, strict=True))
|
||||
samples.sort(key=lambda x: x[0]) # Sort by input text
|
||||
|
||||
# Use fixed random seed for deterministic shuffling
|
||||
|
@ -263,10 +263,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
|
||||
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
|
||||
torch_stray_tensor = isinstance(
|
||||
tensor,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
)
|
||||
need_offload = not torch_stray_tensor
|
||||
need_offload = need_offload and self.tensor_need_offloading_checker(tensor)
|
||||
@ -451,7 +448,7 @@ class ActivationHandler:
|
||||
if len(kwarg_keys) == 0:
|
||||
return flat_args, {}
|
||||
args = flat_args[: -len(kwarg_keys)]
|
||||
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
|
||||
kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True))
|
||||
return args, kwargs
|
||||
|
||||
def _ckpt_forward(self, forward_method, *args, **kwargs):
|
||||
@ -526,7 +523,7 @@ def enable_activation_offloading(model, strategy, enable_ckpt=False):
|
||||
|
||||
def get_layers(module):
|
||||
for name, child in module.named_children():
|
||||
if not isinstance(child, (FSDP, FSDP2)):
|
||||
if not isinstance(child, FSDP | FSDP2):
|
||||
get_layers(child)
|
||||
else:
|
||||
wrapped_module = child
|
||||
|
@ -15,7 +15,6 @@
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -46,7 +45,7 @@ class BaseCheckpointManager:
|
||||
model,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
||||
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
|
||||
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
|
||||
checkpoint_config: DictConfig = None,
|
||||
):
|
||||
self.checkpoint_config = checkpoint_config
|
||||
|
@ -17,7 +17,7 @@ import logging
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -73,7 +73,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
|
||||
model: FSDP,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
|
||||
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
|
||||
processing_class: PreTrainedTokenizer | ProcessorMixin = None,
|
||||
checkpoint_config: DictConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
__all__ = ["omega_conf_to_dataclass"]
|
||||
|
||||
|
||||
def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Optional[Type[Any]] = None) -> Any:
|
||||
def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any:
|
||||
"""
|
||||
Convert an OmegaConf DictConfig to a dataclass.
|
||||
|
||||
@ -36,7 +36,7 @@ def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Opt
|
||||
if not config:
|
||||
return dataclass_type if dataclass_type is None else dataclass_type()
|
||||
# Got an object
|
||||
if not isinstance(config, (DictConfig, ListConfig, dict, list)):
|
||||
if not isinstance(config, DictConfig | ListConfig | dict | list):
|
||||
return config
|
||||
|
||||
if dataclass_type is None:
|
||||
@ -59,7 +59,7 @@ def omega_conf_to_dataclass(config: Union[DictConfig, dict], dataclass_type: Opt
|
||||
return config_object
|
||||
|
||||
|
||||
def update_dict_with_config(dictionary: Dict, config: DictConfig):
|
||||
def update_dict_with_config(dictionary: dict, config: DictConfig):
|
||||
for key in dictionary:
|
||||
if hasattr(config, key):
|
||||
dictionary[key] = getattr(config, key)
|
||||
|
@ -18,7 +18,7 @@ Multi-turn SFT dataset that supports training on conversation data with multiple
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -48,7 +48,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
Dataset for multi-turn conversations where each assistant response should be trained
|
||||
"""
|
||||
|
||||
def __init__(self, parquet_files: Union[str, List[str]], tokenizer, config=None):
|
||||
def __init__(self, parquet_files: str | list[str], tokenizer, config=None):
|
||||
# Set defaults and extract parameters from config if provided
|
||||
config = config or {}
|
||||
self.truncation = config.get("truncation", "error")
|
||||
@ -60,7 +60,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking")
|
||||
assert self.truncation in ["error", "left", "right"]
|
||||
|
||||
if not isinstance(parquet_files, List):
|
||||
if not isinstance(parquet_files, list):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
@ -80,7 +80,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
import numpy
|
||||
import pandas
|
||||
|
||||
while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1:
|
||||
while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:
|
||||
ls = ls[0]
|
||||
return ls
|
||||
|
||||
@ -109,13 +109,13 @@ class MultiTurnSFTDataset(Dataset):
|
||||
|
||||
def _process_message_tokens(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
is_assistant: bool = False,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""
|
||||
Process tokens for a single message or a group of messages.
|
||||
|
||||
@ -185,10 +185,10 @@ class MultiTurnSFTDataset(Dataset):
|
||||
def _validate_and_convert_tokens(
|
||||
self,
|
||||
full_tokens: torch.Tensor,
|
||||
concat_tokens: List[int],
|
||||
concat_loss_mask: List[int],
|
||||
concat_attention_mask: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
concat_tokens: list[int],
|
||||
concat_loss_mask: list[int],
|
||||
concat_attention_mask: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Validate tokenization and convert to tensors.
|
||||
|
||||
@ -204,7 +204,7 @@ class MultiTurnSFTDataset(Dataset):
|
||||
full_tokens_list = full_tokens.tolist()
|
||||
|
||||
if len(concat_tokens) != len(full_tokens_list) or not all(
|
||||
a == b for a, b in zip(concat_tokens, full_tokens_list)
|
||||
a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True)
|
||||
):
|
||||
logging.warning(
|
||||
f"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens "
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@ -84,12 +84,12 @@ class RLHFDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_files: Union[str, List[str]],
|
||||
data_files: str | list[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: DictConfig,
|
||||
processor: Optional[ProcessorMixin] = None,
|
||||
):
|
||||
if not isinstance(data_files, (List, ListConfig)):
|
||||
if not isinstance(data_files, list | ListConfig):
|
||||
data_files = [data_files]
|
||||
|
||||
self.data_files = copy.deepcopy(data_files)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -39,7 +38,7 @@ def download_files_distributed(download_fn):
|
||||
class RMDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
parquet_files: Union[str, List[str]],
|
||||
parquet_files: str | list[str],
|
||||
tokenizer,
|
||||
prompt_key="prompt",
|
||||
chosen_key="chosen",
|
||||
@ -48,7 +47,7 @@ class RMDataset(Dataset):
|
||||
add_eos=True,
|
||||
cache_dir="~/.cache/verl/rm",
|
||||
):
|
||||
if not isinstance(parquet_files, List):
|
||||
if not isinstance(parquet_files, list):
|
||||
parquet_files = [parquet_files]
|
||||
|
||||
self.parquet_files = parquet_files
|
||||
|
@ -18,8 +18,6 @@ SFT dataset
|
||||
Each parquet file contains
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from omegaconf.listconfig import ListConfig
|
||||
@ -39,7 +37,7 @@ class SFTDataset(Dataset):
|
||||
config (OmegaConf): the data config
|
||||
"""
|
||||
|
||||
def __init__(self, parquet_files: Union[str, ListConfig], tokenizer, config):
|
||||
def __init__(self, parquet_files: str | ListConfig, tokenizer, config):
|
||||
prompt_key = config.get("prompt_key", "prompt")
|
||||
prompt_dict_keys = config.get("prompt_dict_keys", None)
|
||||
response_key = config.get("response_key", "response")
|
||||
@ -60,8 +58,8 @@ class SFTDataset(Dataset):
|
||||
tokenizer = hf_tokenizer(tokenizer)
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
|
||||
self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key]
|
||||
self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key]
|
||||
self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key]
|
||||
self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key]
|
||||
self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else []
|
||||
self.response_dict_keys = response_dict_keys if response_dict_keys else []
|
||||
|
||||
@ -79,7 +77,7 @@ class SFTDataset(Dataset):
|
||||
import numpy
|
||||
import pandas
|
||||
|
||||
while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1:
|
||||
while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1:
|
||||
ls = ls[0]
|
||||
return ls
|
||||
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import fetch_image, fetch_video
|
||||
|
||||
|
||||
def process_image(image: Union[dict, Image.Image]) -> Image.Image:
|
||||
def process_image(image: dict | Image.Image) -> Image.Image:
|
||||
if isinstance(image, Image.Image):
|
||||
return image.convert("RGB")
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,7 +22,7 @@ def _fused_linear_for_ppo_fwd(
|
||||
vocab_weights: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
logits = (hidden_states @ vocab_weights.t()) / temperature
|
||||
orig_dtype = logits.dtype
|
||||
logits = logits.to(torch.float32)
|
||||
@ -44,7 +44,7 @@ def _fused_linear_for_ppo_bwd(
|
||||
vocab_weights: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
logits = (hidden_states @ vocab_weights.t()) / temperature
|
||||
orig_dtype = logits.dtype
|
||||
logits = logits.to(torch.float32)
|
||||
@ -81,7 +81,7 @@ class FusedLinearForPPOFunction(torch.autograd.Function):
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
chunk_size: int = 512,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
ctx.set_materialize_grads(False)
|
||||
|
||||
# Cast to a 2D tensor of the shape [T, D] for ease of working
|
||||
@ -205,7 +205,7 @@ class FusedLinearForPPO(torch.nn.Module):
|
||||
vocab_weights: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
input_ids = input_ids.to(torch.int64)
|
||||
return FusedLinearForPPOFunction.apply(
|
||||
hidden_states,
|
||||
|
@ -19,7 +19,6 @@ import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -308,7 +307,7 @@ def parallel_load_safetensors(filepath):
|
||||
return shard_states
|
||||
|
||||
|
||||
def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]):
|
||||
def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]):
|
||||
"""
|
||||
Generate a function to initialize sub-modules in the `module` with `shard_states`
|
||||
from huggingface checkpoint.
|
||||
@ -339,7 +338,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor
|
||||
else: # buffer
|
||||
param = torch.empty_like(state.data, device=device)
|
||||
loaded = shard_states[param_name]
|
||||
if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
|
||||
if isinstance(loaded, torch.nn.Parameter | torch.Tensor):
|
||||
# NOTE: loaded.dtype can be different with param.dtype
|
||||
param.data.copy_(loaded.data)
|
||||
dist.broadcast(param.data, src=dist.get_rank())
|
||||
|
@ -21,7 +21,7 @@ import importlib.util
|
||||
import os
|
||||
import warnings
|
||||
from functools import cache, wraps
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@cache
|
||||
@ -72,7 +72,7 @@ def is_trl_available():
|
||||
def import_external_libs(external_libs=None):
|
||||
if external_libs is None:
|
||||
return
|
||||
if not isinstance(external_libs, List):
|
||||
if not isinstance(external_libs, list):
|
||||
external_libs = [external_libs]
|
||||
import importlib
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user