Add torch_dtype to model kwargs in reward modeling example (#2266)

Update model_kwargs to include torch_dtype.

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Cameron Chen
2024-10-24 14:12:26 -04:00
committed by GitHub
parent 9c376c571f
commit c2bb1eed14

View File

@ -81,6 +81,7 @@ if __name__ == "__main__":
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=False if training_args.gradient_checkpointing else True,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True