[megatron] feat: add mindspeed engine and support sft (#3599)

### What does this PR do?

As per title.

Co-authored with @baymax591 

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: baymax591 <cbai@mail.nwpu.edu.cn>
This commit is contained in:
Huazhong
2025-09-26 14:39:10 +08:00
committed by GitHub
parent 377bbb84f0
commit 2234810235
7 changed files with 132 additions and 10 deletions

View File

@ -38,6 +38,8 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
"verl/third_party/torch/distributed/_state_dict_utils.py", # torch monkey patch fixes
"verl/third_party/torch/distributed/checkpoint/state_dict.py", # torch monkey patch fixes
"verl/workers/engine/base.py", # appear in default device_name
"verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
]

View File

@ -15,13 +15,20 @@
# limitations under the License.
from importlib.metadata import version as get_version
from typing import Optional
import torch
import torch.nn.functional as F
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
from transformers.utils import logging
logger = logging.get_logger(__name__)
# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in
@ -158,6 +165,35 @@ def moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return final_hidden_states, router_logits
@classmethod
def _check_and_enable_flash_attn_2(
cls,
config,
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[str | dict[str, int]] = None,
check_device_map: bool = True,
hard_check_only: bool = False,
) -> PretrainedConfig:
"""
Checks the availability of Flash Attention 2 and compatibility with the current model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute
`attn_implementation` to "flash_attention_2" so that the model can initialize
the correct attention module.
"""
if not cls._supports_flash_attn_2:
raise ValueError(
f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where the"
f" model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
)
if not hard_check_only:
config._attn_implementation = "flash_attention_2"
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return config
modeling_qwen2_5_vl.Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2_5_vl.Qwen2_5_VLMLP.forward = silu_forward
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_qwen2_5_vl_npu
@ -166,3 +202,6 @@ modeling_qwen3_moe.Qwen3MoeSparseMoeBlock.forward = moe_block_forward
modeling_qwen3_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu
modeling_qwen3.Qwen3RMSNorm.forward = rms_norm_forward
modeling_qwen3.Qwen3MLP.forward = silu_forward
if get_version("transformers") == "4.52.4":
PreTrainedModel._check_and_enable_flash_attn_2 = _check_and_enable_flash_attn_2

View File

@ -16,6 +16,14 @@ from .fsdp import FSDPEngine, FSDPEngineWithLMHead
__all__ = ["BaseEngine", "EngineRegistry", "FSDPEngine", "FSDPEngineWithLMHead"]
# Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected
try:
from .mindspeed import MindspeedEngineWithLMHead
__all__ += ["MindspeedEngineWithLMHead"]
except ImportError:
MindspeedEngineWithLMHead = None
try:
from .megatron import MegatronEngine, MegatronEngineWithLMHead

View File

@ -20,6 +20,8 @@ from typing import Any, Callable, Optional
import torch
from tensordict import TensorDict
from verl.utils.device import get_device_name
class BaseEngine:
"""
@ -189,7 +191,7 @@ class EngineRegistry:
_engines = {}
@classmethod
def register(cls, model_type: str, backend: list[str] | str):
def register(cls, model_type: str, backend: list[str] | str, device: list[str] | str = "cuda"):
"""
A class method decorator that registers an engine class with a given key.
@ -198,6 +200,8 @@ class EngineRegistry:
Args:
model_type (str): The type of the model
backend (list[str] | str): The backend to use for the model type
device (list[str] | str): The device type (e.g., "cuda", "npu", "cpu") this engine supports,
default is "cuda"
Returns:
A decorator function that takes an engine class and registers it.
@ -208,12 +212,15 @@ class EngineRegistry:
if model_type not in cls._engines:
cls._engines[model_type] = {}
if isinstance(backend, list):
for k in backend:
cls._engines[model_type][k] = engine_class
else:
assert isinstance(backend, str)
cls._engines[model_type][backend] = engine_class
backends = backend if isinstance(backend, list) else [backend]
devices = device if isinstance(device, list) else [device]
for current_backend in backends:
for current_device in devices:
if current_backend not in cls._engines[model_type]:
cls._engines[model_type][current_backend] = {}
if current_device not in cls._engines[model_type][current_backend]:
cls._engines[model_type][current_backend][current_device] = engine_class
return engine_class
return decorator
@ -222,7 +229,11 @@ class EngineRegistry:
def get_engine_cls(cls, model_type: str, backend: str):
assert model_type in cls._engines, f"Unknown model_type: {model_type}"
assert backend in cls._engines[model_type], f"Unknown backend: {backend}"
return cls._engines[model_type][backend]
device = get_device_name()
assert device in cls._engines[model_type][backend], (
f"Unknown device: {device} for model_type: {model_type} and backend: {backend}"
)
return cls._engines[model_type][backend][device]
@classmethod
def new(cls, model_type, backend, *args, **kwargs):

View File

@ -697,7 +697,7 @@ class EngineTrainModeCtx:
self.engine.mode = None
@EngineRegistry.register(model_type="language_model", backend=["fsdp", "fsdp2"])
@EngineRegistry.register(model_type="language_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class FSDPEngineWithLMHead(FSDPEngine):
def prepare_model_inputs(self, micro_batch: TensorDict):
use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True)
@ -1012,7 +1012,7 @@ class FSDPEngineWithLMHead(FSDPEngine):
return loss, output
@EngineRegistry.register(model_type="value_model", backend=["fsdp", "fsdp2"])
@EngineRegistry.register(model_type="value_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class FSDPEngineWithValueHead(FSDPEngineWithLMHead):
"""
The only difference between critic and actor is how the raw model output is processed

View File

@ -0,0 +1,17 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transformer_impl import MindspeedEngineWithLMHead
__all__ = ["MindspeedEngineWithLMHead"]

View File

@ -0,0 +1,45 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from mindspeed.megatron_adaptor import repatch
from verl.trainer.config import CheckpointConfig
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
from ..base import EngineRegistry
from ..megatron import MegatronEngineWithLMHead
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@EngineRegistry.register(model_type="language_model", backend="megatron", device="npu")
class MindspeedEngineWithLMHead(MegatronEngineWithLMHead):
def __init__(
self,
model_config: HFModelConfig,
engine_config: McoreEngineConfig,
optimizer_config: McoreOptimizerConfig,
checkpoint_config: CheckpointConfig,
):
super().__init__(model_config, engine_config, optimizer_config, checkpoint_config)
repatch_config = {"use_flash_attn": True}
if self.engine_config.context_parallel_size > 1:
repatch_config["context_parallel_size"] = self.engine_config.context_parallel_size
repatch(repatch_config)