* trackio

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* seven -> eight

* Add trackio as a real tracker instead

* Sort

* Style

* Style

* Remove step

* Disable trackio on Python < 3.10

* Update src/accelerate/tracking.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* More style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Pedro Cuenca
2025-07-15 17:17:49 +02:00
committed by GitHub
parent 847ae58c74
commit e2cc537db8
10 changed files with 128 additions and 3 deletions

View File

@ -29,6 +29,11 @@ rendered properly in your Markdown viewer.
[[autodoc]] tracking.WandBTracker
- __init__
## Trackio
[[autodoc]] tracking.TrackioTracker
- __init__
## CometMLTracker
[[autodoc]] tracking.CometMLTracker

View File

@ -20,10 +20,11 @@ Accelerate provides a general tracking API that can be used to log useful items
## Integrated Trackers
Currently `Accelerate` supports seven trackers out-of-the-box:
Currently `Accelerate` supports eight trackers out-of-the-box:
- TensorBoard
- WandB
- WandB
- Trackio
- CometML
- Aim
- MLFlow

View File

@ -49,6 +49,7 @@ extras["test_trackers"] = [
"mlflow",
"matplotlib",
"swanlab",
"trackio",
]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]

View File

@ -233,7 +233,11 @@ class Accelerator:
- `"all"`
- `"tensorboard"`
- `"wandb"`
- `"trackio"`
- `"aim"`
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.

View File

@ -69,6 +69,7 @@ from ..utils import (
is_torchao_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformers_available,
is_triton_available,
@ -459,6 +460,13 @@ def require_wandb(test_case):
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
def require_trackio(test_case):
"""
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
"""
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
def require_comet_ml(test_case):
"""
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
@ -548,7 +556,8 @@ def require_matplotlib(test_case):
_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available()
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
and not is_comet_ml_available()
)

View File

@ -36,6 +36,7 @@ from .utils import (
is_mlflow_available,
is_swanlab_available,
is_tensorboard_available,
is_trackio_available,
is_wandb_available,
listify,
)
@ -67,6 +68,9 @@ if is_dvclive_available():
if is_swanlab_available():
_available_trackers.append(LoggerType.SWANLAB)
if is_trackio_available():
_available_trackers.append(LoggerType.TRACKIO)
logger = get_logger(__name__)
@ -415,6 +419,83 @@ class WandBTracker(GeneralTracker):
logger.debug("WandB run closed")
class TrackioTracker(GeneralTracker):
"""
A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.
Args:
run_name (`str`):
The name of the experiment run. Will be used as the `project` name when instantiating trackio.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `trackio.init` method. Refer to this
[init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)
to see all supported key word arguments.
"""
name = "trackio"
requires_logging_directory = False
main_process_only = False
def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name
self.init_kwargs = kwargs
@on_main_process
def start(self):
import trackio
self.run = trackio.init(project=self.run_name, **self.init_kwargs)
logger.debug(f"Initialized trackio project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)
@property
def tracker(self):
return self.run
@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
import trackio
trackio.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to trackio")
@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.
Args:
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
`str` to `float`/`int`.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `trackio.log` method.
"""
self.run.log(values, **kwargs)
logger.debug("Successfully logged to trackio")
@on_main_process
def finish(self):
"""
Closes `trackio` run
"""
self.run.finish()
logger.debug("trackio run closed")
class CometMLTracker(GeneralTracker):
"""
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
@ -1174,6 +1255,7 @@ LOGGER_TYPE_TO_CLASS = {
"clearml": ClearMLTracker,
"dvclive": DVCLiveTracker,
"swanlab": SwanLabTracker,
"trackio": TrackioTracker,
}
@ -1195,6 +1277,8 @@ def filter_trackers(
- `"all"`
- `"tensorboard"`
- `"wandb"`
- `"trackio"`
- `"aim"`
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`

View File

@ -129,6 +129,7 @@ from .imports import (
is_torchdata_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformers_available,
is_triton_available,

View File

@ -701,7 +701,10 @@ class LoggerType(BaseEnum):
- **ALL** -- all available trackers in the environment that are supported
- **TENSORBOARD** -- TensorBoard as an experiment tracker
- **WANDB** -- wandb as an experiment tracker
- **TRACKIO** -- trackio as an experiment tracker
- **COMETML** -- comet_ml as an experiment tracker
- **MLFLOW** -- mlflow as an experiment tracker
- **CLEARML** -- clearml as an experiment tracker
- **DVCLIVE** -- dvclive as an experiment tracker
- **SWANLAB** -- swanlab as an experiment tracker
"""
@ -710,6 +713,7 @@ class LoggerType(BaseEnum):
AIM = "aim"
TENSORBOARD = "tensorboard"
WANDB = "wandb"
TRACKIO = "trackio"
COMETML = "comet_ml"
MLFLOW = "mlflow"
CLEARML = "clearml"

View File

@ -15,6 +15,7 @@
import importlib
import importlib.metadata
import os
import sys
import warnings
from functools import lru_cache, wraps
@ -285,6 +286,10 @@ def is_swanlab_available():
return _is_package_available("swanlab")
def is_trackio_available():
return sys.version_info >= (3, 10) and _is_package_available("trackio")
def is_boto3_available():
return _is_package_available("boto3")

View File

@ -45,6 +45,7 @@ from accelerate.test_utils.testing import (
require_pandas,
require_swanlab,
require_tensorboard,
require_trackio,
require_wandb,
skip,
)
@ -57,6 +58,7 @@ from accelerate.tracking import (
MLflowTracker,
SwanLabTracker,
TensorBoardTracker,
TrackioTracker,
WandBTracker,
)
from accelerate.utils import (
@ -801,6 +803,15 @@ class TrackerDeferredInitializationTest(unittest.TestCase):
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})
@require_trackio
def test_trackio_deferred_init(self):
"""Test that trackio tracker initialization doesn't initialize distributed"""
PartialState._reset_state()
tracker = TrackioTracker(run_name="test_trackio")
self.assertEqual(PartialState._shared_state, {})
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})
@require_comet_ml
def test_comet_ml_deferred_init(self):
"""Test that CometML tracker initialization doesn't initialize distributed"""