🪪 Adds a more fine-grained profiling context (#2975)

* adds a more fine grained profiling context

* precommit

* fix reward func name

* add reward to RM name

* Update trl/extras/profiling.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* some doc and fixes

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
Edward Beeching
2025-02-27 21:58:39 +01:00
committed by GitHub
parent c0854c32c9
commit ac327d5e84
4 changed files with 98 additions and 31 deletions

View File

@ -107,4 +107,6 @@
title: Text Environments
- local: script_utils
title: Script Utilities
- local: others
title: Others
title: API

9
docs/source/others.md Normal file
View File

@ -0,0 +1,9 @@
# Other
## profiling_decorator
[[autodoc]] extras.profiling.profiling_decorator
## profiling_context
[[autodoc]] extras.profiling.profiling_context

View File

@ -12,30 +12,78 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import time
from typing import Generator
from transformers import is_wandb_available
from transformers import Trainer, is_wandb_available
if is_wandb_available():
import wandb
def profiling_decorator(func):
@contextlib.contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
"""
Decorator to profile a function and log the time taken to execute it.
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
Args:
trainer (`~transformers.Trainer`):
Trainer object.
name (`str`):
Name of the block to be profiled. Used as a key in the logged dictionary.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_context
class MyTrainer(Trainer):
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
with profiling_context(self, "matrix_multiplication"):
# Code to profile: simulate a computationally expensive operation
result = A @ B # Matrix multiplication
```
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})
def profiling_decorator(func: callable) -> callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
Args:
func (`callable`):
Function to be profiled.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_decorator
class MyTrainer(Trainer):
@profiling_decorator
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
# Code to profile: simulate a computationally expensive operation
result = A @ B
```
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time
if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper

View File

@ -46,7 +46,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..extras.profiling import profiling_context, profiling_decorator
from ..import_utils import is_rich_available, is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
@ -729,6 +729,7 @@ class GRPOTrainer(Trainer):
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
@ -812,6 +813,13 @@ class GRPOTrainer(Trainer):
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
else:
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]