mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[megatron] feat: add mindspeed engine and support sft (#3599)
### What does this PR do? As per title. Co-authored with @baymax591 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn>
This commit is contained in:
@ -38,6 +38,8 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
|
||||
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
|
||||
"verl/third_party/torch/distributed/_state_dict_utils.py", # torch monkey patch fixes
|
||||
"verl/third_party/torch/distributed/checkpoint/state_dict.py", # torch monkey patch fixes
|
||||
"verl/workers/engine/base.py", # appear in default device_name
|
||||
"verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name
|
||||
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
|
||||
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
|
||||
]
|
||||
|
@ -15,13 +15,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch_npu import npu_rotary_mul as apply_rotary_emb
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in
|
||||
@ -158,6 +165,35 @@ def moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_2(
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[str | dict[str, int]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
"""
|
||||
Checks the availability of Flash Attention 2 and compatibility with the current model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute
|
||||
`attn_implementation` to "flash_attention_2" so that the model can initialize
|
||||
the correct attention module.
|
||||
"""
|
||||
if not cls._supports_flash_attn_2:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where the"
|
||||
f" model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
|
||||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
logger.info("Detect using FlashAttention2 on Ascend NPU.")
|
||||
return config
|
||||
|
||||
|
||||
modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward
|
||||
modeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward
|
||||
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu
|
||||
@ -166,3 +202,6 @@ modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward
|
||||
modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu
|
||||
modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward
|
||||
modeling_qwen3.Qwen3MLP.forward = silu_forward
|
||||
|
||||
if get_version("transformers") == "4.52.4":
|
||||
PreTrainedModel._check_and_enable_flash_attn_2 = _check_and_enable_flash_attn_2
|
||||
|
@ -16,6 +16,14 @@ from .fsdp import FSDPEngine, FSDPEngineWithLMHead
|
||||
|
||||
__all__ = ["BaseEngine", "EngineRegistry", "FSDPEngine", "FSDPEngineWithLMHead"]
|
||||
|
||||
# Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected
|
||||
try:
|
||||
from .mindspeed import MindspeedEngineWithLMHead
|
||||
|
||||
__all__ += ["MindspeedEngineWithLMHead"]
|
||||
except ImportError:
|
||||
MindspeedEngineWithLMHead = None
|
||||
|
||||
try:
|
||||
from .megatron import MegatronEngine, MegatronEngineWithLMHead
|
||||
|
||||
|
@ -20,6 +20,8 @@ from typing import Any, Callable, Optional
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from verl.utils.device import get_device_name
|
||||
|
||||
|
||||
class BaseEngine:
|
||||
"""
|
||||
@ -189,7 +191,7 @@ class EngineRegistry:
|
||||
_engines = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_type: str, backend: list[str] | str):
|
||||
def register(cls, model_type: str, backend: list[str] | str, device: list[str] | str = "cuda"):
|
||||
"""
|
||||
A class method decorator that registers an engine class with a given key.
|
||||
|
||||
@ -198,6 +200,8 @@ class EngineRegistry:
|
||||
Args:
|
||||
model_type (str): The type of the model
|
||||
backend (list[str] | str): The backend to use for the model type
|
||||
device (list[str] | str): The device type (e.g., "cuda", "npu", "cpu") this engine supports,
|
||||
default is "cuda"
|
||||
|
||||
Returns:
|
||||
A decorator function that takes an engine class and registers it.
|
||||
@ -208,12 +212,15 @@ class EngineRegistry:
|
||||
if model_type not in cls._engines:
|
||||
cls._engines[model_type] = {}
|
||||
|
||||
if isinstance(backend, list):
|
||||
for k in backend:
|
||||
cls._engines[model_type][k] = engine_class
|
||||
else:
|
||||
assert isinstance(backend, str)
|
||||
cls._engines[model_type][backend] = engine_class
|
||||
backends = backend if isinstance(backend, list) else [backend]
|
||||
devices = device if isinstance(device, list) else [device]
|
||||
for current_backend in backends:
|
||||
for current_device in devices:
|
||||
if current_backend not in cls._engines[model_type]:
|
||||
cls._engines[model_type][current_backend] = {}
|
||||
if current_device not in cls._engines[model_type][current_backend]:
|
||||
cls._engines[model_type][current_backend][current_device] = engine_class
|
||||
|
||||
return engine_class
|
||||
|
||||
return decorator
|
||||
@ -222,7 +229,11 @@ class EngineRegistry:
|
||||
def get_engine_cls(cls, model_type: str, backend: str):
|
||||
assert model_type in cls._engines, f"Unknown model_type: {model_type}"
|
||||
assert backend in cls._engines[model_type], f"Unknown backend: {backend}"
|
||||
return cls._engines[model_type][backend]
|
||||
device = get_device_name()
|
||||
assert device in cls._engines[model_type][backend], (
|
||||
f"Unknown device: {device} for model_type: {model_type} and backend: {backend}"
|
||||
)
|
||||
return cls._engines[model_type][backend][device]
|
||||
|
||||
@classmethod
|
||||
def new(cls, model_type, backend, *args, **kwargs):
|
||||
|
@ -697,7 +697,7 @@ class EngineTrainModeCtx:
|
||||
self.engine.mode = None
|
||||
|
||||
|
||||
@EngineRegistry.register(model_type="language_model", backend=["fsdp", "fsdp2"])
|
||||
@EngineRegistry.register(model_type="language_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
|
||||
class FSDPEngineWithLMHead(FSDPEngine):
|
||||
def prepare_model_inputs(self, micro_batch: TensorDict):
|
||||
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
|
||||
@ -1012,7 +1012,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
|
||||
return loss, output
|
||||
|
||||
|
||||
@EngineRegistry.register(model_type="value_model", backend=["fsdp", "fsdp2"])
|
||||
@EngineRegistry.register(model_type="value_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
|
||||
class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
|
||||
"""
|
||||
The only difference between critic and actor is how the raw model output is processed
|
||||
|
17
verl/workers/engine/mindspeed/__init__.py
Normal file
17
verl/workers/engine/mindspeed/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .transformer_impl import MindspeedEngineWithLMHead
|
||||
|
||||
__all__ = ["MindspeedEngineWithLMHead"]
|
45
verl/workers/engine/mindspeed/transformer_impl.py
Normal file
45
verl/workers/engine/mindspeed/transformer_impl.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from mindspeed.megatron_adaptor import repatch
|
||||
|
||||
from verl.trainer.config import CheckpointConfig
|
||||
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
|
||||
|
||||
from ..base import EngineRegistry
|
||||
from ..megatron import MegatronEngineWithLMHead
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
@EngineRegistry.register(model_type="language_model", backend="megatron", device="npu")
|
||||
class MindspeedEngineWithLMHead(MegatronEngineWithLMHead):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: HFModelConfig,
|
||||
engine_config: McoreEngineConfig,
|
||||
optimizer_config: McoreOptimizerConfig,
|
||||
checkpoint_config: CheckpointConfig,
|
||||
):
|
||||
super().__init__(model_config, engine_config, optimizer_config, checkpoint_config)
|
||||
|
||||
repatch_config = {"use_flash_attn": True}
|
||||
if self.engine_config.context_parallel_size > 1:
|
||||
repatch_config["context_parallel_size"] = self.engine_config.context_parallel_size
|
||||
|
||||
repatch(repatch_config)
|
Reference in New Issue
Block a user