Files
verl/tests/utils/test_rollout_trace_on_cpu.py
OC def5b28e3d [rollout] feat: support mlflow in rollout trace (#2440)
Implemented mlflow as rollout trace backend. Comparing to weave, mlflow
is a lite weight solution and can be deployed on-premises easily.

### API and Usage Example

docs/advance/rollout_trace.rst
2025-07-15 05:18:40 +08:00

171 lines
6.1 KiB
Python

# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
@pytest.fixture(autouse=True)
def reset_rollout_trace_config_singleton():
"""Fixture to reset the RolloutTraceConfig singleton before each test."""
RolloutTraceConfig.reset()
@pytest.fixture
def mock_weave_client():
"""Mocks the weave module and its client, yielding the mock client."""
mock_weave = MagicMock()
mock_client = MagicMock()
mock_call = MagicMock()
mock_client.create_call.return_value = mock_call
mock_weave.init.return_value = mock_client
# Also mock the call_context if it's used internally by the decorator
mock_weave.trace.context.call_context.return_value = MagicMock()
with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}):
yield mock_client
class TracedClass:
@rollout_trace_op
# @weave.op
# @mlflow.trace
async def my_method(self, a, b="default"):
return f"result: {a}, {b}"
@rollout_trace_op
# @weave.op
# @mlflow.trace
async def middle_method(self, a, b="default"):
await self.my_method("test_a1", b="test_b1")
return f"result: {a}, {b}"
@rollout_trace_op
# @mlflow.trace
async def my_method_with_exception(self):
raise ValueError("Test Exception")
async def upper_method(self):
await self.my_method("test_a0", b="test_b0")
await self.middle_method("test_a2", b="test_b2")
return True
class UntracedClass:
@rollout_trace_op
async def my_method(self, x):
return x * 2
async def test_rollout_trace_on_untraced_class():
"""Tests that the decorator works correctly when no backend is configured."""
instance = UntracedClass()
assert await instance.my_method(10) == 20
async def test_rollout_trace_with_tracer(mock_weave_client):
"""Tests that the decorator calls the tracer's methods correctly."""
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
instance = TracedClass()
assert RolloutTraceConfig.get_client() is mock_weave_client
result = await instance.my_method("test_a", b="test_b")
assert result == "result: test_a, test_b"
mock_weave_client.create_call.assert_called_once()
call_kwargs = mock_weave_client.create_call.call_args.kwargs
assert call_kwargs["op"] == "TracedClass.my_method"
expected_inputs = {"a": "test_a", "b": "test_b"}
assert call_kwargs["inputs"] == expected_inputs
mock_call = mock_weave_client.create_call.return_value
mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)
async def test_rollout_trace_with_exception(mock_weave_client):
"""Tests that `finish` is called with the exception when one is raised."""
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
instance = TracedClass()
with pytest.raises(ValueError, match="Test Exception"):
await instance.my_method_with_exception()
mock_weave_client.create_call.assert_called_once()
mock_call = mock_weave_client.create_call.return_value
mock_weave_client.finish_call.assert_called_once()
# Check that finish_call was called with the exception
args, kwargs = mock_weave_client.finish_call.call_args
assert args[0] == mock_call
assert "exception" in kwargs
assert isinstance(kwargs["exception"], ValueError)
async def test_rollout_trace_with_dummy_backend(mock_weave_client):
"""Tests that the tracer is not called when the backend is 'dummy'."""
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy")
instance = TracedClass()
await instance.my_method("test_a")
mock_weave_client.create_call.assert_not_called()
@pytest.mark.skipif(
os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true",
reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.",
)
async def test_rollout_trace_with_real_weave_backend():
"""Integration test with a real weave backend."""
# This assumes that the weave environment (e.g., project) is configured
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
instance = TracedClass()
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):
await instance.upper_method()
with pytest.raises(ValueError, match="Test Exception"):
await instance.my_method_with_exception()
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
@pytest.mark.skipif(
os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
)
async def test_rollout_trace_with_real_mlflow_backend():
"""Integration test with a real mlflow backend."""
# This assumes that the mlflow environment (e.g., project) is configured
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
instance = TracedClass()
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
assert await instance.upper_method()
# with pytest.raises(ValueError, match="Test Exception"):
# await instance.my_method_with_exception()
print("\nWeave integration test ran successfully. Check your weave project for the trace.")