mirror of
				https://github.com/volcengine/verl.git
				synced 2025-10-20 21:53:50 +08:00 
			
		
		
		
	Implemented mlflow as rollout trace backend. Comparing to weave, mlflow is a lite weight solution and can be deployed on-premises easily. ### API and Usage Example docs/advance/rollout_trace.rst
		
			
				
	
	
		
			171 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2025 Bytedance Ltd. and/or its affiliates
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| from unittest.mock import MagicMock, patch
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
 | |
| 
 | |
| 
 | |
| @pytest.fixture(autouse=True)
 | |
| def reset_rollout_trace_config_singleton():
 | |
|     """Fixture to reset the RolloutTraceConfig singleton before each test."""
 | |
|     RolloutTraceConfig.reset()
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def mock_weave_client():
 | |
|     """Mocks the weave module and its client, yielding the mock client."""
 | |
|     mock_weave = MagicMock()
 | |
|     mock_client = MagicMock()
 | |
|     mock_call = MagicMock()
 | |
|     mock_client.create_call.return_value = mock_call
 | |
|     mock_weave.init.return_value = mock_client
 | |
| 
 | |
|     # Also mock the call_context if it's used internally by the decorator
 | |
|     mock_weave.trace.context.call_context.return_value = MagicMock()
 | |
| 
 | |
|     with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}):
 | |
|         yield mock_client
 | |
| 
 | |
| 
 | |
| class TracedClass:
 | |
|     @rollout_trace_op
 | |
|     # @weave.op
 | |
|     # @mlflow.trace
 | |
|     async def my_method(self, a, b="default"):
 | |
|         return f"result: {a}, {b}"
 | |
| 
 | |
|     @rollout_trace_op
 | |
|     # @weave.op
 | |
|     # @mlflow.trace
 | |
|     async def middle_method(self, a, b="default"):
 | |
|         await self.my_method("test_a1", b="test_b1")
 | |
|         return f"result: {a}, {b}"
 | |
| 
 | |
|     @rollout_trace_op
 | |
|     # @mlflow.trace
 | |
|     async def my_method_with_exception(self):
 | |
|         raise ValueError("Test Exception")
 | |
| 
 | |
|     async def upper_method(self):
 | |
|         await self.my_method("test_a0", b="test_b0")
 | |
|         await self.middle_method("test_a2", b="test_b2")
 | |
|         return True
 | |
| 
 | |
| 
 | |
| class UntracedClass:
 | |
|     @rollout_trace_op
 | |
|     async def my_method(self, x):
 | |
|         return x * 2
 | |
| 
 | |
| 
 | |
| async def test_rollout_trace_on_untraced_class():
 | |
|     """Tests that the decorator works correctly when no backend is configured."""
 | |
|     instance = UntracedClass()
 | |
|     assert await instance.my_method(10) == 20
 | |
| 
 | |
| 
 | |
| async def test_rollout_trace_with_tracer(mock_weave_client):
 | |
|     """Tests that the decorator calls the tracer's methods correctly."""
 | |
|     RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
 | |
|     instance = TracedClass()
 | |
|     assert RolloutTraceConfig.get_client() is mock_weave_client
 | |
| 
 | |
|     result = await instance.my_method("test_a", b="test_b")
 | |
| 
 | |
|     assert result == "result: test_a, test_b"
 | |
|     mock_weave_client.create_call.assert_called_once()
 | |
|     call_kwargs = mock_weave_client.create_call.call_args.kwargs
 | |
|     assert call_kwargs["op"] == "TracedClass.my_method"
 | |
|     expected_inputs = {"a": "test_a", "b": "test_b"}
 | |
|     assert call_kwargs["inputs"] == expected_inputs
 | |
| 
 | |
|     mock_call = mock_weave_client.create_call.return_value
 | |
|     mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)
 | |
| 
 | |
| 
 | |
| async def test_rollout_trace_with_exception(mock_weave_client):
 | |
|     """Tests that `finish` is called with the exception when one is raised."""
 | |
|     RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
 | |
|     instance = TracedClass()
 | |
| 
 | |
|     with pytest.raises(ValueError, match="Test Exception"):
 | |
|         await instance.my_method_with_exception()
 | |
| 
 | |
|     mock_weave_client.create_call.assert_called_once()
 | |
|     mock_call = mock_weave_client.create_call.return_value
 | |
|     mock_weave_client.finish_call.assert_called_once()
 | |
| 
 | |
|     # Check that finish_call was called with the exception
 | |
|     args, kwargs = mock_weave_client.finish_call.call_args
 | |
|     assert args[0] == mock_call
 | |
|     assert "exception" in kwargs
 | |
|     assert isinstance(kwargs["exception"], ValueError)
 | |
| 
 | |
| 
 | |
| async def test_rollout_trace_with_dummy_backend(mock_weave_client):
 | |
|     """Tests that the tracer is not called when the backend is 'dummy'."""
 | |
|     RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy")
 | |
|     instance = TracedClass()
 | |
| 
 | |
|     await instance.my_method("test_a")
 | |
| 
 | |
|     mock_weave_client.create_call.assert_not_called()
 | |
| 
 | |
| 
 | |
| @pytest.mark.skipif(
 | |
|     os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true",
 | |
|     reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.",
 | |
| )
 | |
| async def test_rollout_trace_with_real_weave_backend():
 | |
|     """Integration test with a real weave backend."""
 | |
| 
 | |
|     # This assumes that the weave environment (e.g., project) is configured
 | |
|     RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
 | |
| 
 | |
|     instance = TracedClass()
 | |
| 
 | |
|     with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):
 | |
|         await instance.upper_method()
 | |
| 
 | |
|     with pytest.raises(ValueError, match="Test Exception"):
 | |
|         await instance.my_method_with_exception()
 | |
| 
 | |
|     print("\nWeave integration test ran successfully. Check your weave project for the trace.")
 | |
| 
 | |
| 
 | |
| @pytest.mark.skipif(
 | |
|     os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
 | |
|     reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
 | |
| )
 | |
| async def test_rollout_trace_with_real_mlflow_backend():
 | |
|     """Integration test with a real mlflow backend."""
 | |
| 
 | |
|     # This assumes that the mlflow environment (e.g., project) is configured
 | |
|     RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
 | |
| 
 | |
|     instance = TracedClass()
 | |
| 
 | |
|     with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
 | |
|         assert await instance.upper_method()
 | |
| 
 | |
|     # with pytest.raises(ValueError, match="Test Exception"):
 | |
|     #     await instance.my_method_with_exception()
 | |
| 
 | |
|     print("\nWeave integration test ran successfully. Check your weave project for the trace.")
 |