mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Add torch_dtype
to model kwargs in reward modeling example (#2266)
Update model_kwargs to include torch_dtype. Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
@ -81,6 +81,7 @@ if __name__ == "__main__":
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
|
||||
|
Reference in New Issue
Block a user