Tracking of (trainable) parameters for MetaMathQA (#2598)

This change adds tracking for the number of (trainable) parameters for each experiment

Tracking the number of parameters, trainable and total, will make the results
much more transparent regarding model capacity. If a method was accidentally
trained with a lot more or less trainable parameters it would make for unfair
results. Having these numbers will also make benchmarking parameter efficiency
easier.
This commit is contained in:
githubnemo
2025-06-19 18:08:25 +02:00
committed by GitHub
parent 4721213828
commit 179e29a756
2 changed files with 6 additions and 0 deletions

View File

@ -343,6 +343,8 @@ def train(
losses=losses,
metrics=metrics,
error_msg=error_msg,
num_trainable_params=num_trainable_params,
num_total_params=num_params,
)
return train_result

View File

@ -547,6 +547,8 @@ class TrainResult:
losses: list[float]
metrics: list[Any] # TODO
error_msg: str
num_trainable_params: int
num_total_params: int
def log_to_console(log_data: dict[str, Any], print_fn: Callable[..., None]) -> None:
@ -648,6 +650,8 @@ def log_results(
"cuda_memory_reserved_99th": cuda_memory_reserved_99th,
"train_time": train_result.train_time,
"file_size": file_size,
"num_trainable_params": train_result.num_trainable_params,
"num_total_params": train_result.num_total_params,
"status": train_result.status.value,
"metrics": train_result.metrics,
},