Integrate OrpoTrainer with PyTorchXLA for faster step time on TPUs (#2001)

* make Orpotrainer run faster on tpu

* less data transfer

* train-trl.py

* fix

* set device_map=auto

* add is_torch_xla_available guards

* delete file

* address comments

* make presubmit

* Update transformer version in setup.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
wenxindongwork
2024-09-11 06:11:28 -07:00
committed by GitHub
parent 37934d70a9
commit e2966c8d99
3 changed files with 31 additions and 15 deletions

View File

@ -63,7 +63,7 @@ __version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc
REQUIRED_PKGS = [
"torch>=1.4.0",
"transformers>=4.31.0",
"transformers>=4.39.0",
"numpy>=1.18.2;platform_system!='Windows'",
"numpy<2;platform_system=='Windows'",
"accelerate",