mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
7 Commits
v0.11.0rc5
...
il_tool
Author | SHA1 | Date | |
---|---|---|---|
e793b9a70c | |||
76c9ec0ddf | |||
87c737016d | |||
ba90794ff1 | |||
ab4ab0fd28 | |||
2af83ebdde | |||
d8bff253d7 |
65
docs/contributing/intermediate_logging.md
Normal file
65
docs/contributing/intermediate_logging.md
Normal file
@ -0,0 +1,65 @@
|
||||
# Intermediate Tensor Logging
|
||||
|
||||
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
|
||||
|
||||
## Overview
|
||||
|
||||
The intermediate tensor logging feature enables you to:
|
||||
|
||||
- Log input and output tensors from a configured set of filters
|
||||
- Filter modules by name using regex patterns
|
||||
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
|
||||
- Filter tensors by device
|
||||
- Filter whole model fwd step id
|
||||
|
||||
## Usage
|
||||
|
||||
### Enabling via parameters or config file
|
||||
|
||||
**Offline Inference example**
|
||||
|
||||
Dump all modules, all devices for step 0 (default behavior)
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
|
||||
```
|
||||
|
||||
Dump first layers module, all devices for step 0
|
||||
|
||||
```bash
|
||||
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
|
||||
```
|
||||
|
||||
|
||||
#### Configuration Parameters
|
||||
|
||||
| Parameter | Type | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
|
||||
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
|
||||
| `log_step_ids` | array | List of step IDs to log | `[0]` |
|
||||
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
|
||||
| `device_names` | array | List of device names to log | `[]` (log all devices) |
|
||||
|
||||
### Output Directory Structure
|
||||
|
||||
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
|
||||
|
||||
```
|
||||
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
|
||||
└── step_0
|
||||
├── model.embed_tokens
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
├── model.layers.0.input_layernorm
|
||||
│ ├── inputs_0_cuda_0.pt
|
||||
│ ├── inputs.json
|
||||
│ ├── outputs_cuda_0.pt
|
||||
│ └── outputs.json
|
||||
└── step_1/
|
||||
└── ...
|
||||
```
|
||||
|
||||
Each tensor is saved in a `.pt` file containing the full PyTorch tensors (can be loaded with `torch.load()`)
|
320
tests/v1/test_intermediates_logging.py
Normal file
320
tests/v1/test_intermediates_logging.py
Normal file
@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the intermediate tensor logging functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
get_current_il_config, get_step, increment_step, intermediate_logging,
|
||||
register_intermediate_hooks, reset_step, should_log_device,
|
||||
should_log_module, should_log_step)
|
||||
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
"""A simple model for testing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(10, 20)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(20, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""Create a temporary directory for test outputs."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
# Clean up after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_model():
|
||||
"""Create a simple model for testing."""
|
||||
return SimpleModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def il_config(temp_output_dir):
|
||||
"""Create a basic IntermediateLoggingConfig for testing."""
|
||||
return IntermediateLoggingConfig(output_dir=temp_output_dir,
|
||||
enabled=True,
|
||||
log_step_ids=[0, 1],
|
||||
module_call_match=[".*linear.*"])
|
||||
|
||||
|
||||
def test_step_counter():
|
||||
"""Test the step counter functionality."""
|
||||
# Reset the step counter
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
# Increment the step counter
|
||||
increment_step()
|
||||
assert get_step() == 1
|
||||
|
||||
# Increment again
|
||||
increment_step()
|
||||
assert get_step() == 2
|
||||
|
||||
# Reset again
|
||||
reset_step()
|
||||
assert get_step() == 0
|
||||
|
||||
|
||||
def test_intermediate_logging_context_manager():
|
||||
"""Test the intermediate_logging context manager."""
|
||||
# Create a config
|
||||
config = IntermediateLoggingConfig(enabled=True)
|
||||
|
||||
# Initially, there should be no global config
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Use the context manager
|
||||
with intermediate_logging(config):
|
||||
# Inside the context, the global config should be set
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is True
|
||||
|
||||
# After the context, the global config should be None again
|
||||
assert get_current_il_config() is None
|
||||
|
||||
# Test with a different config
|
||||
config2 = IntermediateLoggingConfig(enabled=False)
|
||||
with intermediate_logging(config2):
|
||||
assert get_current_il_config() is not None
|
||||
assert get_current_il_config().enabled is False
|
||||
|
||||
|
||||
def test_should_log_step():
|
||||
"""Test the should_log_step function."""
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Create configs with different step IDs
|
||||
config_all_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[] # Empty list means log all steps
|
||||
)
|
||||
config_specific_steps = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
log_step_ids=[0, 1, 2])
|
||||
|
||||
# Test with all steps config
|
||||
with intermediate_logging(config_all_steps):
|
||||
assert should_log_step(config_all_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_all_steps) is True # Step 1
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Test with specific steps config
|
||||
with intermediate_logging(config_specific_steps):
|
||||
assert should_log_step(config_specific_steps) is True # Step 0
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 1
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 2
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is False # Step 3
|
||||
increment_step()
|
||||
assert should_log_step(config_specific_steps) is True # Step 4
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_step(config_disabled) is False # Disabled
|
||||
|
||||
|
||||
def test_should_log_device():
|
||||
"""Test the should_log_device function."""
|
||||
# Create configs with different device filters
|
||||
config_all_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=[] # Empty list means log all devices
|
||||
)
|
||||
config_specific_devices = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
device_names=["cuda:0", "cpu"])
|
||||
|
||||
# Test with all devices config
|
||||
with intermediate_logging(config_all_devices):
|
||||
assert should_log_device(config_all_devices, "cuda:0") is True
|
||||
assert should_log_device(config_all_devices, "cuda:1") is True
|
||||
assert should_log_device(config_all_devices, "cpu") is True
|
||||
|
||||
# Test with specific devices config
|
||||
with intermediate_logging(config_specific_devices):
|
||||
assert should_log_device(config_specific_devices, "cuda:0") is True
|
||||
assert should_log_device(config_specific_devices, "cuda:1") is False
|
||||
assert should_log_device(config_specific_devices, "cpu") is True
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_device(config_disabled, "cuda:0") is False
|
||||
assert should_log_device(config_disabled, "cpu") is False
|
||||
|
||||
|
||||
def test_should_log_module(simple_model):
|
||||
"""Test the should_log_module function."""
|
||||
# Create configs with different module name filters
|
||||
config_all_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=None # None means log all modules
|
||||
)
|
||||
config_specific_modules = IntermediateLoggingConfig(
|
||||
enabled=True,
|
||||
module_call_match=[".*linear.*"
|
||||
] # Only log modules with "linear" in the name
|
||||
)
|
||||
config_disabled = IntermediateLoggingConfig(enabled=False,
|
||||
module_call_match=[".*"])
|
||||
|
||||
# Test with all modules config
|
||||
with intermediate_logging(config_all_modules):
|
||||
assert should_log_module(config_all_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_all_modules, "relu",
|
||||
simple_model.relu) is True
|
||||
|
||||
# Test with specific modules config
|
||||
with intermediate_logging(config_specific_modules):
|
||||
assert should_log_module(config_specific_modules, "linear1",
|
||||
simple_model.linear1) is True
|
||||
assert should_log_module(config_specific_modules, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
# Test with disabled config
|
||||
with intermediate_logging(config_disabled):
|
||||
assert should_log_module(config_disabled, "linear1",
|
||||
simple_model.linear1) is False
|
||||
assert should_log_module(config_disabled, "relu",
|
||||
simple_model.relu) is False
|
||||
|
||||
|
||||
def test_register_hooks(simple_model, il_config):
|
||||
"""Test registering hooks on a model."""
|
||||
# Register hooks
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Check that hooks were registered
|
||||
assert len(logger_instance.hooks) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
# Check that hooks were removed
|
||||
assert len(logger_instance.hooks) == 0
|
||||
|
||||
|
||||
@mock.patch(
|
||||
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
|
||||
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
|
||||
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
|
||||
il_config, temp_output_dir):
|
||||
"""Test that forward hooks are called during model execution."""
|
||||
mock_save_tensors.return_value = None
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that the step counter was incremented
|
||||
assert get_step() == 1
|
||||
|
||||
# Check that dump_intermediates_to_json and save_tensors were called
|
||||
assert mock_dump_json.called
|
||||
assert mock_save_tensors.called
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
def test_end_to_end(simple_model, il_config, temp_output_dir):
|
||||
"""Test the entire intermediate logging workflow end-to-end."""
|
||||
# Register hooks
|
||||
with intermediate_logging(il_config):
|
||||
logger_instance = register_intermediate_hooks(simple_model, il_config)
|
||||
|
||||
# Create input tensor
|
||||
input_tensor = torch.randn(2, 10)
|
||||
|
||||
# Reset step counter
|
||||
reset_step()
|
||||
|
||||
# Forward pass
|
||||
simple_model(input_tensor)
|
||||
|
||||
# Check that output directories were created
|
||||
root_dir = Path(il_config._output_run_dir)
|
||||
assert root_dir.exists()
|
||||
step_dir = root_dir / "step_0"
|
||||
assert step_dir.exists()
|
||||
|
||||
module_dirs = list(step_dir.glob("*"))
|
||||
print(f"{module_dirs=}")
|
||||
assert len(module_dirs) > 0
|
||||
|
||||
# Check that input and output files were created
|
||||
for module_dir in module_dirs:
|
||||
print(f"{module_dir=}")
|
||||
if os.path.isdir(module_dir):
|
||||
inputs_json = module_dir / "inputs.json"
|
||||
outputs_json = module_dir / "outputs.json"
|
||||
|
||||
# Check that JSON files exist
|
||||
assert inputs_json.exists()
|
||||
assert outputs_json.exists()
|
||||
|
||||
# Check that JSON files contain valid data
|
||||
with open(inputs_json) as f:
|
||||
inputs_data = json.load(f)
|
||||
assert "type" in inputs_data
|
||||
|
||||
with open(outputs_json) as f:
|
||||
outputs_data = json.load(f)
|
||||
assert "type" in outputs_data
|
||||
|
||||
# Check that tensor files exist
|
||||
tensor_files = list(module_dir.glob("*.pt"))
|
||||
assert len(tensor_files) > 0
|
||||
|
||||
# Remove hooks
|
||||
logger_instance.remove_hooks()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
@ -3311,6 +3311,119 @@ class KVTransferConfig:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class IntermediateLoggingConfig:
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
|
||||
output_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Directory where to save the intermediate tensors."""
|
||||
|
||||
module_call_match: Optional[list[str]] = None
|
||||
"""Match modules by name regex and call index (
|
||||
a module can be called multiple times in a step)
|
||||
List of regex:call_idx, call_idx is -1 for default for all calls """
|
||||
|
||||
log_step_ids: list[int] = field(default_factory=lambda: [0])
|
||||
"""List of step IDs to log (empty list means log all steps)."""
|
||||
|
||||
log_post_fwd_inputs: bool = False
|
||||
"""Whether logging inputs after forwards for each module"""
|
||||
|
||||
max_tensor_size: Optional[int] = None
|
||||
"""Maximum number of elements in tensors to log (None = no limit)."""
|
||||
|
||||
enabled: bool = True
|
||||
"""Whether logging is enabled."""
|
||||
device_names: list[str] = field(default_factory=list)
|
||||
"""List of device names to log (empty list means log all devices)."""
|
||||
|
||||
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
|
||||
init=False)
|
||||
"""Compiled regex patterns for module filtering."""
|
||||
|
||||
_module_call: dict[str, int] = field(default_factory=dict, init=False)
|
||||
_step_id_set: set[int] = field(default_factory=set, init=False)
|
||||
"""Set of step IDs for faster lookup."""
|
||||
_output_run_dir: str = "/tmp/vllm_intermediates"
|
||||
"""Unique directory to save single run/serve logging result."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize derived fields after instance creation."""
|
||||
self._compile_regex_patterns()
|
||||
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
|
||||
self._step_id_set = set(self.log_step_ids)
|
||||
|
||||
def _compile_regex_patterns(self):
|
||||
"""Compile regex patterns for module name filtering."""
|
||||
from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
self._compiled_module_matches = []
|
||||
|
||||
if self.module_call_match is None:
|
||||
logger.info(
|
||||
"No module name regex patterns provided, will log all modules")
|
||||
return
|
||||
|
||||
# Compile all patterns
|
||||
for regex_pattern_call_idx in self.module_call_match:
|
||||
try:
|
||||
splits = regex_pattern_call_idx.split(":", 2)
|
||||
regex_pattern = splits[0]
|
||||
call_idx = -1
|
||||
if len(splits) > 1:
|
||||
call_idx = int(splits[1])
|
||||
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
|
||||
self._compiled_module_calls[compiled_pattern] = call_idx
|
||||
logger.info("Successfully compiled regex pattern: '%s'",
|
||||
regex_pattern)
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse module_call_match '%s': %s",
|
||||
regex_pattern_call_idx, e)
|
||||
|
||||
logger.info("Compiled %d regex patterns",
|
||||
len(self._compiled_module_calls))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the config to a dictionary for serialization."""
|
||||
return {
|
||||
"output_run_dir": self.output_run_dir,
|
||||
"module_call_match": self.module_call_match,
|
||||
"log_step_ids": self.log_step_ids,
|
||||
"max_tensor_size": self.max_tensor_size,
|
||||
"enabled": self.enabled,
|
||||
"device_names": self.device_names
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
@property
|
||||
def output_run_dir(self) -> str:
|
||||
return self._output_run_dir
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# Intermediate logging doesn't affect the computation graph
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
@ -3362,6 +3475,8 @@ class VllmConfig:
|
||||
"""The configurations for distributed KV cache transfer."""
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
"""The configurations for event publishing."""
|
||||
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
|
||||
"""Configuration for intermediate tensor logging."""
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -3446,6 +3561,10 @@ class VllmConfig:
|
||||
vllm_factors.append(self.kv_transfer_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.intermediate_log_config:
|
||||
vllm_factors.append(self.intermediate_log_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
|
@ -409,6 +409,7 @@ class EngineArgs:
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = \
|
||||
ObservabilityConfig.show_hidden_metrics_for_version
|
||||
otlp_traces_endpoint: Optional[str] = \
|
||||
@ -456,6 +457,8 @@ class EngineArgs:
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
|
||||
intermediate_log_config: Optional[dict[str, Any]] = None
|
||||
|
||||
kv_sharing_fast_prefill: bool = \
|
||||
CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
@ -883,6 +886,9 @@ class EngineArgs:
|
||||
title="VllmConfig",
|
||||
description=VllmConfig.__doc__,
|
||||
)
|
||||
|
||||
vllm_group.add_argument("--intermediate-log-config",
|
||||
**vllm_kwargs["intermediate_log_config"])
|
||||
# We construct SpeculativeConfig using fields from other configs in
|
||||
# create_engine_config. So we set the type to a JSON string here to
|
||||
# delay the Pydantic validation that comes with SpeculativeConfig.
|
||||
@ -1394,7 +1400,6 @@ class EngineArgs:
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
collect_detailed_traces=self.collect_detailed_traces,
|
||||
)
|
||||
|
||||
config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
@ -1409,6 +1414,7 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
intermediate_log_config=self.intermediate_log_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
@ -80,6 +80,10 @@ class EngineCore:
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
if vllm_config.intermediate_log_config is not None:
|
||||
self.collective_rpc("register_intermediate_hooks",
|
||||
args=(vllm_config.intermediate_log_config, ))
|
||||
|
||||
if executor_fail_callback is not None:
|
||||
self.model_executor.register_failure_callback(
|
||||
executor_fail_callback)
|
||||
|
0
vllm/v1/intermediates/__init__.py
Normal file
0
vllm/v1/intermediates/__init__.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
405
vllm/v1/intermediates/intermediates_logging.py
Normal file
@ -0,0 +1,405 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Module for logging intermediate tensors during model execution.
|
||||
|
||||
This module provides functionality to capture and save intermediate tensors
|
||||
(inputs and outputs) from PyTorch modules during forward passes.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from vllm.config import IntermediateLoggingConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Global step counter
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
|
||||
|
||||
IL_MODULE_NAME = "_il_module_name"
|
||||
IL_MODULE_CALL_IDX = "_il_module_call_idx"
|
||||
|
||||
# Utility functions for intermediate logging
|
||||
|
||||
|
||||
def should_log_step(config):
|
||||
"""Check if the current step should be logged based on the step IDs.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
|
||||
Returns:
|
||||
True if the current step should be logged, False otherwise.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
|
||||
# If log_step_ids is empty, log all steps
|
||||
if not config.log_step_ids:
|
||||
return True
|
||||
|
||||
# Otherwise, check if current step is in the set of step IDs to log
|
||||
return get_step() in config._step_id_set
|
||||
|
||||
|
||||
def should_log_device(config, device_name):
|
||||
"""Check if a device should be logged based on the device names.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
|
||||
|
||||
Returns:
|
||||
True if the device should be logged, False otherwise.
|
||||
If device_names is empty, all devices are logged.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If device_names is empty, log all devices
|
||||
if not config.device_names:
|
||||
return True
|
||||
|
||||
# Otherwise, check if device_name is in the list of device names to log
|
||||
return device_name in config.device_names
|
||||
|
||||
|
||||
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
|
||||
"""Check if a module should be logged based on the name regex patterns.
|
||||
|
||||
Args:
|
||||
config: The IntermediateLoggingConfig instance.
|
||||
module_name: The name of the module to check.
|
||||
|
||||
Returns:
|
||||
True if the module should be logged, False otherwise.
|
||||
If no patterns are defined, all modules are logged.
|
||||
If patterns are defined, the module is logged if it matches ANY pattern.
|
||||
"""
|
||||
if not is_log_enabled(config):
|
||||
return False
|
||||
# If no patterns are defined, log all modules
|
||||
if not config._compiled_module_calls:
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, -1)
|
||||
return True
|
||||
|
||||
# Check if the module name matches any of the patterns
|
||||
for pattern, call_idx in config._compiled_module_calls.items():
|
||||
match = pattern.search(module_name)
|
||||
if match:
|
||||
logger.debug(
|
||||
"Module %s, %s matches pattern: '%s', call_idx=%s",
|
||||
module_name,
|
||||
module.__class__.__name__,
|
||||
pattern.pattern,
|
||||
call_idx,
|
||||
)
|
||||
set_il_module_name(module, module_name)
|
||||
set_il_module_call_idx(module, call_idx)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_log_enabled(config):
|
||||
if not config or not config.enabled:
|
||||
return False
|
||||
if torch.compiler.is_compiling():
|
||||
logger.debug("Not logging because torch.compile is in progress")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_il_module_name(module: torch.nn.Module) -> str:
|
||||
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
|
||||
|
||||
|
||||
def get_il_module_call_idx(module: torch.nn.Module) -> int:
|
||||
return getattr(module, IL_MODULE_CALL_IDX, -1)
|
||||
|
||||
|
||||
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
|
||||
setattr(module, IL_MODULE_NAME, name)
|
||||
|
||||
|
||||
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
|
||||
setattr(module, IL_MODULE_CALL_IDX, idx)
|
||||
|
||||
|
||||
_global_config: Optional[IntermediateLoggingConfig] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
|
||||
"""
|
||||
Temporarily sets the global config for the duration of the context.
|
||||
:param config: Keyword arguments to set as global config
|
||||
"""
|
||||
global _global_config
|
||||
old_config = _global_config
|
||||
try:
|
||||
_global_config = config
|
||||
yield
|
||||
finally:
|
||||
_global_config = old_config
|
||||
|
||||
|
||||
def get_current_il_config():
|
||||
return _global_config
|
||||
|
||||
|
||||
def save_tensors(tensor: Any, file_path: str) -> Any:
|
||||
"""Utility function to dump tensor to a file.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
|
||||
tensors, or a dictionary containing tensors.
|
||||
file_path: Base path where to save the tensor (without extension).
|
||||
"""
|
||||
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
device_name = str(tensor.device)
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if not should_log_device(intermediate_log_config, device_name):
|
||||
return tensor
|
||||
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
|
||||
try:
|
||||
torch.save(tensor, pt_path)
|
||||
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
|
||||
pt_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
|
||||
return tensor
|
||||
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
for i, item in enumerate(tensor):
|
||||
save_tensors(item, f"{file_path}_{i}")
|
||||
return tensor
|
||||
if isinstance(tensor, dict):
|
||||
for k, v in tensor.items():
|
||||
save_tensors(v, f"{file_path}_{k}")
|
||||
return tensor
|
||||
|
||||
|
||||
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to increment the global step counter after a forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if get_current_il_config() is None:
|
||||
return
|
||||
# Increment the global step counter
|
||||
increment_step()
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
_CURRENT_STEP_MODULE_CALL_STEP = {}
|
||||
|
||||
|
||||
def _prepare_module_log_dir(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
is_pre_fwd: bool = False,
|
||||
) -> Path:
|
||||
# Create a unique directory for this step if not
|
||||
dump_dir = Path(
|
||||
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
|
||||
dump_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Create module directory
|
||||
suffix = ""
|
||||
module_call_idx = get_current_step_module_call(module_name)
|
||||
if module_call_idx > 0:
|
||||
suffix = f"_{module_call_idx}"
|
||||
module_dir = dump_dir / (module_name + suffix)
|
||||
if is_pre_fwd:
|
||||
_log_module_call(intermediate_log_config, module_name + suffix)
|
||||
module_dir.mkdir(exist_ok=True, parents=True)
|
||||
logger.debug("Logging module %s inputs/outputs to %s", module_name,
|
||||
module_dir)
|
||||
return module_dir
|
||||
|
||||
|
||||
def _log_module_call(
|
||||
intermediate_log_config: IntermediateLoggingConfig,
|
||||
module_name: str,
|
||||
) -> None:
|
||||
file = (Path(intermediate_log_config.output_run_dir) /
|
||||
f"step_{get_step()}" / "module_calls.txt")
|
||||
with open(file, "a") as f:
|
||||
f.write(f"{module_name}\n")
|
||||
|
||||
|
||||
def update_current_step_module_call(module_name: str) -> None:
|
||||
logger.debug("Updating current step module call for %s", module_name)
|
||||
global _CURRENT_STEP_MODULE_CALL_STEP
|
||||
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
|
||||
else:
|
||||
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
|
||||
|
||||
|
||||
def get_current_step_module_call(module_name: str) -> int:
|
||||
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
|
||||
|
||||
|
||||
def prepare_log_current_fwd(module,
|
||||
is_pre_fwd: bool = False) -> Optional[Path]:
|
||||
intermediate_log_config = get_current_il_config()
|
||||
if intermediate_log_config is None or not intermediate_log_config.enabled:
|
||||
return None
|
||||
if not should_log_step(intermediate_log_config):
|
||||
return None
|
||||
|
||||
module_name = get_il_module_name(module)
|
||||
log_call_idx = get_il_module_call_idx(module)
|
||||
current_call_idx = get_current_step_module_call(module_name)
|
||||
should_log = True
|
||||
if log_call_idx >= 0 and current_call_idx != log_call_idx:
|
||||
should_log = False
|
||||
|
||||
log_dir = None
|
||||
if is_pre_fwd:
|
||||
update_current_step_module_call(module_name)
|
||||
if should_log:
|
||||
log_dir = _prepare_module_log_dir(intermediate_log_config,
|
||||
module_name,
|
||||
is_pre_fwd=is_pre_fwd)
|
||||
return log_dir
|
||||
|
||||
|
||||
def log_pre_fwd_hook(module: torch.nn.Module,
|
||||
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
"""Hook to capture module inputs before forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
|
||||
Returns:
|
||||
The unchanged inputs.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
|
||||
save_tensors(inputs, str(log_dir / "inputs"))
|
||||
return inputs
|
||||
|
||||
|
||||
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
|
||||
outputs: Any) -> None:
|
||||
"""Hook to capture module outputs after forward pass.
|
||||
|
||||
Args:
|
||||
module: The PyTorch module being executed.
|
||||
inputs: The inputs to the module's forward function.
|
||||
outputs: The outputs from the module's forward function.
|
||||
"""
|
||||
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
|
||||
save_tensors(outputs, str(log_dir / "outputs"))
|
||||
intermediate_log_config = get_current_il_config()
|
||||
assert intermediate_log_config is not None, \
|
||||
"IL config should not be None"
|
||||
if intermediate_log_config.log_post_fwd_inputs:
|
||||
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
|
||||
|
||||
|
||||
def get_step() -> int:
|
||||
"""Get the current global step counter.
|
||||
|
||||
Returns:
|
||||
The current global step counter.
|
||||
"""
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def increment_step() -> int:
|
||||
"""Increment the global step counter.
|
||||
|
||||
Returns:
|
||||
The new step counter value.
|
||||
"""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP += 1
|
||||
return _CURRENT_STEP
|
||||
|
||||
|
||||
def reset_step() -> None:
|
||||
"""Reset the global step counter to zero."""
|
||||
global _CURRENT_STEP
|
||||
_CURRENT_STEP = 0
|
||||
|
||||
|
||||
class IntermediatesLogger:
|
||||
"""Class to manage logging of intermediate tensors during model
|
||||
execution."""
|
||||
|
||||
def __init__(self, config: IntermediateLoggingConfig):
|
||||
self.config = config
|
||||
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
|
||||
Optional[RemovableHandle]]] = []
|
||||
logger.debug("Created IntermediatesLogger with config: %s", config)
|
||||
path = Path(config.output_run_dir)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
# Log configuration
|
||||
logger.info("Intermediates will be logged in %s",
|
||||
config.output_run_dir)
|
||||
|
||||
def register_hooks(self, model: torch.nn.Module) -> None:
|
||||
"""Register hooks for the model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to register hooks for.
|
||||
"""
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name and should_log_module(self.config, name, module):
|
||||
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
|
||||
logger.debug("Registered pre_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
post_hook = module.register_forward_hook(log_post_fwd_hook)
|
||||
logger.debug("Registered post_fwd hook for %s",
|
||||
module.__class__.__name__)
|
||||
self.hooks.append((name, module, pre_hook, post_hook))
|
||||
|
||||
# Register a step counter hook for the root model
|
||||
step_hook = model.register_forward_hook(step_fwd)
|
||||
self.hooks.append(("", model, None, step_hook))
|
||||
logger.info("Registered hooks for %s modules", len(self.hooks))
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""Remove all registered hooks."""
|
||||
for _, _, pre_hook, post_hook in self.hooks:
|
||||
if pre_hook is not None:
|
||||
pre_hook.remove()
|
||||
if post_hook is not None:
|
||||
post_hook.remove()
|
||||
|
||||
logger.info("Removed %s hooks", len(self.hooks))
|
||||
self.hooks = []
|
||||
|
||||
|
||||
def register_intermediate_hooks(
|
||||
model: torch.nn.Module,
|
||||
config: Optional[IntermediateLoggingConfig] = None
|
||||
) -> IntermediatesLogger:
|
||||
"""Register hooks to log intermediate tensors for a model.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to log intermediates for.
|
||||
config: Configuration for intermediate logging. If provided, this takes
|
||||
precedence over kwargs.
|
||||
|
||||
Returns:
|
||||
An IntermediatesLogger instance that can be used to manage the hooks.
|
||||
"""
|
||||
logger_instance = IntermediatesLogger(config)
|
||||
logger_instance.register_hooks(model)
|
||||
return logger_instance
|
@ -27,6 +27,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, ModelRunnerOutput)
|
||||
@ -362,9 +363,9 @@ class Worker(WorkerBase):
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
with intermediate_logging(self.vllm_config.intermediate_log_config):
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
|
||||
return output
|
||||
|
||||
|
@ -6,8 +6,10 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import IntermediateLoggingConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.intermediates.intermediates_logging import (
|
||||
register_intermediate_hooks)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
@ -63,3 +65,26 @@ class WorkerBase(WorkerBaseV0):
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
|
||||
def register_intermediate_hooks(
|
||||
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is called via collective_rpc from the engine core.
|
||||
It registers hooks on the model to dump intermediate tensors during
|
||||
execution.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging. If provided, this
|
||||
takes precedence over kwargs.
|
||||
"""
|
||||
if self.model_runner is None or not hasattr(
|
||||
self.model_runner, "model") or self.model_runner.model is None:
|
||||
logger.error("Could not register intermediate hooks: "
|
||||
"model_runner.model is not accessible")
|
||||
return
|
||||
model = self.model_runner.model
|
||||
try:
|
||||
register_intermediate_hooks(model, config)
|
||||
except Exception:
|
||||
logger.exception("Error registering intermediate hooks")
|
||||
|
@ -129,6 +129,22 @@ class WorkerBase:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def register_intermediate_hooks(self, config=None) -> None:
|
||||
"""Register hooks for intermediate tensor logging.
|
||||
|
||||
This method is a stub for v0 workers. The actual implementation is
|
||||
in v1 workers. It's included here for compatibility with the
|
||||
collective_rpc mechanism.
|
||||
|
||||
Args:
|
||||
config: Configuration for intermediate logging.
|
||||
"""
|
||||
logger.warning(
|
||||
"register_intermediate_hooks is not implemented in v0 workers. "
|
||||
"This is only available in v1 workers. No hooks will be registered."
|
||||
)
|
||||
return None
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Clean up resources held by the worker."""
|
||||
return
|
||||
|
Reference in New Issue
Block a user