mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
As initially mentioned in https://github.com/volcengine/verl/discussions/1941, having structured configuration classes in verl makes argument passing easier for testing and validation. This is an extended thread on the current implementation of configuration schema in verl. Related PRs: - https://github.com/volcengine/verl/pull/2117 - https://github.com/volcengine/verl/pull/2621 # Motivation By moving from loose `omegaconfig.DictConfig`-based parameters to structured dataclasses, we gain: - Type safety & IDE support when accessing fields (e.g. cfg.optim.lr). - Validation hooks via __post_init__ in each class. - Immutable defaults with controlled mutability (e.g., an extra field). - Seamless Hydra/OmegaConf integration and easy per-recipe extension. # Core: BaseConfig hydra natively provides support for converting DictConfig to dataclass, but dataclass does not support accessing attribute via `get()`. We introduce a base class to provide backward compatibility and make the change less abrupt for existing users. All config dataclasses inherit from BaseConfig, which: - Implements collections.abc.Mapping → dict-like iteration/access. - Freezes attributes once set, unless listed in _mutable_fields. - Provides an `extra: dict[str, Any]` for unchecked extensions. ```python @dataclass class BaseConfig(collections.abc.Mapping): """Dict-like, frozen dataclass with opt-in mutability.""" _mutable_fields: set[str] = {"extra"} extra: dict[str, Any] = field(default_factory=dict) def __setattr__(self, name: str, value): if name in self.__dict__ and name not in self._mutable_fields: raise FrozenInstanceError(f"Field '{name}' is frozen") super().__setattr__(name, value) # Mapping methods: get, __getitem__, __iter__, __len__ … ``` # Example Config Classes (verl/trainer/config) Each sub-component of the trainer has its own dataclass, inheriting BaseConfig. ```yaml: critic: checkpoint: _target_: verl.trainer.config.CheckpointConfig save_contents: ["model","optimizer","extra"] load_contents: ["model","optimizer","extra"] async_save: false ``` Definition: ```python @dataclass class CheckpointConfig(BaseConfig): """What to save/load and async behavior.""" save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) async_save: bool = False def __post_init__(self): # validation checks go here after initialization ckpt_cfg = CheckpointConfig(async_save=True) print(ckpt_cfg.save_contents) print(ckpt_cfg.get("save_contents", default_value)) print(ckpt_cfg["save_contents"]) # converting hydra-generated omegaconf.DictConfig to the dataclass config: from verl.utils.config import omegaconf_to_dataclass ckpt_cfg_from_cli = omegaconf_to_dataclass(config.critic.checkpoint) ``` # Extending existing config classes Because now configs become structured, unexpected keys would raise exceptions. To add new keys, there are two ways: ## Explicit class extensions: ```python from verl.workers.config import FSDPActorConfig @dataclass class SPPOActorConfig(FSDPActorConfig): """Add SPPO-specific temperature/penalty.""" sppo_eta: float = 1.0 ``` When using yaml or from command line, update the target config class: ```yaml hydra: searchpath: - file://verl/trainer/config defaults: - ppo_trainer # base trainer config - _self_ # then apply these overrides actor_rollout_ref: actor: _target_: recipe.sppo.config.SPPOActorConfig # **new target dataclass required for extension ** sppo_eta: 1.0 ``` or directly from command line: ```bash python main_sppo.py \ actor_rollout_ref.actor._target_=recipe.sppo.config.SPPOActorConfig \ actor_rollout_ref.actor.sppo_eta=1.0 ``` ## Leverage the `extra` field Adding more keys to the `extra` field of any dataclass that inherits from `BaseConfig` also works. This way there's no need to define your own dataclass in python: ```yaml hydra: searchpath: - file://verl/trainer/config defaults: - ppo_trainer # base trainer config - _self_ # then apply these overrides actor_rollout_ref: actor: extra: sppo_eta: 1.0 ``` # Declaring mutable fields For historical reasons some fields in the configs are mutated inplace in the codebase such as batch size for data/sequence parallelism. We are in the process of deprecating this kind of behavior. However, if you want to intentionally mutate one field, specify it with the `_mutable_fields` attr: ```python @dataclass class CheckpointConfig(BaseConfig): """What to save/load and async behavior.""" _mutable_fields = BaseConfig._mutable_fields | {"save_contents"} # mark save_contents as mutable. save_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) load_contents: list[str] = field(default_factory=lambda: ["model","optimizer","extra"]) async_save: bool = False ``` # Other helpful resources verl default trainer configs combines the following config files together, specified in the `_defaults_` field: https://github.com/volcengine/verl/blob/main/verl/trainer/config/ppo_trainer.yaml#L1-L36 - verl/trainer/config/ppo_trainer.yaml # main config for entrypoint - verl/trainer/config/actor/dp_actor.yaml - verl/trainer/config/critic/dp_critic.yaml - verl/trainer/config/reward_model/dp_reward_model.yaml - verl/trainer/config/rollout/rollout.yaml To quickly peek the default full config in a single file, you can check the auto-generated full config in https://github.com/volcengine/verl/blob/main/verl/trainer/config/_generated_ppo_trainer.yaml # Change log and impact on existing code This PR converts the following fields to structured dataclass in the training pipeline. More can be done in future PRs (contributions from the community is welcome) - [x] actor_rollout_ref.actor - [x] critic - [ ] actor_rollout_ref.rollout - [ ] actor_rollout_ref.ref - [ ] reward_model - [ ] data - [ ] trainer Changes needed for existing code that added new fields to config: - see recipe/sppo for an example - `OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))` now has to manually changed to `self.config.model.get("override_config", {})`. Because OmegaConf.to_container expects a DictConfig but config.model.override_config is already a dict. # Other Breaking Changes critic.optim.lr for megatron changed from 1e-6 to 1e-5 --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joel <wuxibin@bytedance.com> Co-authored-by: Cheetah <1659275352@qq.com> Co-authored-by: 杨睿 <yangruipis@163.com> Co-authored-by: X. HU <huxiaobo@zju.edu.cn> Co-authored-by: Le Xue <48175490+ShareLer@users.noreply.github.com> Co-authored-by: Ziheng Jiang <ziheng@apache.org> Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
515 lines
22 KiB
Python
515 lines
22 KiB
Python
#
|
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
try:
|
|
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
|
except ImportError:
|
|
# FIXME: remove these manually included paths
|
|
import sys
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")))
|
|
finally:
|
|
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
|
|
|
import verl.utils.torch_functional as verl_F
|
|
|
|
compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
|
|
|
|
MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
|
|
VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False)
|
|
LOW_MEMORY = os.environ.get("LOW_MEMORY", False)
|
|
LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)
|
|
|
|
|
|
def run_torch_entropy(
|
|
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
|
|
) -> list[torch.Tensor]:
|
|
# [num_tokens, vocab_size]
|
|
if len(hidden.shape) > 2:
|
|
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
|
|
if len(labels.shape) > 1:
|
|
labels = labels.view(-1)
|
|
logits = torch.matmul(
|
|
hidden.to(torch.float32),
|
|
weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),
|
|
)
|
|
logits /= temperature
|
|
pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
|
|
entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
|
|
entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
|
|
entropy = entropy_a - entropy_b
|
|
logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens]
|
|
logprobs = torch.neg(logprobs)
|
|
return logprobs, entropy
|
|
|
|
|
|
class TorchEntropyTP(torch.autograd.Function):
|
|
"""
|
|
it is used for testing the correctness of the kernel
|
|
it is not efficient and is not recommended to use in practice
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
hidden: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
temperature: float,
|
|
dist_process_group: torch.distributed.ProcessGroup,
|
|
):
|
|
# weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size]
|
|
ctx.original_hidden_shape = hidden.shape
|
|
if len(hidden.shape) > 2:
|
|
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
|
|
if len(labels.shape) > 1:
|
|
labels = labels.view(-1)
|
|
|
|
logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size]
|
|
logits /= temperature
|
|
whole_logits = torch.empty(
|
|
(logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),
|
|
dtype=logits.dtype,
|
|
device=logits.device,
|
|
)
|
|
whole_logits_ref = [
|
|
whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]
|
|
for i in range(dist.get_world_size(dist_process_group))
|
|
]
|
|
dist.all_gather(whole_logits_ref, logits, group=dist_process_group)
|
|
|
|
pd = torch.nn.functional.softmax(whole_logits, dim=-1)
|
|
entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens]
|
|
entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens]
|
|
entropy = entropy_a - entropy_b
|
|
|
|
logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none")
|
|
logprobs = torch.neg(logprobs)
|
|
|
|
ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)
|
|
ctx.dist_process_group = dist_process_group
|
|
ctx.temperature = temperature
|
|
return logprobs, entropy
|
|
|
|
@staticmethod
|
|
def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):
|
|
hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors
|
|
dist_process_group = ctx.dist_process_group
|
|
temperature = ctx.temperature
|
|
batch_size, hidden_size = hidden.shape
|
|
vocab_size, hidden_size = weight.shape
|
|
rank = dist.get_rank(dist_process_group)
|
|
|
|
# Compute softmax probabilities
|
|
maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)
|
|
exp_logits = torch.exp(whole_logits - maximum)
|
|
accumulate = exp_logits.sum(dim=-1, keepdim=True)
|
|
pd = exp_logits / accumulate
|
|
|
|
# Gradient for entropy
|
|
# entropy = entropy_a - entropy_b
|
|
# entropy_a = log(sum(exp(logits)))
|
|
# entropy_b = sum(pd * logits)
|
|
# d_entropy_a/d_logits = pd
|
|
# d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1)
|
|
# d_entropy/d_logits = d_entropy_a - d_entropy_b
|
|
# d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1)
|
|
# d_entropy/d_logits = -pd * (logits - b.unsqueeze(1))
|
|
d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))
|
|
|
|
# Gradient for logprobs
|
|
# logprobs = -cross_entropy = -log(pd[labels])
|
|
# d_logprobs/d_logits = (pd - one_hot(labels))
|
|
one_hot = torch.zeros_like(whole_logits)
|
|
one_hot.scatter_(1, labels.unsqueeze(1), 1)
|
|
g_logprobs = torch.neg(g_logprobs)
|
|
d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)
|
|
# NOTE: This will lead to wrong result
|
|
# d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot
|
|
|
|
# Combine gradients
|
|
d_logits = d_logits_entropy + d_logits_logprobs
|
|
d_logits /= temperature
|
|
|
|
# Get local slice of gradients
|
|
local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]
|
|
|
|
# Compute gradients for hidden and weight
|
|
d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))
|
|
d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))
|
|
d_hidden = d_hidden.view(ctx.original_hidden_shape)
|
|
|
|
return d_hidden, d_weight, None, None, None
|
|
|
|
|
|
run_torch_entropy_tp = TorchEntropyTP.apply
|
|
|
|
|
|
class TestLinearCrossEntropy_TensorParallel:
|
|
def __init__(self):
|
|
dist.init_process_group(backend="nccl")
|
|
self.group = dist.group.WORLD
|
|
|
|
self.local_rank = dist.get_rank(self.group)
|
|
self.world_size = dist.get_world_size(self.group)
|
|
device = torch.device(f"cuda:{self.local_rank}")
|
|
torch.cuda.set_device(device)
|
|
print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}")
|
|
|
|
def initialize(self, test_case_idx: int, temperature: float = 1.5):
|
|
self.test_case_idx = test_case_idx
|
|
self.temperature = temperature
|
|
|
|
def shutdown(self):
|
|
dist.destroy_process_group()
|
|
|
|
def cleanup(self):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
import gc
|
|
|
|
gc.collect()
|
|
torch.cuda.synchronize()
|
|
|
|
def generate_hyper(self):
|
|
global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES
|
|
|
|
self.dtype = torch.bfloat16
|
|
if self.test_case_idx == 0:
|
|
self.batch_size = 1
|
|
self.num_tokens = 1937
|
|
self.hidden_size = 3584
|
|
self.vocab_size = 152064
|
|
elif self.test_case_idx == 1:
|
|
self.batch_size = 1
|
|
self.num_tokens = 2169
|
|
self.hidden_size = 896
|
|
self.vocab_size = 151936
|
|
elif self.test_case_idx == 2:
|
|
self.batch_size = 1
|
|
self.num_tokens = 1530
|
|
self.hidden_size = 2048
|
|
self.vocab_size = 32256
|
|
elif self.test_case_idx == 3:
|
|
self.batch_size = 1
|
|
self.num_tokens = 1388
|
|
self.hidden_size = 4096
|
|
self.vocab_size = 102400
|
|
elif self.test_case_idx == 4:
|
|
self.batch_size = 1
|
|
self.num_tokens = 8192
|
|
self.hidden_size = 4096
|
|
self.vocab_size = 102400
|
|
else:
|
|
raise ValueError(f"Invalid test case index: {self.test_case_idx}")
|
|
if LOW_MEMORY:
|
|
self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)
|
|
assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."
|
|
|
|
def generate_forward_inputs(self):
|
|
hidden = (
|
|
torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
|
|
.uniform_(-0.5, 0.5)
|
|
.requires_grad_()
|
|
)
|
|
weight = (
|
|
torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
|
|
.uniform_(-0.5, 0.5)
|
|
.requires_grad_()
|
|
)
|
|
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
|
|
return hidden, weight, labels
|
|
|
|
def generate_backward_inputs(self):
|
|
g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
|
|
g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
|
|
return g_entropy, g_logprobs
|
|
|
|
def verify_torch_itself(self, iterations: int = 5):
|
|
self.cleanup()
|
|
self.generate_hyper()
|
|
|
|
for i in range(iterations):
|
|
hidden, weight, labels = self.generate_forward_inputs()
|
|
|
|
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
|
dist.broadcast(hidden, src=0, group=self.group)
|
|
dist.broadcast(labels, src=0, group=self.group)
|
|
|
|
# forward pass
|
|
# Create a tensor to hold the gathered weights from all ranks
|
|
# weight has shape [vocab_size, hidden_size]
|
|
# We want to gather along the first dimension to get [vocab_size * world_size, hidden_size]
|
|
|
|
# Create a single contiguous tensor to hold all gathered weights
|
|
whole_weight = torch.empty(
|
|
(self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device
|
|
)
|
|
|
|
# Create views into the tensor for each rank's portion
|
|
whole_weight_views = [
|
|
whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)
|
|
]
|
|
|
|
# Perform all_gather operation using the views
|
|
dist.all_gather(whole_weight_views, weight, group=self.group)
|
|
|
|
# Set requires_grad for autograd
|
|
whole_weight.requires_grad_()
|
|
|
|
(single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)
|
|
|
|
(tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
|
|
|
torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)
|
|
torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)
|
|
|
|
# backward pass
|
|
g_entropy, g_logprobs = self.generate_backward_inputs()
|
|
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
|
dist.broadcast(g_entropy, src=0, group=self.group)
|
|
dist.broadcast(g_logprobs, src=0, group=self.group)
|
|
|
|
(single_d_hidden, single_d_weight) = torch.autograd.grad(
|
|
(single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
|
|
(tp_d_hidden, tp_d_weight) = torch.autograd.grad(
|
|
(tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
# NOTE: all-reduce on hidden is conducted outside the kernel
|
|
dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
|
|
|
torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)
|
|
# Extract the corresponding slice from single_d_weight for comparison
|
|
# tp_d_weight has shape [vocab_size, hidden_size]
|
|
# single_d_weight has shape [vocab_size * world_size, hidden_size]
|
|
torch.testing.assert_close(
|
|
tp_d_weight,
|
|
single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],
|
|
atol=1e-2,
|
|
rtol=1e-4,
|
|
)
|
|
|
|
# atol=1e-3, rtol=1e-4)
|
|
if self.local_rank == 0:
|
|
print("[PASS] torch TP correctness is verified")
|
|
|
|
def check_torch_storage(self):
|
|
self.cleanup()
|
|
self.generate_hyper()
|
|
|
|
hidden, weight, labels = self.generate_forward_inputs()
|
|
|
|
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
|
dist.broadcast(hidden, src=0, group=self.group)
|
|
dist.broadcast(labels, src=0, group=self.group)
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
(tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
|
torch.cuda.synchronize()
|
|
forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
|
|
|
g_entropy, g_logprobs = self.generate_backward_inputs()
|
|
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
|
dist.broadcast(g_entropy, src=0, group=self.group)
|
|
dist.broadcast(g_logprobs, src=0, group=self.group)
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
(d_tp_hidden, d_tp_weight) = torch.autograd.grad(
|
|
(tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
torch.cuda.synchronize()
|
|
backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
|
# NOTE: all-reduce on hidden is conducted outside the kernel
|
|
dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
|
|
|
if self.local_rank == 0:
|
|
print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB")
|
|
print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB")
|
|
|
|
def verify_kernel_correctness(self, iterations: int = 5):
|
|
self.cleanup()
|
|
self.generate_hyper()
|
|
|
|
torch_forward_latency = list()
|
|
torch_backward_latency = list()
|
|
kernel_forward_latency = list()
|
|
kernel_backward_latency = list()
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
for i in range(iterations):
|
|
hidden, weight, labels = self.generate_forward_inputs()
|
|
|
|
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
|
dist.broadcast(hidden, src=0, group=self.group)
|
|
dist.broadcast(labels, src=0, group=self.group)
|
|
|
|
start_event.record()
|
|
(torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
torch_forward_latency.append(start_event.elapsed_time(end_event))
|
|
|
|
start_event.record()
|
|
(kernel_logprobs, kernel_entropy) = linear_cross_entropy(
|
|
hidden, weight, labels, self.temperature, "none", self.group
|
|
)
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
kernel_forward_latency.append(start_event.elapsed_time(end_event))
|
|
|
|
torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)
|
|
torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)
|
|
|
|
# backward pass
|
|
g_entropy, g_logprobs = self.generate_backward_inputs()
|
|
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
|
dist.broadcast(g_entropy, src=0, group=self.group)
|
|
dist.broadcast(g_logprobs, src=0, group=self.group)
|
|
|
|
start_event.record()
|
|
(torch_d_hidden, torch_d_weight) = torch.autograd.grad(
|
|
(torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
torch_backward_latency.append(start_event.elapsed_time(end_event))
|
|
# NOTE: all-reduce on hidden is conducted outside the kernel
|
|
dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
|
|
|
start_event.record()
|
|
(kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(
|
|
(kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
end_event.record()
|
|
torch.cuda.synchronize()
|
|
kernel_backward_latency.append(start_event.elapsed_time(end_event))
|
|
# NOTE: all-reduce on hidden is conducted outside the kernel
|
|
dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
|
|
|
torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)
|
|
torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)
|
|
|
|
# remove first latency
|
|
torch_forward_latency = torch_forward_latency[1:]
|
|
torch_backward_latency = torch_backward_latency[1:]
|
|
kernel_forward_latency = kernel_forward_latency[1:]
|
|
kernel_backward_latency = kernel_backward_latency[1:]
|
|
|
|
if self.local_rank == 0:
|
|
print("\n[PASS]: Verified kernel forward & backward correctness.")
|
|
|
|
print(
|
|
f"[INFO]: Forward pass: Torch implementation average time: "
|
|
f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
|
|
)
|
|
print(
|
|
f"[INFO]: Backward pass: torch implementation average time: "
|
|
f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
|
|
)
|
|
print(
|
|
f"[INFO]: Forward pass: Kernel implementation average time: "
|
|
f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
|
|
)
|
|
print(
|
|
f"[INFO]: Backward pass: kernel implementation average time: "
|
|
f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
|
|
)
|
|
|
|
def check_kernel_storage(self):
|
|
self.cleanup()
|
|
self.generate_hyper()
|
|
|
|
hidden, weight, labels = self.generate_forward_inputs()
|
|
|
|
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
|
dist.broadcast(hidden, src=0, group=self.group)
|
|
dist.broadcast(labels, src=0, group=self.group)
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
(kernel_logprobs, kernel_entropy) = linear_cross_entropy(
|
|
hidden, weight, labels, self.temperature, "none", self.group
|
|
)
|
|
torch.cuda.synchronize()
|
|
kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
|
|
|
g_entropy, g_logprobs = self.generate_backward_inputs()
|
|
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
|
dist.broadcast(g_entropy, src=0, group=self.group)
|
|
dist.broadcast(g_logprobs, src=0, group=self.group)
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
(d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
|
|
(kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
|
)
|
|
torch.cuda.synchronize()
|
|
kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
|
# NOTE: all-reduce on hidden is conducted outside the kernel
|
|
dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
|
|
|
if self.local_rank == 0:
|
|
print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB")
|
|
print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py
|
|
|
|
# Check if running with torchrun (distributed mode)
|
|
assert int(os.environ["WORLD_SIZE"]) > 1, (
|
|
"[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to "
|
|
"execute this script."
|
|
)
|
|
torch.manual_seed(233376 + int(os.environ.get("RANK", 0)))
|
|
|
|
# set_backward_method(BackwardEnum._Total_Fuse_MN)
|
|
# set_backward_method(BackwardEnum._Split_Dlogits_N)
|
|
|
|
test = TestLinearCrossEntropy_TensorParallel()
|
|
for test_case_idx in range(MAX_TEST_CASES):
|
|
print(f"[INFO] Running test case {test_case_idx}")
|
|
test.initialize(test_case_idx)
|
|
if VERIFY_TORCH_SELF:
|
|
test.verify_torch_itself()
|
|
test.check_torch_storage()
|
|
test.verify_kernel_correctness()
|
|
test.check_kernel_storage()
|
|
|
|
test.shutdown()
|