mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🪪 Adds a more fine-grained profiling context (#2975)
* adds a more fine grained profiling context * precommit * fix reward func name * add reward to RM name * Update trl/extras/profiling.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * some doc and fixes --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
@ -107,4 +107,6 @@
|
||||
title: Text Environments
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
- local: others
|
||||
title: Others
|
||||
title: API
|
||||
|
9
docs/source/others.md
Normal file
9
docs/source/others.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Other
|
||||
|
||||
## profiling_decorator
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_decorator
|
||||
|
||||
## profiling_context
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_context
|
@ -12,30 +12,78 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
from transformers import is_wandb_available
|
||||
from transformers import Trainer, is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
def profiling_decorator(func):
|
||||
@contextlib.contextmanager
|
||||
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
|
||||
"""
|
||||
Decorator to profile a function and log the time taken to execute it.
|
||||
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
|
||||
|
||||
Args:
|
||||
trainer (`~transformers.Trainer`):
|
||||
Trainer object.
|
||||
name (`str`):
|
||||
Name of the block to be profiled. Used as a key in the logged dictionary.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from transformers import Trainer
|
||||
from trl.extras.profiling import profiling_context
|
||||
|
||||
class MyTrainer(Trainer):
|
||||
def some_method(self):
|
||||
A = np.random.rand(1000, 1000)
|
||||
B = np.random.rand(1000, 1000)
|
||||
with profiling_context(self, "matrix_multiplication"):
|
||||
# Code to profile: simulate a computationally expensive operation
|
||||
result = A @ B # Matrix multiplication
|
||||
```
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
yield
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
|
||||
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})
|
||||
|
||||
|
||||
def profiling_decorator(func: callable) -> callable:
|
||||
"""
|
||||
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
|
||||
|
||||
Args:
|
||||
func (`callable`):
|
||||
Function to be profiled.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from transformers import Trainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
class MyTrainer(Trainer):
|
||||
@profiling_decorator
|
||||
def some_method(self):
|
||||
A = np.random.rand(1000, 1000)
|
||||
B = np.random.rand(1000, 1000)
|
||||
# Code to profile: simulate a computationally expensive operation
|
||||
result = A @ B
|
||||
```
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = func(self, *args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
|
||||
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
|
||||
return result
|
||||
with profiling_context(self, func.__name__):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -46,7 +46,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from ..extras.profiling import profiling_decorator
|
||||
from ..extras.profiling import profiling_context, profiling_decorator
|
||||
from ..import_utils import is_rich_available, is_vllm_available
|
||||
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
||||
from .callbacks import SyncRefModelCallback
|
||||
@ -729,6 +729,7 @@ class GRPOTrainer(Trainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
all_outputs = self.llm.generate(
|
||||
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
|
||||
)
|
||||
@ -812,6 +813,13 @@ class GRPOTrainer(Trainer):
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
):
|
||||
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
||||
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
with profiling_context(self, reward_func_name):
|
||||
if isinstance(
|
||||
reward_func, nn.Module
|
||||
): # Module instead of PretrainedModel for compat with compiled models
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
|
Reference in New Issue
Block a user