Files
vllm/tests/v1/test_intermediates_logging.py
2025-08-05 09:25:17 -07:00

321 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the intermediate tensor logging functionality.
"""
import json
import os
import shutil
import tempfile
from pathlib import Path
from unittest import mock
import pytest
import torch
import torch.nn as nn
from vllm.config import IntermediateLoggingConfig
from vllm.v1.intermediates.intermediates_logging import (
get_current_il_config, get_step, increment_step, intermediate_logging,
register_intermediate_hooks, reset_step, should_log_device,
should_log_module, should_log_step)
class SimpleModel(nn.Module):
"""A simple model for testing."""
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
@pytest.fixture
def temp_output_dir():
"""Create a temporary directory for test outputs."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Clean up after the test
shutil.rmtree(temp_dir)
@pytest.fixture
def simple_model():
"""Create a simple model for testing."""
return SimpleModel()
@pytest.fixture
def il_config(temp_output_dir):
"""Create a basic IntermediateLoggingConfig for testing."""
return IntermediateLoggingConfig(output_dir=temp_output_dir,
enabled=True,
log_step_ids=[0, 1],
module_call_match=[".*linear.*"])
def test_step_counter():
"""Test the step counter functionality."""
# Reset the step counter
reset_step()
assert get_step() == 0
# Increment the step counter
increment_step()
assert get_step() == 1
# Increment again
increment_step()
assert get_step() == 2
# Reset again
reset_step()
assert get_step() == 0
def test_intermediate_logging_context_manager():
"""Test the intermediate_logging context manager."""
# Create a config
config = IntermediateLoggingConfig(enabled=True)
# Initially, there should be no global config
assert get_current_il_config() is None
# Use the context manager
with intermediate_logging(config):
# Inside the context, the global config should be set
assert get_current_il_config() is not None
assert get_current_il_config().enabled is True
# After the context, the global config should be None again
assert get_current_il_config() is None
# Test with a different config
config2 = IntermediateLoggingConfig(enabled=False)
with intermediate_logging(config2):
assert get_current_il_config() is not None
assert get_current_il_config().enabled is False
def test_should_log_step():
"""Test the should_log_step function."""
# Reset step counter
reset_step()
# Create configs with different step IDs
config_all_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[] # Empty list means log all steps
)
config_specific_steps = IntermediateLoggingConfig(
enabled=True,
log_step_ids=[0, 2, 4] # Only log steps 0, 2, and 4
)
config_disabled = IntermediateLoggingConfig(enabled=False,
log_step_ids=[0, 1, 2])
# Test with all steps config
with intermediate_logging(config_all_steps):
assert should_log_step(config_all_steps) is True # Step 0
increment_step()
assert should_log_step(config_all_steps) is True # Step 1
# Reset step counter
reset_step()
# Test with specific steps config
with intermediate_logging(config_specific_steps):
assert should_log_step(config_specific_steps) is True # Step 0
increment_step()
assert should_log_step(config_specific_steps) is False # Step 1
increment_step()
assert should_log_step(config_specific_steps) is True # Step 2
increment_step()
assert should_log_step(config_specific_steps) is False # Step 3
increment_step()
assert should_log_step(config_specific_steps) is True # Step 4
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_step(config_disabled) is False # Disabled
def test_should_log_device():
"""Test the should_log_device function."""
# Create configs with different device filters
config_all_devices = IntermediateLoggingConfig(
enabled=True,
device_names=[] # Empty list means log all devices
)
config_specific_devices = IntermediateLoggingConfig(
enabled=True,
device_names=["cuda:0", "cpu"] # Only log cuda:0 and cpu
)
config_disabled = IntermediateLoggingConfig(enabled=False,
device_names=["cuda:0", "cpu"])
# Test with all devices config
with intermediate_logging(config_all_devices):
assert should_log_device(config_all_devices, "cuda:0") is True
assert should_log_device(config_all_devices, "cuda:1") is True
assert should_log_device(config_all_devices, "cpu") is True
# Test with specific devices config
with intermediate_logging(config_specific_devices):
assert should_log_device(config_specific_devices, "cuda:0") is True
assert should_log_device(config_specific_devices, "cuda:1") is False
assert should_log_device(config_specific_devices, "cpu") is True
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_device(config_disabled, "cuda:0") is False
assert should_log_device(config_disabled, "cpu") is False
def test_should_log_module(simple_model):
"""Test the should_log_module function."""
# Create configs with different module name filters
config_all_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=None # None means log all modules
)
config_specific_modules = IntermediateLoggingConfig(
enabled=True,
module_call_match=[".*linear.*"
] # Only log modules with "linear" in the name
)
config_disabled = IntermediateLoggingConfig(enabled=False,
module_call_match=[".*"])
# Test with all modules config
with intermediate_logging(config_all_modules):
assert should_log_module(config_all_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_all_modules, "relu",
simple_model.relu) is True
# Test with specific modules config
with intermediate_logging(config_specific_modules):
assert should_log_module(config_specific_modules, "linear1",
simple_model.linear1) is True
assert should_log_module(config_specific_modules, "relu",
simple_model.relu) is False
# Test with disabled config
with intermediate_logging(config_disabled):
assert should_log_module(config_disabled, "linear1",
simple_model.linear1) is False
assert should_log_module(config_disabled, "relu",
simple_model.relu) is False
def test_register_hooks(simple_model, il_config):
"""Test registering hooks on a model."""
# Register hooks
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Check that hooks were registered
assert len(logger_instance.hooks) > 0
# Remove hooks
logger_instance.remove_hooks()
# Check that hooks were removed
assert len(logger_instance.hooks) == 0
@mock.patch(
'vllm.v1.intermediates.intermediates_logging.dump_intermediates_to_json')
@mock.patch('vllm.v1.intermediates.intermediates_logging.save_tensors')
def test_forward_hooks(mock_save_tensors, mock_dump_json, simple_model,
il_config, temp_output_dir):
"""Test that forward hooks are called during model execution."""
mock_save_tensors.return_value = None
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that the step counter was incremented
assert get_step() == 1
# Check that dump_intermediates_to_json and save_tensors were called
assert mock_dump_json.called
assert mock_save_tensors.called
# Remove hooks
logger_instance.remove_hooks()
def test_end_to_end(simple_model, il_config, temp_output_dir):
"""Test the entire intermediate logging workflow end-to-end."""
# Register hooks
with intermediate_logging(il_config):
logger_instance = register_intermediate_hooks(simple_model, il_config)
# Create input tensor
input_tensor = torch.randn(2, 10)
# Reset step counter
reset_step()
# Forward pass
simple_model(input_tensor)
# Check that output directories were created
root_dir = Path(il_config._output_run_dir)
assert root_dir.exists()
step_dir = root_dir / "step_0"
assert step_dir.exists()
module_dirs = list(step_dir.glob("*"))
print(f"{module_dirs=}")
assert len(module_dirs) > 0
# Check that input and output files were created
for module_dir in module_dirs:
print(f"{module_dir=}")
if os.path.isdir(module_dir):
inputs_json = module_dir / "inputs.json"
outputs_json = module_dir / "outputs.json"
# Check that JSON files exist
assert inputs_json.exists()
assert outputs_json.exists()
# Check that JSON files contain valid data
with open(inputs_json) as f:
inputs_data = json.load(f)
assert "type" in inputs_data
with open(outputs_json) as f:
outputs_data = json.load(f)
assert "type" in outputs_data
# Check that tensor files exist
tensor_files = list(module_dir.glob("*.pt"))
assert len(tensor_files) > 0
# Remove hooks
logger_instance.remove_hooks()
if __name__ == "__main__":
pytest.main(["-xvs", __file__])