Compare commits

...

7 Commits

Author SHA1 Message Date
e793b9a70c Merge remote-tracking branch 'origin/main' into il_tool
Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-08 17:33:55 -07:00
76c9ec0ddf adjust config type and remove config path for simplicity
Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-08 17:23:15 -07:00
87c737016d Merge remote-tracking branch 'origin/main' into il_tool
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:48:28 -07:00
ba90794ff1 remove feature for il_tool_compare
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:47:16 -07:00
ab4ab0fd28 address arg utils fix
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:45:13 -07:00
2af83ebdde remove feature for metadata dump and input reload
Signed-off-by: Lucia Fang <fanglu@fb.com>
2025-08-05 09:25:17 -07:00
d8bff253d7 add il tool
more changes

Apply suggestions from code review

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

fix tp

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison tool

tmp

add unit test and fix format

Signed-off-by: Lu Fang <fanglu@fb.com>

add comparison script and documentation

Signed-off-by: Lu Fang <fanglu@fb.com>

provide default intermediate logging

Signed-off-by: Lu Fang <fanglu@fb.com>

optional register il

Signed-off-by: Lu Fang <fanglu@fb.com>

add input reload and improve intermediate compare
2025-07-28 18:32:10 -07:00
10 changed files with 966 additions and 5 deletions

View File

@ -0,0 +1,65 @@
# Intermediate Tensor Logging
This document provides guidance on using the intermediate tensor logging feature in vLLM, which allows you to capture and save intermediate tensors during model execution.
## Overview
The intermediate tensor logging feature enables you to:
- Log input and output tensors from a configured set of filters
- Filter modules by name using regex patterns
- Filter module fwd call index (e.g. dump 2nd call of forward pass on same module)
- Filter tensors by device
- Filter whole model fwd step id
## Usage
### Enabling via parameters or config file
**Offline Inference example**
Dump all modules, all devices for step 0 (default behavior)
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true}'
```
Dump first layers module, all devices for step 0
```bash
python3 ./examples/offline_inference/llm_engine_example.py --model "meta-llama/Llama-3.1-8B-Instruct" --enforce-eager --intermediate-log-config '{"enabled": true, "module_call_match": "layers\\.0\\."}'
```
#### Configuration Parameters
| Parameter | Type | Description | Default |
|-----------|------|-------------|---------|
| `output_dir` | string | Directory where to save the intermediate tensors | `/tmp/vllm_intermediates` |
| `module_call_match` | array | Regex patterns to filter module names, if limti to ith call only, add `:i` | `null` (log all modules) |
| `log_step_ids` | array | List of step IDs to log | `[0]` |
| `max_tensor_size` | integer | Maximum number of elements in tensors to log | `null` (no limit) |
| `device_names` | array | List of device names to log | `[]` (log all devices) |
### Output Directory Structure
When you enable intermediate logging, the system creates a timestamped directory under your specified `output_dir`. This helps organize multiple logging sessions:
```
/tmp/vllm_intermediates/010fed05-4a36-4c19-ab44-7cd67e3f63ce/
└── step_0
├── model.embed_tokens
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
├── model.layers.0.input_layernorm
│ ├── inputs_0_cuda_0.pt
│ ├── inputs.json
│ ├── outputs_cuda_0.pt
│ └── outputs.json
└── step_1/
└── ...
```
Each tensor is saved in a `.pt` file containing the full PyTorch tensors (can be loaded with `torch.load()`)

View File

@ -0,0 +1,320 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the intermediate tensor logging functionality.
"""
import json
import os
import shutil
import tempfile
from pathlib import Path
from unittest import mock
import pytest
import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (
get_current_il_config, get_step, increment_step, intermediate_logging,
register_intermediate_hooks, reset_step, should_log_device,
should_log_module, should_log_step)
class SimpleModel(nn.Module):
"""A simple model for testing."""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
@pytest.fixture
def temp_output_dir():
"""Create a temporary directory for test outputs."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir)
@pytest.fixture
def simple_model():
"""Create a simple model for testing."""
return SimpleModel()
@pytest.fixture
def il_config(temp_output_dir):
"""Create a basic IntermediateLoggingConfig for testing."""
return IntermediateLoggingConfig(output_dir=temp_output_dir,
enabled=True,
log_step_ids=[0, 1],
module_call_match=[".*linear.*"])
def test_step_counter():
"""Test the step counter functionality."""
# Reset the step counter
reset_step()
assert get_step() == 0
# Increment the step counter
increment_step()
assert get_step() == 1
# Increment again
increment_step()
assert get_step() == 2
# Reset again
reset_step()
assert get_step() == 0
def test_intermediate_logging_context_manager():
"""Test the intermediate_logging context manager."""
# Create a config
config = IntermediateLoggingConfig(enabled=True)
# Initially, there should be no global config
assert get_current_il_config() is None
# Use the context manager
with intermediate_logging(config):
# Inside the context, the global config should be set
assert get_current_il_config() is not None
assert get_current_il_config().enabled is True
# After the context, the global config should be None again
assert get_current_il_config() is None
# Test with a different config
config2 = IntermediateLoggingConfig(enabled=False)
with intermediate_logging(config2):
assert get_current_il_config() is not None
assert get_current_il_config().enabled is False
def test_should_log_step():
"""Test the should_log_step function."""
# Reset step counter
reset_step()
# Create configs with different step IDs
config_all_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[] # Empty list means log all steps
)
config_specific_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
)
config_disabled = IntermediateLoggingConfig(enabled=False,
log_step_ids=[0, 1, 2])
# Test with all steps config
with intermediate_logging(config_all_steps):
assert should_log_step(config_all_steps) is True # Step 0
increment_step()
assert should_log_step(config_all_steps) is True # Step 1
# Reset step counter
reset_step()
# Test with specific steps config
with intermediate_logging(config_specific_steps):
assert should_log_step(config_specific_steps) is True # Step 0
increment_step()
assert should_log_step(config_specific_steps) is False # Step 1
increment_step()
assert should_log_step(config_specific_steps) is True # Step 2
increment_step()
assert should_log_step(config_specific_steps) is False # Step 3
increment_step()
assert should_log_step(config_specific_steps) is True # Step 4
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_step(config_disabled) is False # Disabled
def test_should_log_device():
"""Test the should_log_device function."""
# Create configs with different device filters
config_all_devices = IntermediateLoggingConfig(
enabled=True,
device_names=[] # Empty list means log all devices
)
config_specific_devices = IntermediateLoggingConfig(
enabled=True,
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
)
config_disabled = IntermediateLoggingConfig(enabled=False,
device_names=["cuda:0", "cpu"])
# Test with all devices config
with intermediate_logging(config_all_devices):
assert should_log_device(config_all_devices, "cuda:0") is True
assert should_log_device(config_all_devices, "cuda:1") is True
assert should_log_device(config_all_devices, "cpu") is True
# Test with specific devices config
with intermediate_logging(config_specific_devices):
assert should_log_device(config_specific_devices, "cuda:0") is True
assert should_log_device(config_specific_devices, "cuda:1") is False
assert should_log_device(config_specific_devices, "cpu") is True
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_device(config_disabled, "cuda:0") is False
assert should_log_device(config_disabled, "cpu") is False
def test_should_log_module(simple_model):
"""Test the should_log_module function."""
# Create configs with different module name filters
config_all_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=None # None means log all modules
)
config_specific_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=[".*linear.*"
] # Only log modules with "linear" in the name
)
config_disabled = IntermediateLoggingConfig(enabled=False,
module_call_match=[".*"])
# Test with all modules config
with intermediate_logging(config_all_modules):
assert should_log_module(config_all_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_all_modules, "relu",
simple_model.relu) is True
# Test with specific modules config
with intermediate_logging(config_specific_modules):
assert should_log_module(config_specific_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_specific_modules, "relu",
simple_model.relu) is False
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_module(config_disabled, "linear1",
simple_model.linear1) is False
assert should_log_module(config_disabled, "relu",
simple_model.relu) is False
def test_register_hooks(simple_model, il_config):
"""Test registering hooks on a model."""
# Register hooks
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Check that hooks were registered
assert len(logger_instance.hooks) > 0
# Remove hooks
logger_instance.remove_hooks()
# Check that hooks were removed
assert len(logger_instance.hooks) == 0
@mock.patch(
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
"""Test that forward hooks are called during model execution."""
mock_save_tensors.return_value = None
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that the step counter was incremented
assert get_step() == 1
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
# Remove hooks
logger_instance.remove_hooks()
def test_end_to_end(simple_model, il_config, temp_output_dir):
"""Test the entire intermediate logging workflow end-to-end."""
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that output directories were created
root_dir = Path(il_config._output_run_dir)
assert root_dir.exists()
step_dir = root_dir / "step_0"
assert step_dir.exists()
module_dirs = list(step_dir.glob("*"))
print(f"{module_dirs=}")
assert len(module_dirs) > 0
# Check that input and output files were created
for module_dir in module_dirs:
print(f"{module_dir=}")
if os.path.isdir(module_dir):
inputs_json = module_dir / "inputs.json"
outputs_json = module_dir / "outputs.json"
# Check that JSON files exist
assert inputs_json.exists()
assert outputs_json.exists()
# Check that JSON files contain valid data
with open(inputs_json) as f:
inputs_data = json.load(f)
assert "type" in inputs_data
with open(outputs_json) as f:
outputs_data = json.load(f)
assert "type" in outputs_data
# Check that tensor files exist
tensor_files = list(module_dir.glob("*.pt"))
assert len(tensor_files) > 0
# Remove hooks
logger_instance.remove_hooks()
if __name__ == "__main__":
pytest.main(["-xvs", __file__])

View File

@ -3311,6 +3311,119 @@ class KVTransferConfig:
return self.kv_connector_extra_config.get(key, default)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class IntermediateLoggingConfig:
"""Configuration for intermediate tensor logging."""
output_dir: str = "/tmp/vllm_intermediates"
"""Directory where to save the intermediate tensors."""
module_call_match: Optional[list[str]] = None
"""Match modules by name regex and call index (
a module can be called multiple times in a step)
List of regex:call_idx, call_idx is -1 for default for all calls """
log_step_ids: list[int] = field(default_factory=lambda: [0])
"""List of step IDs to log (empty list means log all steps)."""
log_post_fwd_inputs: bool = False
"""Whether logging inputs after forwards for each module"""
max_tensor_size: Optional[int] = None
"""Maximum number of elements in tensors to log (None = no limit)."""
enabled: bool = True
"""Whether logging is enabled."""
device_names: list[str] = field(default_factory=list)
"""List of device names to log (empty list means log all devices)."""
_compiled_module_calls: dict[re.Pattern, int] = field(default_factory=dict,
init=False)
"""Compiled regex patterns for module filtering."""
_module_call: dict[str, int] = field(default_factory=dict, init=False)
_step_id_set: set[int] = field(default_factory=set, init=False)
"""Set of step IDs for faster lookup."""
_output_run_dir: str = "/tmp/vllm_intermediates"
"""Unique directory to save single run/serve logging result."""
def __post_init__(self):
"""Initialize derived fields after instance creation."""
self._compile_regex_patterns()
self._output_run_dir = self.output_dir + "/" + str(uuid.uuid4())
self._step_id_set = set(self.log_step_ids)
def _compile_regex_patterns(self):
"""Compile regex patterns for module name filtering."""
from vllm.logger import init_logger
logger = init_logger(__name__)
self._compiled_module_matches = []
if self.module_call_match is None:
logger.info(
"No module name regex patterns provided, will log all modules")
return
# Compile all patterns
for regex_pattern_call_idx in self.module_call_match:
try:
splits = regex_pattern_call_idx.split(":", 2)
regex_pattern = splits[0]
call_idx = -1
if len(splits) > 1:
call_idx = int(splits[1])
compiled_pattern: re.Pattern[str] = re.compile(regex_pattern)
self._compiled_module_calls[compiled_pattern] = call_idx
logger.info("Successfully compiled regex pattern: '%s'",
regex_pattern)
except Exception as e:
logger.error("Failed to parse module_call_match '%s': %s",
regex_pattern_call_idx, e)
logger.info("Compiled %d regex patterns",
len(self._compiled_module_calls))
def to_dict(self) -> dict:
"""Convert the config to a dictionary for serialization."""
return {
"output_run_dir": self.output_run_dir,
"module_call_match": self.module_call_match,
"log_step_ids": self.log_step_ids,
"max_tensor_size": self.max_tensor_size,
"enabled": self.enabled,
"device_names": self.device_names
}
@classmethod
def from_dict(cls, dict_value: dict) -> "IntermediateLoggingConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@property
def output_run_dir(self) -> str:
return self._output_run_dir
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# Intermediate logging doesn't affect the computation graph
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@ -3362,6 +3475,8 @@ class VllmConfig:
"""The configurations for distributed KV cache transfer."""
kv_events_config: Optional[KVEventsConfig] = None
"""The configurations for event publishing."""
intermediate_log_config: Optional[IntermediateLoggingConfig] = None
"""Configuration for intermediate tensor logging."""
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@ -3446,6 +3561,10 @@ class VllmConfig:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.intermediate_log_config:
vllm_factors.append(self.intermediate_log_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(

View File

@ -409,6 +409,7 @@ class EngineArgs:
speculative_config: Optional[Dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
otlp_traces_endpoint: Optional[str] = \
@ -456,6 +457,8 @@ class EngineArgs:
async_scheduling: bool = SchedulerConfig.async_scheduling
intermediate_log_config: Optional[dict[str, Any]] = None
kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill
@ -883,6 +886,9 @@ class EngineArgs:
title="VllmConfig",
description=VllmConfig.__doc__,
)
vllm_group.add_argument("--intermediate-log-config",
**vllm_kwargs["intermediate_log_config"])
# We construct SpeculativeConfig using fields from other configs in
# create_engine_config. So we set the type to a JSON string here to
# delay the Pydantic validation that comes with SpeculativeConfig.
@ -1394,7 +1400,6 @@ class EngineArgs:
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_detailed_traces=self.collect_detailed_traces,
)
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@ -1409,6 +1414,7 @@ class EngineArgs:
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
intermediate_log_config=self.intermediate_log_config,
additional_config=self.additional_config,
)

View File

@ -80,6 +80,10 @@ class EngineCore:
# Setup Model.
self.model_executor = executor_class(vllm_config)
if vllm_config.intermediate_log_config is not None:
self.collective_rpc("register_intermediate_hooks",
args=(vllm_config.intermediate_log_config, ))
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)

View File

View File

@ -0,0 +1,405 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Module for logging intermediate tensors during model execution.
This module provides functionality to capture and save intermediate tensors
(inputs and outputs) from PyTorch modules during forward passes.
"""
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Optional
import torch
from torch.utils.hooks import RemovableHandle
from vllm.config import IntermediateLoggingConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global step counter
_CURRENT_STEP = 0
_CURRENT_STEP_MODULE_CALL_STEP: dict[str, int] = {}
IL_MODULE_NAME = "_il_module_name"
IL_MODULE_CALL_IDX = "_il_module_call_idx"
# Utility functions for intermediate logging
def should_log_step(config):
"""Check if the current step should be logged based on the step IDs.
Args:
config: The IntermediateLoggingConfig instance.
Returns:
True if the current step should be logged, False otherwise.
"""
if not is_log_enabled(config):
return False
# If log_step_ids is empty, log all steps
if not config.log_step_ids:
return True
# Otherwise, check if current step is in the set of step IDs to log
return get_step() in config._step_id_set
def should_log_device(config, device_name):
"""Check if a device should be logged based on the device names.
Args:
config: The IntermediateLoggingConfig instance.
device_name: The name of the device to check (e.g., 'cuda:0', 'cpu').
Returns:
True if the device should be logged, False otherwise.
If device_names is empty, all devices are logged.
"""
if not is_log_enabled(config):
return False
# If device_names is empty, log all devices
if not config.device_names:
return True
# Otherwise, check if device_name is in the list of device names to log
return device_name in config.device_names
def should_log_module(config, module_name, module: torch.nn.Module) -> bool:
"""Check if a module should be logged based on the name regex patterns.
Args:
config: The IntermediateLoggingConfig instance.
module_name: The name of the module to check.
Returns:
True if the module should be logged, False otherwise.
If no patterns are defined, all modules are logged.
If patterns are defined, the module is logged if it matches ANY pattern.
"""
if not is_log_enabled(config):
return False
# If no patterns are defined, log all modules
if not config._compiled_module_calls:
set_il_module_name(module, module_name)
set_il_module_call_idx(module, -1)
return True
# Check if the module name matches any of the patterns
for pattern, call_idx in config._compiled_module_calls.items():
match = pattern.search(module_name)
if match:
logger.debug(
"Module %s, %s matches pattern: '%s', call_idx=%s",
module_name,
module.__class__.__name__,
pattern.pattern,
call_idx,
)
set_il_module_name(module, module_name)
set_il_module_call_idx(module, call_idx)
return True
return False
def is_log_enabled(config):
if not config or not config.enabled:
return False
if torch.compiler.is_compiling():
logger.debug("Not logging because torch.compile is in progress")
return False
return True
def get_il_module_name(module: torch.nn.Module) -> str:
return getattr(module, IL_MODULE_NAME, module.__class__.__name__)
def get_il_module_call_idx(module: torch.nn.Module) -> int:
return getattr(module, IL_MODULE_CALL_IDX, -1)
def set_il_module_name(module: torch.nn.Module, name: str) -> None:
setattr(module, IL_MODULE_NAME, name)
def set_il_module_call_idx(module: torch.nn.Module, idx: int) -> None:
setattr(module, IL_MODULE_CALL_IDX, idx)
_global_config: Optional[IntermediateLoggingConfig] = None
@contextmanager
def intermediate_logging(config: Optional[IntermediateLoggingConfig]):
"""
Temporarily sets the global config for the duration of the context.
:param config: Keyword arguments to set as global config
"""
global _global_config
old_config = _global_config
try:
_global_config = config
yield
finally:
_global_config = old_config
def get_current_il_config():
return _global_config
def save_tensors(tensor: Any, file_path: str) -> Any:
"""Utility function to dump tensor to a file.
Args:
tensor: The tensor to dump. Can be a torch.Tensor, a list/tuple of
tensors, or a dictionary containing tensors.
file_path: Base path where to save the tensor (without extension).
"""
if isinstance(tensor, torch.Tensor):
device_name = str(tensor.device)
intermediate_log_config = get_current_il_config()
if not should_log_device(intermediate_log_config, device_name):
return tensor
pt_path = f"{file_path}_{device_name.replace(':', '_')}.pt"
try:
torch.save(tensor, pt_path)
logger.debug("Saved tensor of shape %s to %s", tensor.shape,
pt_path)
except Exception as e:
logger.warning("Failed to save tensor to %s: %s", pt_path, e)
return tensor
if isinstance(tensor, (list, tuple)):
for i, item in enumerate(tensor):
save_tensors(item, f"{file_path}_{i}")
return tensor
if isinstance(tensor, dict):
for k, v in tensor.items():
save_tensors(v, f"{file_path}_{k}")
return tensor
def step_fwd(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to increment the global step counter after a forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if get_current_il_config() is None:
return
# Increment the global step counter
increment_step()
global _CURRENT_STEP_MODULE_CALL_STEP
_CURRENT_STEP_MODULE_CALL_STEP = {}
def _prepare_module_log_dir(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
is_pre_fwd: bool = False,
) -> Path:
# Create a unique directory for this step if not
dump_dir = Path(
intermediate_log_config.output_run_dir) / f"step_{get_step()}"
dump_dir.mkdir(exist_ok=True, parents=True)
# Create module directory
suffix = ""
module_call_idx = get_current_step_module_call(module_name)
if module_call_idx > 0:
suffix = f"_{module_call_idx}"
module_dir = dump_dir / (module_name + suffix)
if is_pre_fwd:
_log_module_call(intermediate_log_config, module_name + suffix)
module_dir.mkdir(exist_ok=True, parents=True)
logger.debug("Logging module %s inputs/outputs to %s", module_name,
module_dir)
return module_dir
def _log_module_call(
intermediate_log_config: IntermediateLoggingConfig,
module_name: str,
) -> None:
file = (Path(intermediate_log_config.output_run_dir) /
f"step_{get_step()}" / "module_calls.txt")
with open(file, "a") as f:
f.write(f"{module_name}\n")
def update_current_step_module_call(module_name: str) -> None:
logger.debug("Updating current step module call for %s", module_name)
global _CURRENT_STEP_MODULE_CALL_STEP
if module_name not in _CURRENT_STEP_MODULE_CALL_STEP:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] = 0
else:
_CURRENT_STEP_MODULE_CALL_STEP[module_name] += 1
def get_current_step_module_call(module_name: str) -> int:
return _CURRENT_STEP_MODULE_CALL_STEP.get(module_name, 0)
def prepare_log_current_fwd(module,
is_pre_fwd: bool = False) -> Optional[Path]:
intermediate_log_config = get_current_il_config()
if intermediate_log_config is None or not intermediate_log_config.enabled:
return None
if not should_log_step(intermediate_log_config):
return None
module_name = get_il_module_name(module)
log_call_idx = get_il_module_call_idx(module)
current_call_idx = get_current_step_module_call(module_name)
should_log = True
if log_call_idx >= 0 and current_call_idx != log_call_idx:
should_log = False
log_dir = None
if is_pre_fwd:
update_current_step_module_call(module_name)
if should_log:
log_dir = _prepare_module_log_dir(intermediate_log_config,
module_name,
is_pre_fwd=is_pre_fwd)
return log_dir
def log_pre_fwd_hook(module: torch.nn.Module,
inputs: tuple[Any, ...]) -> tuple[Any, ...]:
"""Hook to capture module inputs before forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
Returns:
The unchanged inputs.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=True):
save_tensors(inputs, str(log_dir / "inputs"))
return inputs
def log_post_fwd_hook(module: torch.nn.Module, inputs: tuple[Any, ...],
outputs: Any) -> None:
"""Hook to capture module outputs after forward pass.
Args:
module: The PyTorch module being executed.
inputs: The inputs to the module's forward function.
outputs: The outputs from the module's forward function.
"""
if log_dir := prepare_log_current_fwd(module, is_pre_fwd=False):
save_tensors(outputs, str(log_dir / "outputs"))
intermediate_log_config = get_current_il_config()
assert intermediate_log_config is not None, \
"IL config should not be None"
if intermediate_log_config.log_post_fwd_inputs:
save_tensors(inputs, str(log_dir / "post_fwd_inputs"))
def get_step() -> int:
"""Get the current global step counter.
Returns:
The current global step counter.
"""
return _CURRENT_STEP
def increment_step() -> int:
"""Increment the global step counter.
Returns:
The new step counter value.
"""
global _CURRENT_STEP
_CURRENT_STEP += 1
return _CURRENT_STEP
def reset_step() -> None:
"""Reset the global step counter to zero."""
global _CURRENT_STEP
_CURRENT_STEP = 0
class IntermediatesLogger:
"""Class to manage logging of intermediate tensors during model
execution."""
def __init__(self, config: IntermediateLoggingConfig):
self.config = config
self.hooks: list[tuple[str, str, Optional[RemovableHandle],
Optional[RemovableHandle]]] = []
logger.debug("Created IntermediatesLogger with config: %s", config)
path = Path(config.output_run_dir)
path.mkdir(exist_ok=True, parents=True)
# Log configuration
logger.info("Intermediates will be logged in %s",
config.output_run_dir)
def register_hooks(self, model: torch.nn.Module) -> None:
"""Register hooks for the model.
Args:
model: The PyTorch model to register hooks for.
"""
for name, module in model.named_modules():
if name and should_log_module(self.config, name, module):
pre_hook = module.register_forward_pre_hook(log_pre_fwd_hook)
logger.debug("Registered pre_fwd hook for %s",
module.__class__.__name__)
post_hook = module.register_forward_hook(log_post_fwd_hook)
logger.debug("Registered post_fwd hook for %s",
module.__class__.__name__)
self.hooks.append((name, module, pre_hook, post_hook))
# Register a step counter hook for the root model
step_hook = model.register_forward_hook(step_fwd)
self.hooks.append(("", model, None, step_hook))
logger.info("Registered hooks for %s modules", len(self.hooks))
def remove_hooks(self) -> None:
"""Remove all registered hooks."""
for _, _, pre_hook, post_hook in self.hooks:
if pre_hook is not None:
pre_hook.remove()
if post_hook is not None:
post_hook.remove()
logger.info("Removed %s hooks", len(self.hooks))
self.hooks = []
def register_intermediate_hooks(
model: torch.nn.Module,
config: Optional[IntermediateLoggingConfig] = None
) -> IntermediatesLogger:
"""Register hooks to log intermediate tensors for a model.
Args:
model: The PyTorch model to log intermediates for.
config: Configuration for intermediate logging. If provided, this takes
precedence over kwargs.
Returns:
An IntermediatesLogger instance that can be used to manage the hooks.
"""
logger_instance = IntermediatesLogger(config)
logger_instance.register_hooks(model)
return logger_instance

View File

@ -27,6 +27,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.intermediates.intermediates_logging import intermediate_logging
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, ModelRunnerOutput)
@ -362,7 +363,7 @@ class Worker(WorkerBase):
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
with intermediate_logging(self.vllm_config.intermediate_log_config):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):

View File

@ -6,8 +6,10 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config import IntermediateLoggingConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.intermediates.intermediates_logging import (
register_intermediate_hooks)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
@ -63,3 +65,26 @@ class WorkerBase(WorkerBaseV0):
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return
def register_intermediate_hooks(
self, config: Optional[IntermediateLoggingConfig] = None) -> None:
"""Register hooks for intermediate tensor logging.
This method is called via collective_rpc from the engine core.
It registers hooks on the model to dump intermediate tensors during
execution.
Args:
config: Configuration for intermediate logging. If provided, this
takes precedence over kwargs.
"""
if self.model_runner is None or not hasattr(
self.model_runner, "model") or self.model_runner.model is None:
logger.error("Could not register intermediate hooks: "
"model_runner.model is not accessible")
return
model = self.model_runner.model
try:
register_intermediate_hooks(model, config)
except Exception:
logger.exception("Error registering intermediate hooks")

View File

@ -129,6 +129,22 @@ class WorkerBase:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
def register_intermediate_hooks(self, config=None) -> None:
"""Register hooks for intermediate tensor logging.
This method is a stub for v0 workers. The actual implementation is
in v1 workers. It's included here for compatibility with the
collective_rpc mechanism.
Args:
config: Configuration for intermediate logging.
"""
logger.warning(
"register_intermediate_hooks is not implemented in v0 workers. "
"This is only available in v1 workers. No hooks will be registered."
)
return None
def shutdown(self) -> None:
"""Clean up resources held by the worker."""
return