[rollout] feat: support mlflow in rollout trace (#2440)

Implemented mlflow as rollout trace backend. Comparing to weave, mlflow
is a lite weight solution and can be deployed on-premises easily.

### API and Usage Example

docs/advance/rollout_trace.rst
This commit is contained in:
OC
2025-07-15 05:18:40 +08:00
committed by GitHub
parent 141b1d3251
commit def5b28e3d
12 changed files with 248 additions and 60 deletions

View File

@ -0,0 +1,125 @@
Trace Function Usage Instructions
========================================
Last updated: 07/10/2025.
Applicable Scenarios
--------------------
Agentic RL involves multiple turns of conversations, tool invocations, and user interactions during the rollout process. During the Model Training process, it is necessary to track function calls, inputs, and outputs to understand the flow path of data within the application. The Trace feature helps, in complex multi-round conversations, to view the transformation of data during each interaction and the entire process leading to the final output by recording the inputs, outputs, and corresponding timestamps of functions, which is conducive to understanding the details of how the model processes data and optimizing the training results.
The Trace feature integrates commonly used Agent trace tools, including wandb weave and mlflow, which are already supported. Users can choose the appropriate trace tool according to their own needs and preferences. Here, we introduce the usage of each tool.
Trace Parameter Configuration
-----------------------------
- ``actor_rollout_ref.rollout.trace.backend=mlflow|weave`` # the trace backend type
- ``actor_rollout_ref.rollout.trace.token2text=True`` # To show decoded text in trace view
Glossary
--------
+----------------+------------------------------------------------------------------------------------------------------+
| Object | Explaination |
+================+======================================================================================================+
| trajectory | A complete multi-turn conversation includes: |
| | 1. LLM output at least once |
| | 2. Tool Call |
+----------------+------------------------------------------------------------------------------------------------------+
| step | The training step corresponds to the global_steps variable in the trainer |
+----------------+------------------------------------------------------------------------------------------------------+
| sample_index | The identifier of the sample, defined in the extra_info.index of the dataset. It is usually a number,|
| | but may also be a uuid in some cases. |
+----------------+------------------------------------------------------------------------------------------------------+
| rollout_n | In the GROP algorithm, each sample is rolled out n times. rollout_n represents the serial number of |
| | the rollout. |
+----------------+------------------------------------------------------------------------------------------------------+
| validate | Whether the test dataset is used for evaluation? |
+----------------+------------------------------------------------------------------------------------------------------+
Rollout trace functions
-----------------------
There are 2 functions used for tracing:
1. ``rollout_trace_op``: This is a decorator function used to mark the functions to trace. In default, only few method has it, you can add it to more functions to trace more infor.
2. ``rollout_trace_attr``: This function is used to mark the entry of a trajectory and input some info to trace. If you add new type of agent, you may need to add it to enable trace.
Usage of wandb weave
--------------------
1.1 Basic Configuration
~~~~~~~~~~~~~~~~~~~~~~~
1. Set the ``WANDB_API_KEY`` environment variable
2. Configuration Parameters
1. ``actor_rollout_ref.rollout.trace.backend=weave``
2. ``trainer.logger=['console', 'wandb']``: This item is optional. Trace and logger are independent functions. When using Weave, it is recommended to also enable the wandb logger to implement both functions in one system.
3. ``trainer.project_name=$project_name``
4. ``trainer.experiment_name=$experiment_name``
5. ``actor_rollout_ref.rollout.mode=async``: Since trace is mainly used for agentic RL, need to enable agent toop using async mode for either vllm or sglang.
Note:
The Weave Free Plan comes with a default monthly network traffic allowance of 1GB. During the training process, the amount of trace data generated is substantial, reaching dozens of gigabytes per day, so it is necessary to select an appropriate wandb plan.
1.2 View Trace Logs
~~~~~~~~~~~~~~~~~~~
After executing the training, on the project page, you can see the WEAVE sidebar. Click Traces to view it.
Each Trace project corresponds to a trajectory. You can filter and select the trajectories you need to view by step, sample_index, rollout_n, and experiment_name.
After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the input and output content.
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_list.png?raw=true
1.3 Compare Trace Logs
~~~~~~~~~~~~~~~~~~~~~~
Weave can select multiple trace items and then compare the differences among them.
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_compare.png?raw=true
Usage of mlflow
---------------
1. Basic Configuration
~~~~~~~~~~~~~~~~~~~~~~
1. Set the ``MLFLOW_TRACKING_URI`` environment variable, which can be:
1. Http and https URLs corresponding to online services
2. Local files or directories, such as ``sqlite:////tmp/mlruns.db``, indicate that data is stored in ``/tmp/mlruns.db``. When using local files, it is necessary to initialize the file first (e.g., start the UI: ``mlflow ui --backend-store-uri sqlite:////tmp/mlruns.db``) to avoid conflicts when multiple workers create files simultaneously.
2. Configuration Parameters
1. ``actor_rollout_ref.rollout.trace.backend=mlflow``
2. ``trainer.logger=['console', 'mlflow']``. This item is optional. Trace and logger are independent functions. When using mlflow, it is recommended to also enable the mlflow logger to implement both functions in one system.
3. ``trainer.project_name=$project_name``
4. ``trainer.experiment_name=$experiment_name``
2. View Log
~~~~~~~~~~~
Since ``trainer.project_name`` corresponds to Experiments in mlflow, in the mlflow view, you need to select the corresponding project name, then click the "Traces" tab to view traces. Among them, ``trainer.experiment_name`` corresponds to the experiment_name of tags, and tags corresponding to step, sample_index, rollout_n, etc., are used for filtering and viewing.
For example, searching for ``"tags.step = '1'"`` can display all trajectories of step 1.
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_list.png?raw=true
Opening one of the trajectories allows you to view each function call process within it.
After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the content.
.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_view.png?raw=true
Note:
1. mlflow does not support comparing multiple traces
2. rollout_trace can not associate the mlflow trace with the run, so the trace content cannot be seen in the mlflow run logs.

View File

@ -177,21 +177,3 @@ Comparing to using global https_proxy env variable, this approach won't mess up
+trainer.wandb_proxy=http://<your proxy and port>
How to trace rollout and toolcall data?
---------------------------------------
To enable trace rollout data, you can set the config term ``trainer.rollout_trace.backend`` to a backend name.
For example:
.. code:: bash
+trainer.rollout_trace.backend=weave # only wandb weave is support now
To show decoded text in trace view, you can set the config term ``trainer.rollout_trace.token2text`` to True.
For example:
.. code:: bash
+trainer.rollout_trace.token2text=True # default to False for better performance

View File

@ -114,6 +114,7 @@ verl is fast with:
advance/placement
advance/dpo_extension
examples/sandbox_fusion_example
advance/rollout_trace.rst
.. toctree::
:maxdepth: 1

View File

@ -259,12 +259,12 @@ async def test_get_trajectory_info():
step = 10
index = [1, 1, 3, 3]
expected_info = [
{"step": step, "sample_index": 1, "rollout_n": 0},
{"step": step, "sample_index": 1, "rollout_n": 1},
{"step": step, "sample_index": 3, "rollout_n": 0},
{"step": step, "sample_index": 3, "rollout_n": 1},
{"step": step, "sample_index": 1, "rollout_n": 0, "validate": False},
{"step": step, "sample_index": 1, "rollout_n": 1, "validate": False},
{"step": step, "sample_index": 3, "rollout_n": 0, "validate": False},
{"step": step, "sample_index": 3, "rollout_n": 1, "validate": False},
]
trajectory_info = await get_trajectory_info(step, index)
trajectory_info = await get_trajectory_info(step, index, validate=False)
assert trajectory_info == expected_info

View File

@ -46,22 +46,26 @@ def mock_weave_client():
class TracedClass:
@rollout_trace_op
# @weave.op
# @mlflow.trace
async def my_method(self, a, b="default"):
return f"result: {a}, {b}"
@rollout_trace_op
# @weave.op
# @mlflow.trace
async def middle_method(self, a, b="default"):
await self.my_method("test_a1", b="test_b1")
return f"result: {a}, {b}"
@rollout_trace_op
# @mlflow.trace
async def my_method_with_exception(self):
raise ValueError("Test Exception")
async def upper_method(self):
await self.my_method("test_a0", b="test_b0")
await self.middle_method("test_a2", b="test_b2")
return True
class UntracedClass:
@ -143,3 +147,24 @@ async def test_rollout_trace_with_real_weave_backend():
await instance.my_method_with_exception()
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
@pytest.mark.skipif(
os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
)
async def test_rollout_trace_with_real_mlflow_backend():
"""Integration test with a real mlflow backend."""
# This assumes that the mlflow environment (e.g., project) is configured
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
instance = TracedClass()
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
assert await instance.upper_method()
# with pytest.raises(ValueError, match="Test Exception"):
# await instance.my_method_with_exception()
print("\nWeave integration test ran successfully. Check your weave project for the trace.")

View File

@ -69,6 +69,7 @@ def test_sglang_spmd():
mem_fraction_static=0.5,
enable_memory_saver=True,
tp_size=inference_device_mesh_cpu["tp"].size(),
attention_backend="fa3",
)
input_ids = input_ids.cuda()

View File

@ -180,11 +180,10 @@ class AgentLoopWorker:
local_path = copy_to_local(config.actor_rollout_ref.model.path)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
trace_config = config.trainer.get("rollout_trace", {})
trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
RolloutTraceConfig.init(
config.trainer.project_name,
config.trainer.experiment_name,
self.config.trainer.project_name,
self.config.trainer.experiment_name,
trace_config.get("backend"),
trace_config.get("token2text", False),
)
@ -234,7 +233,9 @@ class AgentLoopWorker:
else:
index = np.arange(len(raw_prompts))
trajectory_info = await get_trajectory_info(batch.meta_info.get("global_steps", -1), index)
trajectory_info = await get_trajectory_info(
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
)
for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
tasks.append(
@ -253,7 +254,11 @@ class AgentLoopWorker:
trajectory: dict[str, Any],
) -> AgentLoopOutput:
with rollout_trace_attr(
step=trajectory["step"], sample_index=trajectory["sample_index"], rollout_n=trajectory["rollout_n"]
step=trajectory["step"],
sample_index=trajectory["sample_index"],
rollout_n=trajectory["rollout_n"],
validate=trajectory["validate"],
name="agent_loop",
):
agent_loop_class = self.get_agent_loop_class(agent_name)
agent_loop = agent_loop_class(self.config, self.server_manager, self.tokenizer)
@ -350,8 +355,17 @@ class AgentLoopWorker:
return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics})
async def get_trajectory_info(step, index):
"""Get the trajectory info (step, sample_index, rollout_n) asynchrously"""
async def get_trajectory_info(step, index, validate):
"""Get trajectory info.
Args:
step (int): global steps in the trainer.
index (list): form datastore extra_info.index column.
validate (bool): whether is a validate step.
Returns:
list: trajectory.
"""
trajectory_info = []
rollout_n = 0
for i in range(len(index)):
@ -359,7 +373,7 @@ async def get_trajectory_info(step, index):
rollout_n += 1
else:
rollout_n = 0
trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n})
trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate})
return trajectory_info

View File

@ -144,7 +144,6 @@ trainer:
capture-range: "cudaProfilerApi"
capture-range-end: null
kill: none
ray_init:
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
timeline_json_file: null

View File

@ -178,3 +178,12 @@ agent:
# Class name of the custom async server class (e.g. AsyncvLLMServer)
name: null
# trace rollout data
trace:
# trace backend, support mlflow, weave
backend: null
# whether translate token id to text in output
token2text: False

View File

@ -738,6 +738,7 @@ class RayPPOTrainer:
"recompute_log_prob": False,
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
"validate": True,
"global_steps": self.global_steps,
}
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")

View File

@ -16,6 +16,7 @@ import asyncio
import contextlib
import functools
import inspect
import os
from typing import Optional
@ -24,10 +25,14 @@ class RolloutTraceConfig:
backend: Optional[str] = None
client: Optional[object] = None
token2text: bool = False
_initialized: bool = False
project_name: str = None
experiment_name: str = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
@classmethod
@ -39,10 +44,14 @@ class RolloutTraceConfig:
@classmethod
def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False):
config = cls.get_instance()
if config._initialized:
return
config.backend = backend
config.token2text = token2text
config.project_name = project_name
config.experiment_name = experiment_name
if backend == "weave":
import weave
@ -50,10 +59,18 @@ class RolloutTraceConfig:
elif backend == "mlflow":
import mlflow
mlflow.config.enable_async_logging()
config.client = mlflow
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment(project_name)
else:
config.client = None
config._initialized = True
@classmethod
def get_backend(cls) -> Optional[str]:
return cls.get_instance().backend
@ -72,17 +89,19 @@ class RolloutTraceConfig:
@contextlib.contextmanager
def rollout_trace_attr(sample_index=None, step=None, rollout_n=None):
def rollout_trace_attr(sample_index=None, step=None, rollout_n=None, name="rollout_trace", validate=False):
"""A context manager to add attributes to a trace for the configured backend."""
backend = RolloutTraceConfig.get_backend()
attributes = {}
if sample_index is not None:
attributes["sample_index"] = sample_index
if step is not None:
attributes["step"] = step
if rollout_n is not None:
attributes["rollout_n"] = rollout_n
attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name
if backend:
if sample_index is not None:
attributes["sample_index"] = sample_index
if step is not None:
attributes["step"] = step
if rollout_n is not None:
attributes["rollout_n"] = rollout_n
attributes["validate"] = validate
attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name
if not attributes or backend is None:
yield
@ -93,17 +112,14 @@ def rollout_trace_attr(sample_index=None, step=None, rollout_n=None):
with weave.attributes(attributes):
yield
# TODO implement mlfow trace
# elif backend == "mlflow":
# import mlflow
# # This assumes a run is already active.
# # We are setting tags for the current active run.
# try:
# mlflow.set_tags(attributes)
# except Exception:
# # Silently fail if there is no active run.
# pass
# yield
elif backend == "mlflow":
import mlflow
with mlflow.start_span(name=name) as span:
trace_id = span.trace_id
for key, value in attributes.items():
mlflow.set_trace_tag(trace_id, str(key), str(value))
yield
else:
yield
@ -124,15 +140,15 @@ def rollout_trace_op(func):
async def add_token2text(self, result):
if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"):
_result = [result]
_result = vars(result)
loop = asyncio.get_running_loop()
if hasattr(result, "prompt_ids"):
prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids)
_result.append(prompt_text)
_result["prompt_text"] = prompt_text
if hasattr(result, "response_ids"):
response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids)
_result.append(response_text)
_result["response_text"] = response_text
return _result
return result
@ -156,7 +172,20 @@ def rollout_trace_op(func):
except Exception as e:
tracer.finish_call(call, exception=e)
raise e
# TODO implement other backends such as mlflow
elif backend == "mlflow":
import mlflow
with mlflow.start_span(name=func.__qualname__) as span:
span.set_inputs(inputs)
result = await func(self, *args, **kwargs)
if enable_token2text:
_result = await add_token2text(self, result)
span.set_outputs(_result)
else:
span.set_outputs(result)
return result
else:
return await func(self, *args, **kwargs)
@ -185,7 +214,10 @@ def rollout_trace_op(func):
except Exception as e:
tracer.finish_call(call, exception=e)
raise e
# TODO implement other backends such as mlflow
elif backend == "mlflow":
import mlflow
return mlflow.trace(func)(self, *args, **kwargs)
else:
return func(self, *args, **kwargs)

View File

@ -63,9 +63,8 @@ class Tracking:
import mlflow
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", None)
if MLFLOW_TRACKING_URI:
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
# Project_name is actually experiment_name in MLFlow
# If experiment does not exist, will create a new experiment