mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
Implemented mlflow as rollout trace backend. Comparing to weave, mlflow is a lite weight solution and can be deployed on-premises easily. ### API and Usage Example docs/advance/rollout_trace.rst
171 lines
6.1 KiB
Python
171 lines
6.1 KiB
Python
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_rollout_trace_config_singleton():
|
|
"""Fixture to reset the RolloutTraceConfig singleton before each test."""
|
|
RolloutTraceConfig.reset()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_weave_client():
|
|
"""Mocks the weave module and its client, yielding the mock client."""
|
|
mock_weave = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_call = MagicMock()
|
|
mock_client.create_call.return_value = mock_call
|
|
mock_weave.init.return_value = mock_client
|
|
|
|
# Also mock the call_context if it's used internally by the decorator
|
|
mock_weave.trace.context.call_context.return_value = MagicMock()
|
|
|
|
with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}):
|
|
yield mock_client
|
|
|
|
|
|
class TracedClass:
|
|
@rollout_trace_op
|
|
# @weave.op
|
|
# @mlflow.trace
|
|
async def my_method(self, a, b="default"):
|
|
return f"result: {a}, {b}"
|
|
|
|
@rollout_trace_op
|
|
# @weave.op
|
|
# @mlflow.trace
|
|
async def middle_method(self, a, b="default"):
|
|
await self.my_method("test_a1", b="test_b1")
|
|
return f"result: {a}, {b}"
|
|
|
|
@rollout_trace_op
|
|
# @mlflow.trace
|
|
async def my_method_with_exception(self):
|
|
raise ValueError("Test Exception")
|
|
|
|
async def upper_method(self):
|
|
await self.my_method("test_a0", b="test_b0")
|
|
await self.middle_method("test_a2", b="test_b2")
|
|
return True
|
|
|
|
|
|
class UntracedClass:
|
|
@rollout_trace_op
|
|
async def my_method(self, x):
|
|
return x * 2
|
|
|
|
|
|
async def test_rollout_trace_on_untraced_class():
|
|
"""Tests that the decorator works correctly when no backend is configured."""
|
|
instance = UntracedClass()
|
|
assert await instance.my_method(10) == 20
|
|
|
|
|
|
async def test_rollout_trace_with_tracer(mock_weave_client):
|
|
"""Tests that the decorator calls the tracer's methods correctly."""
|
|
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
|
instance = TracedClass()
|
|
assert RolloutTraceConfig.get_client() is mock_weave_client
|
|
|
|
result = await instance.my_method("test_a", b="test_b")
|
|
|
|
assert result == "result: test_a, test_b"
|
|
mock_weave_client.create_call.assert_called_once()
|
|
call_kwargs = mock_weave_client.create_call.call_args.kwargs
|
|
assert call_kwargs["op"] == "TracedClass.my_method"
|
|
expected_inputs = {"a": "test_a", "b": "test_b"}
|
|
assert call_kwargs["inputs"] == expected_inputs
|
|
|
|
mock_call = mock_weave_client.create_call.return_value
|
|
mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)
|
|
|
|
|
|
async def test_rollout_trace_with_exception(mock_weave_client):
|
|
"""Tests that `finish` is called with the exception when one is raised."""
|
|
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
|
instance = TracedClass()
|
|
|
|
with pytest.raises(ValueError, match="Test Exception"):
|
|
await instance.my_method_with_exception()
|
|
|
|
mock_weave_client.create_call.assert_called_once()
|
|
mock_call = mock_weave_client.create_call.return_value
|
|
mock_weave_client.finish_call.assert_called_once()
|
|
|
|
# Check that finish_call was called with the exception
|
|
args, kwargs = mock_weave_client.finish_call.call_args
|
|
assert args[0] == mock_call
|
|
assert "exception" in kwargs
|
|
assert isinstance(kwargs["exception"], ValueError)
|
|
|
|
|
|
async def test_rollout_trace_with_dummy_backend(mock_weave_client):
|
|
"""Tests that the tracer is not called when the backend is 'dummy'."""
|
|
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy")
|
|
instance = TracedClass()
|
|
|
|
await instance.my_method("test_a")
|
|
|
|
mock_weave_client.create_call.assert_not_called()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true",
|
|
reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.",
|
|
)
|
|
async def test_rollout_trace_with_real_weave_backend():
|
|
"""Integration test with a real weave backend."""
|
|
|
|
# This assumes that the weave environment (e.g., project) is configured
|
|
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
|
|
|
instance = TracedClass()
|
|
|
|
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):
|
|
await instance.upper_method()
|
|
|
|
with pytest.raises(ValueError, match="Test Exception"):
|
|
await instance.my_method_with_exception()
|
|
|
|
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
|
|
reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
|
|
)
|
|
async def test_rollout_trace_with_real_mlflow_backend():
|
|
"""Integration test with a real mlflow backend."""
|
|
|
|
# This assumes that the mlflow environment (e.g., project) is configured
|
|
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
|
|
|
|
instance = TracedClass()
|
|
|
|
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
|
|
assert await instance.upper_method()
|
|
|
|
# with pytest.raises(ValueError, match="Test Exception"):
|
|
# await instance.my_method_with_exception()
|
|
|
|
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
|