mirror of
https://github.com/huggingface/trl.git
synced 2025-11-05 13:44:28 +08:00
Compare commits
1065 Commits
v0.11.0
...
docs/unify
| Author | SHA1 | Date | |
|---|---|---|---|
| 9bf8db4887 | |||
| 5dfb2db0c1 | |||
| c34de94903 | |||
| 0d5711040e | |||
| 4995b24b24 | |||
| 1cb0161ce7 | |||
| 91e7cdc3b8 | |||
| 800a4d928a | |||
| 6f906d5087 | |||
| 4677cf293e | |||
| 7a9592bc8c | |||
| 7f15a7f629 | |||
| 8b0a3ce7c7 | |||
| d9f9e2b1a9 | |||
| 4e138ab922 | |||
| 43253b2ae4 | |||
| 6f41b18e49 | |||
| 8d64144a23 | |||
| 91e540ce09 | |||
| 7347a10f1d | |||
| 6eb8d46a38 | |||
| 2a6408020b | |||
| bb057d15d9 | |||
| 580c6bb951 | |||
| 41c8ca1ad3 | |||
| 5cefb39fe2 | |||
| 50b96e25a8 | |||
| 3d718df9a9 | |||
| 77e4cd3420 | |||
| 6f8121e477 | |||
| 414cb7dd6d | |||
| ad9d9c927b | |||
| 095544e7a3 | |||
| 06c059bab8 | |||
| f6834206a8 | |||
| 0aef77b4a5 | |||
| 519cdf36eb | |||
| b3bf53f957 | |||
| c26b375ca3 | |||
| a8f70b02e1 | |||
| 1c2322eb7d | |||
| 242de1ee1e | |||
| caaf656271 | |||
| 9925469170 | |||
| 4e9ab9fa6e | |||
| b82a8f401e | |||
| 29fb69f033 | |||
| ac6cea80a3 | |||
| 1e39eb6c5a | |||
| 97830a3cc2 | |||
| d2754185db | |||
| 61bf96cd22 | |||
| b8f23ef3bd | |||
| f8073cba7d | |||
| 55854c8db5 | |||
| 4352074093 | |||
| 928f589746 | |||
| b0889d2188 | |||
| a9d33d052b | |||
| 34fdb6154b | |||
| a23e91c868 | |||
| 5e691d1bf8 | |||
| fa644b1bdf | |||
| fda88c642e | |||
| 2a138c7363 | |||
| 05a1feb050 | |||
| d8543c02b0 | |||
| 23c0062449 | |||
| 47b1aa7757 | |||
| a4872d97a8 | |||
| 3f66564804 | |||
| 9b80e336b3 | |||
| 2819a8f812 | |||
| e1c87e3589 | |||
| 7c547a37b0 | |||
| bfd6f49105 | |||
| 712f6a9c43 | |||
| 1382e564b5 | |||
| cb9bc2acce | |||
| 475c732526 | |||
| 0dc4d53736 | |||
| e2ab435487 | |||
| 46a53cd03b | |||
| 61050401ca | |||
| 5eae44a97c | |||
| 28bba8c6b1 | |||
| 2f1802bc6e | |||
| e0eec055b4 | |||
| f4c554da22 | |||
| a932e2796d | |||
| 04fd1203af | |||
| 19d2f97932 | |||
| 31caf64778 | |||
| 8e2d5516ca | |||
| 94aac4a101 | |||
| 26b7c2507e | |||
| aa25c2697c | |||
| 93c7d88563 | |||
| c7c041ecc8 | |||
| ef40c047aa | |||
| 7e0adbc552 | |||
| 773afd9314 | |||
| 966b397201 | |||
| 927cf6ba46 | |||
| 56cb6ccf76 | |||
| 49c8f14b06 | |||
| cefbacb30e | |||
| fae245a062 | |||
| 2aa9506c69 | |||
| d6eeb290d9 | |||
| 1684ef279a | |||
| aab21eb5e7 | |||
| b997a31981 | |||
| 86d1963cc1 | |||
| 039d526d24 | |||
| bcd059a384 | |||
| 0e57b4a9df | |||
| 98488e0946 | |||
| f45e86571b | |||
| f5827928a0 | |||
| f853e091ea | |||
| 803ec0d856 | |||
| 7a0a615d50 | |||
| c38cb69ec7 | |||
| 68ef15c686 | |||
| 3dd7fc2850 | |||
| 51ced65153 | |||
| 4bb883a6e6 | |||
| f7846321e7 | |||
| a944890ff1 | |||
| 521db3520a | |||
| e2c97a805a | |||
| d1d0407d3c | |||
| 824ff8c73e | |||
| f15399d3d3 | |||
| cc578b6b14 | |||
| 30cf68a97b | |||
| 452284b8dc | |||
| 6be53e19bc | |||
| 3080fc1bd7 | |||
| 5d870955f8 | |||
| 8265800abf | |||
| 65eb45c32b | |||
| ae6837f8d4 | |||
| 56a8f1128b | |||
| 529101537f | |||
| 0588b1f01d | |||
| 45ee98b05e | |||
| 3800a6ecc7 | |||
| 7ad9ce8acc | |||
| 0c2dc14014 | |||
| ced8b337ba | |||
| 1eff7da9e0 | |||
| 1cbfb00b6a | |||
| e086f073cf | |||
| e5d437ed76 | |||
| d1b4691900 | |||
| 39c603872f | |||
| 5a4021f23e | |||
| ea66a9e650 | |||
| da209f89fc | |||
| ebb8899f5d | |||
| 70e2017dbc | |||
| 4368f54c97 | |||
| 22720d176b | |||
| c8a5add88a | |||
| a7b54f988b | |||
| 78bf77abbd | |||
| 3b9ac65a05 | |||
| 7a78320f58 | |||
| 67e83aee90 | |||
| a0df357591 | |||
| 864e593e9f | |||
| 6428647063 | |||
| 8a5bfecc3a | |||
| 910aeebe06 | |||
| e208823b3e | |||
| f397a61e82 | |||
| 7fe9dd42ac | |||
| 79c774af54 | |||
| 9603b41d7e | |||
| 5ee56ed04f | |||
| e85e634bff | |||
| d633c4337f | |||
| d1e24df031 | |||
| 094e0760d4 | |||
| 01c9b4c414 | |||
| 18faf03c4e | |||
| d144e73e78 | |||
| be1ffe59d2 | |||
| fb6bdab33b | |||
| 526303edbd | |||
| 9e5e60c933 | |||
| 5c52f46f9a | |||
| deac14a39f | |||
| 3d5a30bb77 | |||
| 251fdb228a | |||
| 37806e618b | |||
| 008c7ad9aa | |||
| e8ba9eaf27 | |||
| abe07c9e32 | |||
| fe02ea2b52 | |||
| 68408d7219 | |||
| 94f8d00a62 | |||
| b5ca3799ad | |||
| a68b4af50f | |||
| 9f0ed8b130 | |||
| 27f22ba5a1 | |||
| 86f74b486f | |||
| 26b497ea63 | |||
| d22bdb8031 | |||
| 0e204482e6 | |||
| 3c8d7209f1 | |||
| 0450f05ad9 | |||
| 7e2075347e | |||
| 20cc58d777 | |||
| a6c0c57f6b | |||
| 10dc36d610 | |||
| d2d1912d96 | |||
| 08ea00289a | |||
| 4ff8b4e007 | |||
| 6356343fd2 | |||
| 45e59f77ea | |||
| 4bd4acf172 | |||
| 8380869d33 | |||
| 5139af3712 | |||
| 2f46c18a66 | |||
| e2b18ec4e7 | |||
| 78f1a928ce | |||
| 1d0b196f6b | |||
| 5a1c2f9b3b | |||
| 9955ee7eaa | |||
| 304eaf8053 | |||
| 69e288ebad | |||
| d655ce48f8 | |||
| 91c4bba922 | |||
| 2845d024a4 | |||
| f4ff248407 | |||
| b8eb5c5d2d | |||
| 07f9ad982d | |||
| 417915a3e4 | |||
| 44ddc28bcd | |||
| e8b8499f1f | |||
| 7eb7f42372 | |||
| 6adfd138d8 | |||
| a647e5a78a | |||
| 816ac610c0 | |||
| 373a64a7ce | |||
| 09e19244c0 | |||
| a228cb51d1 | |||
| 6c6f13b5f3 | |||
| b3f9f613f9 | |||
| 659d2c1284 | |||
| 82b34e5723 | |||
| 27e30f86ef | |||
| af82b38482 | |||
| 1b799a23c1 | |||
| e4ebf3ba11 | |||
| e458df650a | |||
| a1ee7d2182 | |||
| 1d06757e57 | |||
| 4f9009b0f2 | |||
| c9484b161f | |||
| f5c2fec4a9 | |||
| d1bf56020d | |||
| 19f9b9ee69 | |||
| 1eb38018b7 | |||
| deae7e00b8 | |||
| 0c69fd2867 | |||
| b5fd290b2c | |||
| 67991605c0 | |||
| 208e9f7df7 | |||
| 3bfa981bd2 | |||
| 6a5dfffe56 | |||
| 18633dbb06 | |||
| e4dbf57bf2 | |||
| 12fc85fd13 | |||
| fdd6bda111 | |||
| cb84da0ece | |||
| 35702ce378 | |||
| 705306d78b | |||
| edbe8234bc | |||
| 4c47b32811 | |||
| 92046bb972 | |||
| 39faf36a91 | |||
| 1cb4150dfb | |||
| 3a6b365c0d | |||
| 7ae16d3234 | |||
| ab984fabac | |||
| 419d716a6b | |||
| f538bd3085 | |||
| 8aa0eed816 | |||
| e7b37d4e8d | |||
| b7676d1701 | |||
| 515e9eb255 | |||
| 26442abff2 | |||
| 0c91515b58 | |||
| 4b3517facc | |||
| 6f5865131b | |||
| 0c7ab76a01 | |||
| ffc061b5e5 | |||
| 38fc1f6ecf | |||
| 39cc9a826a | |||
| 1f15f187c3 | |||
| 181a841877 | |||
| da167d88b2 | |||
| 2324245cad | |||
| fe44806b68 | |||
| 251c0488c8 | |||
| e2eaa2334d | |||
| 48d7ecc67b | |||
| 215294872e | |||
| 85ead751f5 | |||
| 8793a46760 | |||
| 730e19d939 | |||
| 7233b981ce | |||
| 18836f078e | |||
| e575ea3815 | |||
| 52eaa552aa | |||
| 0227d68e50 | |||
| b08bc7f33e | |||
| 152235a8e5 | |||
| 4fcef6c32d | |||
| d15049bf71 | |||
| b9718449a8 | |||
| 0e7c99ab07 | |||
| c99cd2361e | |||
| 68937969b4 | |||
| a6f802f41d | |||
| dfb96af810 | |||
| 485e7d1c74 | |||
| 7ee8f796ff | |||
| 64b7028fe9 | |||
| 1324448c6f | |||
| 206964ce16 | |||
| 39efa8affb | |||
| 499d9fb32c | |||
| 44e6c153a5 | |||
| f5b1ed24a0 | |||
| 7f53ac08f2 | |||
| b4c418110c | |||
| 80b660de76 | |||
| 65d7894b6a | |||
| 72d4d82b8c | |||
| de27d612b0 | |||
| a222aeb462 | |||
| cb95323429 | |||
| 2fb7090231 | |||
| f23543fc96 | |||
| d3f63ca292 | |||
| ad0b9dae1e | |||
| f3289be384 | |||
| f9b0947155 | |||
| 46d09bd240 | |||
| 17393b8c82 | |||
| 21060b25a5 | |||
| 5d914a4125 | |||
| 67763762bc | |||
| 072d7dd5a6 | |||
| ead5aaf934 | |||
| dbbc770f45 | |||
| 294e8cb093 | |||
| 79c5797d92 | |||
| ab2400029a | |||
| 3ae60cd1b4 | |||
| 9a1e6a4508 | |||
| 90c7876da5 | |||
| 72bbc6dd0d | |||
| 25ce0f31ae | |||
| 9269f9f151 | |||
| eb5d0fe484 | |||
| 30576d2ddc | |||
| 5522cc0a3f | |||
| 303d3b1d63 | |||
| 3d765b0702 | |||
| fcd3e0fd15 | |||
| 8a23c866f8 | |||
| 5bb3ca4b21 | |||
| fd70021cd7 | |||
| a902450e85 | |||
| 03034317d0 | |||
| 23ea671c5e | |||
| fc08f55518 | |||
| 2f4cb38f28 | |||
| eee9ec94ef | |||
| a043fd74a3 | |||
| d16b960dfa | |||
| daad892730 | |||
| 097d6153a2 | |||
| bc3eebb73e | |||
| 1fb115daff | |||
| 3a40f18192 | |||
| 56f4201db6 | |||
| a50bdc6388 | |||
| e102ac8df1 | |||
| d870230218 | |||
| 68ce3a3f07 | |||
| 5787f3bf63 | |||
| 116ec493fa | |||
| 1b17fa78ae | |||
| c389599057 | |||
| e333da8cf0 | |||
| c8347b4287 | |||
| 8684cb4666 | |||
| 508d551db1 | |||
| 569d60e999 | |||
| 640a9f3916 | |||
| 5a2b04a699 | |||
| dffd1acb94 | |||
| 43e6b24e70 | |||
| 2ae43f80d9 | |||
| c949b66f01 | |||
| 97085539a3 | |||
| 68ed863eed | |||
| 0462dd7f12 | |||
| 68db24e010 | |||
| 2d086f26a5 | |||
| b674989f15 | |||
| 0353d67661 | |||
| d98d53983b | |||
| c30344e9ee | |||
| db19d79e30 | |||
| e8abe03a06 | |||
| 7eb52c1b4e | |||
| 686cd35a72 | |||
| 601a25693e | |||
| d42188b17f | |||
| 4ccc5ca7bd | |||
| d1e116c67d | |||
| 90cdf96418 | |||
| b520378b97 | |||
| e04f7eb3b9 | |||
| 02cce41d06 | |||
| 6a6d4345c9 | |||
| 79ec242aef | |||
| 7e8ef867ae | |||
| 32df09358e | |||
| 0336e4bcbb | |||
| ab331bfd56 | |||
| 84d7b5bbfa | |||
| b40c959c00 | |||
| 34fa6b9af2 | |||
| eef7a43427 | |||
| 89c699f598 | |||
| 559a99f053 | |||
| 5b3ea9dd43 | |||
| c262674ea7 | |||
| 5c3dd3ab24 | |||
| 4c92de0000 | |||
| 67f17f7ea4 | |||
| 37a71e82bf | |||
| b0958c6f8f | |||
| 8bad863ffa | |||
| d00441505d | |||
| 9554c2f319 | |||
| 712afd5dd1 | |||
| 086e9d56e3 | |||
| 5206c927f6 | |||
| e4b586a389 | |||
| 0576346758 | |||
| e63588a56a | |||
| d9d25a71b2 | |||
| 58ea227d4c | |||
| a768484d47 | |||
| d17ec7ad72 | |||
| ed9b78a5f7 | |||
| d6a969ff7d | |||
| 8a235a9b71 | |||
| afa06c3b56 | |||
| 77ec43ce31 | |||
| 4126803875 | |||
| 91b3f5ee9a | |||
| b6e255a9d3 | |||
| 0d54f05fa3 | |||
| 72c91e77f5 | |||
| 32ffa1170e | |||
| fd4c9e3b72 | |||
| c5e64b479b | |||
| 15ff54790b | |||
| 3d077fd3de | |||
| 53c4a7c2b8 | |||
| aff16a5b2f | |||
| 1314aac502 | |||
| e99a8aec4b | |||
| b9572737b4 | |||
| 4cafb2744a | |||
| c49c7b7d4e | |||
| b773a4c191 | |||
| 7c8355d038 | |||
| 50a2fa8ec8 | |||
| 0333108854 | |||
| 6ffde23a45 | |||
| 6f288c2d9d | |||
| 8cf6220cef | |||
| da7b3fe745 | |||
| 24ef9eb8e7 | |||
| b0eff324aa | |||
| 026fc9439c | |||
| a912ad1bcf | |||
| fef915e36f | |||
| 0db63f0f50 | |||
| 7359ddcc6f | |||
| 0844936930 | |||
| 897c87fa91 | |||
| c13de6f9c0 | |||
| 722847abbc | |||
| ef4b0b225c | |||
| 8e8e62b380 | |||
| 824100ce25 | |||
| 4e7f0a5eb9 | |||
| 17a9069710 | |||
| cb07c44920 | |||
| 0b6a1874f1 | |||
| ac18c9d532 | |||
| d1174adc5b | |||
| cd838417e4 | |||
| c7e3f096a5 | |||
| 5c08897570 | |||
| 3ef9faf257 | |||
| 9ac614fb08 | |||
| 29401e790e | |||
| 31bf3f9244 | |||
| 7f32792c07 | |||
| 3d8727918a | |||
| 65245f6be8 | |||
| a528b9c465 | |||
| e0dd525021 | |||
| 64aa06499b | |||
| be93a0c30c | |||
| f9fbd91ea9 | |||
| 54d4f6b13a | |||
| 05bc43e960 | |||
| d3dc8ff654 | |||
| 21738c3732 | |||
| eab175d434 | |||
| 4da4dc9117 | |||
| 6b3a02385d | |||
| abbbb93d6a | |||
| cafa663c84 | |||
| fd04a5461a | |||
| 56e5766205 | |||
| 89d44caece | |||
| adfa7fd59a | |||
| cf5183db7f | |||
| 1954c02d86 | |||
| 45f4c58832 | |||
| cc044e35b2 | |||
| 999acd53ec | |||
| 8606b1ad09 | |||
| a673da5773 | |||
| 00b8e311aa | |||
| c163cf5081 | |||
| bc9c019c43 | |||
| 18596cf232 | |||
| 280d35301b | |||
| 13fa8402a3 | |||
| 09b669fbf7 | |||
| 01d0be15cb | |||
| 3a42af1c78 | |||
| aaf39604ba | |||
| 2bf48478e8 | |||
| a8cfca6d01 | |||
| 1bca49515e | |||
| 39e96394a9 | |||
| 8e6ed93dfd | |||
| 29c5e05e3a | |||
| a9b27f82d6 | |||
| cd6b3de356 | |||
| 36685c8bba | |||
| 89556c8cbf | |||
| f3e8c23044 | |||
| 9ee6c3aa56 | |||
| ef05331752 | |||
| 05e2ba6e01 | |||
| 1b4f189e09 | |||
| 1faa7f9b36 | |||
| 66e6eab9bb | |||
| 27af0aaf4a | |||
| b4ffda769e | |||
| 0dad4eb7ca | |||
| c82f626f94 | |||
| 33add19161 | |||
| 294f35bf3c | |||
| 9874b3aa04 | |||
| 1e61f6cc5a | |||
| 27adc30162 | |||
| df737f99c1 | |||
| c04e84c454 | |||
| d625c5533a | |||
| 6cdd24a360 | |||
| 8b38570258 | |||
| 95b1a9f612 | |||
| 5c1511423b | |||
| 5e2e9cb442 | |||
| 227df8271e | |||
| ae1581474e | |||
| 47b9515fb1 | |||
| c4891dcfee | |||
| 055cee255a | |||
| 73a2fb0554 | |||
| 982ba08092 | |||
| e03e7acc5c | |||
| 9df19e8a75 | |||
| 1d7b8c4f70 | |||
| 7e170612a4 | |||
| 559724ee2c | |||
| a5a46725c8 | |||
| b6bcafb8bb | |||
| 4bfb8eb0d1 | |||
| 4d66bad208 | |||
| e90117b3e1 | |||
| 31b54a6237 | |||
| 17e33cdaa0 | |||
| 5a0cebc786 | |||
| 65308cfd84 | |||
| 1755e03f6f | |||
| 793735a698 | |||
| e70a0efeca | |||
| 7eaca76ed1 | |||
| 657f9ce6ee | |||
| 485852c942 | |||
| 9f3702f6be | |||
| e751a16df5 | |||
| 582bc5684b | |||
| c5ba70d4fc | |||
| 5b586da3cc | |||
| 488025cd87 | |||
| 2594cb39de | |||
| 2fe2337067 | |||
| f6b4d6e569 | |||
| 26d86757a7 | |||
| 9771f259ed | |||
| 7bdedd4075 | |||
| a069a2f19c | |||
| ea45f513f3 | |||
| a91023990a | |||
| 1a9387b922 | |||
| 1884ff1bb8 | |||
| bfe2075608 | |||
| 6067e2a669 | |||
| dee37342a8 | |||
| 8037f18cdf | |||
| a0a53171cc | |||
| 23a635ed61 | |||
| 9b38b0b5ee | |||
| 0f26049ea2 | |||
| 7511aa4e36 | |||
| f713f614e9 | |||
| a34987956c | |||
| 0f88c179e3 | |||
| beda4328cc | |||
| 07cfe1677e | |||
| 9f7755d8ed | |||
| 4e3f569eb8 | |||
| 979fda1548 | |||
| f6fb6a88a9 | |||
| 6cbf8fbc9f | |||
| 5cb390cd30 | |||
| b3c391e628 | |||
| 1b85ca6147 | |||
| e7a1290b0a | |||
| 3822edd67b | |||
| 230455cab0 | |||
| 08f014d559 | |||
| 10740333bd | |||
| 058a733c30 | |||
| 3f193972d8 | |||
| b575596b89 | |||
| 118c43f0e0 | |||
| 40b1c33edf | |||
| 1a2e74cc5a | |||
| 80f7dcb16d | |||
| 4404ccd24a | |||
| 39f77ca2d8 | |||
| 52085dd96b | |||
| c7a1c95017 | |||
| 3003058418 | |||
| a759cee2e0 | |||
| 0a3bad44f0 | |||
| bb5b96a823 | |||
| 8466c7273e | |||
| a871ec8e91 | |||
| f7572221db | |||
| 8ec2e42833 | |||
| 218d493d11 | |||
| 1a9f78eb3a | |||
| a10978ebdf | |||
| 87fbb831d3 | |||
| 52f39d6a24 | |||
| 931f7a14d2 | |||
| 9951105a90 | |||
| 5a6e23aac9 | |||
| d9104c8b0d | |||
| d5a5840307 | |||
| f3cbd41e2c | |||
| d41a32f619 | |||
| fc4dae256d | |||
| e4e5671e80 | |||
| 7c76f103da | |||
| aad18ef52a | |||
| b55d9f0412 | |||
| 4871c82b0c | |||
| fd9e5a7cab | |||
| 5463e49a55 | |||
| 22759c8208 | |||
| 2ee6fd369f | |||
| 844a9c665f | |||
| 04f6597377 | |||
| e3244d2d09 | |||
| 6a02c69789 | |||
| a1c58aa42a | |||
| 3f0695a4ca | |||
| a72b50b772 | |||
| ea1d9be2a7 | |||
| 402187baab | |||
| 5858ceab7e | |||
| 7442d42c21 | |||
| 98de0e7c62 | |||
| 491921c1a4 | |||
| ad6a35bdd5 | |||
| 7bc9858a8f | |||
| b882f57d93 | |||
| ac7bde5832 | |||
| 3d94e4e25c | |||
| 1a303cca8e | |||
| ac327d5e84 | |||
| c0854c32c9 | |||
| aa18ecfde7 | |||
| 6849c050b9 | |||
| 27a6f2201b | |||
| f074dcdc86 | |||
| 0caff61600 | |||
| 019fc6dbaa | |||
| 69ad852e56 | |||
| 45ccdefac4 | |||
| 703484a8c2 | |||
| 9b76d5f2e9 | |||
| cbe0681ba1 | |||
| 4e0cf01aef | |||
| 5c05913196 | |||
| caba04da42 | |||
| be5a088337 | |||
| 38861475e6 | |||
| f69707dab4 | |||
| 76f00fc394 | |||
| 8453017622 | |||
| 3608709529 | |||
| 21f0055893 | |||
| 013d360b8f | |||
| e5ae703d35 | |||
| a92e00e810 | |||
| 9b3c5bf64f | |||
| 15fec312d5 | |||
| be1e34003c | |||
| 6aaf379a82 | |||
| 49adf74833 | |||
| 6c54f023ae | |||
| 963243a7d1 | |||
| aafd8cbea5 | |||
| 822653824b | |||
| ba036576d4 | |||
| 293b620950 | |||
| ae3bd0d07a | |||
| 6d9fc11fd6 | |||
| ffcb9f4aee | |||
| 00e5889380 | |||
| 5c9cf2003d | |||
| 8830786a23 | |||
| b0f513c13d | |||
| 81221661c6 | |||
| 7347c292c3 | |||
| 2106b31298 | |||
| 9b67eea473 | |||
| e752fc6c2e | |||
| 674bb75f59 | |||
| b9df81045b | |||
| 55e680e142 | |||
| 09eefa73ab | |||
| 7fdb69aa7d | |||
| 5b9236d1e8 | |||
| 82d12eb751 | |||
| 84d73fd00b | |||
| 2241f17914 | |||
| cf97133d51 | |||
| 724acb9716 | |||
| 7134a1e73f | |||
| bf6e7edea5 | |||
| e95f9fb74a | |||
| a85768f120 | |||
| 78c5ce23fd | |||
| af4ad47035 | |||
| b2ae99925d | |||
| bd946f93c1 | |||
| f42e34e613 | |||
| 338fbd546b | |||
| 32f8fa8aad | |||
| 1a2276402f | |||
| 1f344c9377 | |||
| 85121fc300 | |||
| bbdd6db17c | |||
| 6e088d165c | |||
| a325a0eec5 | |||
| 0ec1ccd990 | |||
| 1c35a48b50 | |||
| 2ce36ae889 | |||
| bf6919117e | |||
| 265663af6a | |||
| 5ab15d3fef | |||
| fecaa991de | |||
| ab30a01baf | |||
| 6dc278a042 | |||
| 67441bb432 | |||
| 62685fbf20 | |||
| 4197956395 | |||
| 9ac8d9773b | |||
| 094d51b599 | |||
| df8f619ec5 | |||
| 56880ba73d | |||
| 801582ec24 | |||
| ed14ed9043 | |||
| 4659ad916f | |||
| 1123bd0f51 | |||
| 55a329e9f0 | |||
| 4720656654 | |||
| 807046b7d7 | |||
| 317d2d477b | |||
| aeb03cf1a9 | |||
| 2578e95023 | |||
| 6f99f42f72 | |||
| d14f7f3eb2 | |||
| 8e65825d4c | |||
| 5e4d7be0e1 | |||
| f34b70a32e | |||
| 0e216f7411 | |||
| 59c201433c | |||
| 40c238395e | |||
| a1d2955116 | |||
| 887c1f3fa3 | |||
| 949db2357e | |||
| fe4b5efe4e | |||
| a9b54a852e | |||
| d4222a1e08 | |||
| a5c88d6c75 | |||
| b6a084c46e | |||
| d9f056862f | |||
| 3d2c1e49b1 | |||
| 5fd78367ae | |||
| 0f5ffad26e | |||
| 88514d51e3 | |||
| 76837e82b9 | |||
| 35553930da | |||
| fd4b283b82 | |||
| 1b1140aa69 | |||
| 4c7eb6fe29 | |||
| 564fc86759 | |||
| 3215a1c586 | |||
| cdc16f3ac6 | |||
| 2ecd53ad77 | |||
| 5877786b5a | |||
| 57d9a97394 | |||
| 751fb1d84b | |||
| edabe0a2d8 | |||
| abfffc510b | |||
| ed7de87dc7 | |||
| beb892bfe0 | |||
| f2d42fa0c2 | |||
| d6a7e9d6f5 | |||
| 451677203d | |||
| 2f25f54ab9 | |||
| a50124dd3a | |||
| 1d23ecc36f | |||
| 52d213173f | |||
| d9ee2fd202 | |||
| 763738f457 | |||
| aed5da580e | |||
| 99451b421a | |||
| 5239b9462d | |||
| 8fb267ff1e | |||
| 2e1adbb6ff | |||
| b668048fe1 | |||
| 8c49ea39ec | |||
| 88ad1a099c | |||
| 9908dda6d9 | |||
| 5e204e1eaa | |||
| 82cfeb8930 | |||
| 0fe73a8af5 | |||
| 33fb9efc43 | |||
| f68d11f9f9 | |||
| aeca63774f | |||
| 117c6d4b52 | |||
| 6d4ed070f1 | |||
| cd7156fb34 | |||
| ca850be0a2 | |||
| 179ba53671 | |||
| e3e171a26b | |||
| b3aff441ff | |||
| efc687db62 | |||
| f2e362656c | |||
| c9c4f18039 | |||
| 460e780265 | |||
| 7ba118a229 | |||
| 6a05feff02 | |||
| 2f72f47191 | |||
| 9410874787 | |||
| 9c5388b69e | |||
| b02189aaa5 | |||
| 52201d3c18 | |||
| 9ff79a65e3 | |||
| 9001a8682c | |||
| f6f42651e2 | |||
| 148b592313 | |||
| d6a8f2c2f6 | |||
| 8d9cfaafeb | |||
| 94e4135a17 | |||
| ac267781ec | |||
| 2c6e0d9705 | |||
| e1d781353b | |||
| a34e9bf84f | |||
| c10cc8995b | |||
| 9368dccef6 | |||
| 43df3a485a | |||
| baee06f2e8 | |||
| bbd8cbb720 | |||
| 4f937c7629 | |||
| 16fa13ce72 | |||
| 453db5cd79 | |||
| ee3cbe1946 | |||
| 17e8060984 | |||
| 163695e85c | |||
| 672c96546d | |||
| bdeb117320 | |||
| 6578fdc101 | |||
| a0066f47f8 | |||
| 5626806aef | |||
| bb0afc2459 | |||
| 066fc37bd3 | |||
| b80c1a6fb8 | |||
| b5eabbeb07 | |||
| cbf9abcd07 | |||
| 6f8fe59aeb | |||
| 1293f37c5f | |||
| e7870dd5d6 | |||
| 21d5baf338 | |||
| 76dbb1a576 | |||
| b8c9d9c7bc | |||
| 623963126b | |||
| 2d24d35013 | |||
| dde20b23cf | |||
| 015321e135 | |||
| 454f36d951 | |||
| 9b7f9f3519 | |||
| 518e29ca9c | |||
| ac7b6cfdfa | |||
| 0238d96c6f | |||
| c86b51cd12 | |||
| ac77c09223 | |||
| 7f2ccbe3a2 | |||
| 74e20cbbbc | |||
| 27b9e3a93f | |||
| dc2b8b9e90 | |||
| 5e90682836 | |||
| 3b439967f4 | |||
| 2f34a161cd | |||
| 6138439df4 | |||
| d57a181163 | |||
| 73c3970c1f | |||
| 013a32b396 | |||
| 24fb32733f | |||
| bb56c6e6af | |||
| 06be6f409a | |||
| b2696578ce | |||
| 0ce3b65928 | |||
| e155cb8a66 | |||
| ea7a1be92c | |||
| 110d0884c7 | |||
| 57ba9b93aa | |||
| 0de75b26f2 | |||
| e615974a03 | |||
| c2bb1eed14 | |||
| 9c376c571f | |||
| 16994738d0 | |||
| 99225bb6d6 | |||
| 88be2c07e5 | |||
| f2349d2af0 | |||
| d843b3dadd | |||
| 84dab850f6 | |||
| 92f6d246d3 | |||
| 31b7820aad | |||
| b9aa965cce | |||
| a67f2143c3 | |||
| 494b4afa10 | |||
| 02f4e750c0 | |||
| 2ba3005d1c | |||
| 7e394b03e8 | |||
| 14f3613dac | |||
| 5e24101b36 | |||
| b81a6121c3 | |||
| 7f0d246235 | |||
| 70036bf87f | |||
| d0aa421e5e | |||
| 5375d71bbd | |||
| 6004e033a4 | |||
| f436c3e1c9 | |||
| cd1aa6bdcc | |||
| b3f93f0bad | |||
| 6c32c8bfcd | |||
| 3107a40f16 | |||
| 419791695c | |||
| 7e5924d17e | |||
| ed9ea74b62 | |||
| 511c92c91c | |||
| c6cb6353a5 | |||
| adb3e0560b | |||
| adf58d80d0 | |||
| 9aa022503c | |||
| 82ad390caf | |||
| ac038ef03a | |||
| 51ca76b749 | |||
| 7005ab4d11 | |||
| ffb1ab74ba | |||
| 47d08a9626 | |||
| 70327c18e6 | |||
| f05c3fa8fc | |||
| 4799ba4842 | |||
| d45c86e2a7 | |||
| c6b0d1358b | |||
| 3321084e30 | |||
| a9cffc7caf | |||
| 32a928cfc2 | |||
| 1a3bb372ac | |||
| d4564b7c64 | |||
| 1be4d86ccc | |||
| 78249d9de4 | |||
| 5c21de30ae | |||
| 0a566f0c58 | |||
| de3876577c | |||
| 1201aa61b4 | |||
| c00722ce0a | |||
| 124189c86a | |||
| d5eeaab462 | |||
| 5368be1e1e | |||
| b169e1030d | |||
| 9af4734178 | |||
| a0d714949f | |||
| a0e28143ec | |||
| 32d9d34eb1 | |||
| fb1b48fdbe | |||
| b5e4bc5984 | |||
| 7a24565d9d | |||
| 44a06fc487 | |||
| a84fc5d815 | |||
| 80038a5a92 | |||
| cece86b182 | |||
| d005980d8b | |||
| cc23b511e4 | |||
| 2cad48d511 | |||
| 6859e048da | |||
| 92eea1f239 | |||
| 663002f609 | |||
| 44d998b2af | |||
| 9b80f3d50c | |||
| 2038e52c30 | |||
| 10c2f63b2a | |||
| 9fb871f62f | |||
| 3cec013a20 |
76
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
76
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -7,36 +7,7 @@ body:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report! 🤗
|
||||
|
||||
Before you submit your bug report:
|
||||
|
||||
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
|
||||
placeholder: trl version, transformers version, platform, python version, ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
@ -50,18 +21,47 @@ body:
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
value: |
|
||||
```python
|
||||
from trl import ...
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
```
|
||||
|
||||
outputs:
|
||||
|
||||
```
|
||||
Traceback (most recent call last):
|
||||
File "example.py", line 42, in <module>
|
||||
...
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
|
||||
You can get this information by running `trl env` in your terminal.
|
||||
|
||||
placeholder: Copy-paste the output of `trl env`
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
label: Checklist
|
||||
description: |
|
||||
Before submitting, please confirm that you've completed each of the following.
|
||||
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
|
||||
options:
|
||||
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
|
||||
required: true
|
||||
- label: "I have included my system information"
|
||||
required: true
|
||||
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
||||
required: true
|
||||
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
||||
required: true
|
||||
- label: "Any traceback provided is complete"
|
||||
required: true
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -21,8 +21,7 @@ Fixes # (issue)
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes? Here are the
|
||||
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
|
||||
|
||||
19
.github/codeql/custom-queries.qls
vendored
Normal file
19
.github/codeql/custom-queries.qls
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
import codeql
|
||||
|
||||
from WorkflowString interpolation, Workflow workflow
|
||||
where
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
|
||||
interpolation.getStringValue().matches("${{ github.event.* }}") and
|
||||
(
|
||||
step.getKey() = "run" or // Injection in run
|
||||
step.getKey() = "env" or // Injection via env
|
||||
step.getKey() = "with" // Injection via with
|
||||
)
|
||||
select workflow, "🚨 Do not use directly as input of action"
|
||||
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@ -14,6 +14,5 @@ jobs:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: trl
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
2
.github/workflows/build_pr_documentation.yml
vendored
2
.github/workflows/build_pr_documentation.yml
vendored
@ -9,10 +9,10 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: github.event.pull_request.draft == false
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: trl
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
|
||||
26
.github/workflows/codeQL.yml
vendored
Normal file
26
.github/workflows/codeQL.yml
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
name: "CodeQL Analysis - Workflows"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: "Analyze GitHub Workflows"
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
security-events: write
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: "Checkout repository"
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Initialize CodeQL"
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: "yaml"
|
||||
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
|
||||
|
||||
- name: "Perform CodeQL Analysis"
|
||||
uses: github/codeql-action/analyze@v2
|
||||
89
.github/workflows/docker-build.yml
vendored
89
.github/workflows/docker-build.yml
vendored
@ -1,95 +1,84 @@
|
||||
name: Build Docker images (scheduled)
|
||||
name: Build TRL Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
trl-latest:
|
||||
name: "Latest TRL GPU"
|
||||
trl:
|
||||
name: "Build and push TRL Docker image"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get TRL version from PyPI
|
||||
run: |
|
||||
VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
- name: Build and Push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-latest-gpu
|
||||
context: docker/trl
|
||||
push: true
|
||||
tags: huggingface/trl-latest-gpu
|
||||
tags: |
|
||||
huggingface/trl:${{ env.VERSION }}
|
||||
huggingface/trl
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-latest-gpu Docker Image build
|
||||
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
title: 🤗 Results of the TRL Dev Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
trl-source:
|
||||
name: "Latest TRL + HF ecosystem from source"
|
||||
trl-dev:
|
||||
name: "Build and push TRL Dev Docker image"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
- name: Build and Push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-source-gpu
|
||||
context: docker/trl-dev
|
||||
push: true
|
||||
tags: huggingface/trl-source-gpu
|
||||
tags: |
|
||||
huggingface/trl:dev
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-source-gpu Docker Image build
|
||||
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
title: 🤗 Results of the TRL Dev Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
15
.github/workflows/issue_auto_labeller.yml
vendored
Normal file
15
.github/workflows/issue_auto_labeller.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
name: "Hugging Face Issue Labeler"
|
||||
on:
|
||||
issues:
|
||||
types: opened
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: August-murr/auto-labeler@main
|
||||
with:
|
||||
hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
|
||||
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
name: PR Style Bot
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
run-style-bot:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v3
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${{ env.PRNUMBER }}"
|
||||
echo "Head Ref: ${{ env.HEADREF }}"
|
||||
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ruff pre-commit
|
||||
|
||||
- name: Download Makefile from main branch
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
|
||||
|
||||
- name: Compare Makefiles
|
||||
run: |
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "No changes in Makefile. Proceeding..."
|
||||
rm -rf main_Makefile
|
||||
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
make precommit || true
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${{ env.HEADREF }}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
||||
43
.github/workflows/publish.yml
vendored
Normal file
43
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,43 @@
|
||||
name: Publish to PyPI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- v*-release
|
||||
paths:
|
||||
- "VERSION"
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Read version
|
||||
id: get_version
|
||||
run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Debug - Show version.txt content
|
||||
run: echo "Version is ${{ steps.get_version.outputs.version }}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.x"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build twine
|
||||
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: ${{ !contains(steps.get_version.outputs.version, 'dev') }}
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
python -m twine upload dist/*
|
||||
98
.github/workflows/slow-tests.yml
vendored
98
.github/workflows/slow-tests.yml
vendored
@ -2,7 +2,7 @@ name: Slow tests (on push)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
# Run only when python files are modified
|
||||
- "trl/**.py"
|
||||
@ -11,88 +11,102 @@ env:
|
||||
RUN_SLOW: "yes"
|
||||
IS_GITHUB_CI: "1"
|
||||
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
TRL_EXPERIMENTAL_SILENCE: 1
|
||||
|
||||
jobs:
|
||||
run_all_tests_single_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install pytest-reportlog
|
||||
|
||||
- name: Run slow SFT tests on single GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
|
||||
- name: Generate Report
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
source .venv/bin/activate
|
||||
uv pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
run_all_tests_multi_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
|
||||
TEST_TYPE: "multi_gpu"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install pytest-reportlog
|
||||
|
||||
- name: Run slow SFT tests on Multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
source .venv/bin/activate
|
||||
uv pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
rm *.txt
|
||||
rm *.txt
|
||||
27
.github/workflows/stale.yml
vendored
27
.github/workflows/stale.yml
vendored
@ -1,27 +0,0 @@
|
||||
name: Stale Bot
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 15 * * *"
|
||||
|
||||
jobs:
|
||||
close_stale_issues:
|
||||
name: Close Stale Issues
|
||||
if: github.repository == 'huggingface/trl'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install PyGithub
|
||||
- name: Close stale issues
|
||||
run: |
|
||||
python scripts/stale.py
|
||||
70
.github/workflows/tests-experimental.yml
vendored
Normal file
70
.github/workflows/tests-experimental.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
name: Tests (experimental)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
# Run only when relevant files are modified
|
||||
- "trl/experimental/**"
|
||||
- "tests/experimental/**"
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
TRL_EXPERIMENTAL_SILENCE: 1
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
name: Check code quality
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.13
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.13
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --all-files
|
||||
|
||||
tests:
|
||||
name: Tests (experimental)
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.13
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.13
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test_experimental
|
||||
46
.github/workflows/tests-main.yml
vendored
46
.github/workflows/tests-main.yml
vendored
@ -1,46 +0,0 @@
|
||||
name: tests on transformers PEFT main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# install PEFT & transformers from source
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the TRL CI on transformers/PEFT main
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
282
.github/workflows/tests.yml
vendored
282
.github/workflows/tests.yml
vendored
@ -1,88 +1,256 @@
|
||||
name: tests
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
# Run only when relevant files are modified
|
||||
- "trl/**.py"
|
||||
- ".github/**.yml"
|
||||
- "examples/**.py"
|
||||
- "scripts/**.py"
|
||||
- ".github/**.yml"
|
||||
- "tests/**.py"
|
||||
- "trl/**.py"
|
||||
- "pyproject.toml"
|
||||
# Exclude if only experimental code/tests
|
||||
- "!trl/experimental/**"
|
||||
- "!tests/experimental/**"
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
name: Check code quality
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.9]
|
||||
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: 3.12
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --all-files
|
||||
|
||||
tests:
|
||||
needs: check_code_quality
|
||||
name: Tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
runs-on: ${{ matrix.os }}
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
fail-fast: false
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# install PEFT & transformers from source
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
tests_no_optional_dep:
|
||||
needs: check_code_quality
|
||||
runs-on: 'ubuntu-latest'
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python ${{ matrix.python-version }} and latest dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_dev:
|
||||
name: Tests with dev dependencies
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install .[test]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_wo_optional_deps:
|
||||
name: Tests without optional dependencies
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[test]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 without optional dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_min_versions:
|
||||
name: Tests with minimum versions
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install accelerate==1.4.0
|
||||
uv pip install datasets==3.0.0
|
||||
uv pip install transformers==4.56.1
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 and minimum dependencies versions
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
67
.github/workflows/tests_latest.yml
vendored
Normal file
67
.github/workflows/tests_latest.yml
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
name: Tests latest TRL release with dev dependencies
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs daily at midnight UTC
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
TRL_EXPERIMENTAL_SILENCE: 1
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: Tests latest TRL release with dev dependencies
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.24-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results of latest TRL with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
5
.github/workflows/trufflehog.yml
vendored
5
.github/workflows/trufflehog.yml
vendored
@ -12,4 +12,7 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
||||
with:
|
||||
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
||||
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,4 +1,3 @@
|
||||
benchmark/trl
|
||||
*.bak
|
||||
.gitattributes
|
||||
.last_checked
|
||||
@ -143,7 +142,4 @@ checklink/cookies.txt
|
||||
# wandb files
|
||||
nbs/wandb/
|
||||
examples/notebooks/wandb/
|
||||
wandb/
|
||||
|
||||
# cli scripts that are symlinked from `examples/scripts`
|
||||
trl/commands/scripts/
|
||||
wandb/
|
||||
@ -1,8 +1,8 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.3
|
||||
rev: v0.13.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
@ -17,6 +17,12 @@ authors:
|
||||
family-names: Thrush
|
||||
- given-names: Nathan
|
||||
family-names: Lambert
|
||||
- given-names: Shengyi
|
||||
family-names: Huang
|
||||
- given-names: Kashif
|
||||
family-names: Rasul
|
||||
- given-names: Quentin
|
||||
family-names: Gallouédec
|
||||
repository-code: 'https://github.com/huggingface/trl'
|
||||
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
|
||||
keywords:
|
||||
@ -25,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: 0.2.1
|
||||
version: "0.24"
|
||||
|
||||
370
CONTRIBUTING.md
370
CONTRIBUTING.md
@ -1,15 +1,10 @@
|
||||
# How to contribute to TRL?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
@ -20,11 +15,9 @@ There are several ways you can contribute to TRL:
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Implement trainers for new post-training algorithms.
|
||||
* Contribute to the examples or to the documentation.
|
||||
* Contribute to the examples or the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
|
||||
@ -33,12 +26,12 @@ For something slightly more challenging, you can also take a look at the [Good S
|
||||
Before you start contributing make sure you have installed all the dev tools:
|
||||
|
||||
```bash
|
||||
make dev
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Fixing outstanding issues
|
||||
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
|
||||
|
||||
## Submitting a bug-related issue or feature request
|
||||
|
||||
@ -48,21 +41,19 @@ Do your best to follow these guidelines when submitting a bug-related issue or a
|
||||
|
||||
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
To get the OS and software versions automatically, run the following command:
|
||||
|
||||
```bash
|
||||
transformers-cli env
|
||||
trl env
|
||||
```
|
||||
|
||||
### Do you want a new feature?
|
||||
@ -74,19 +65,19 @@ If there is a new feature you'd like to see in TRL, please open an issue and des
|
||||
Whatever it is, we'd love to hear about it!
|
||||
|
||||
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
|
||||
3. Provide a *code snippet* that demonstrates the features usage.
|
||||
3. Provide a *code snippet* that demonstrates the feature's usage.
|
||||
4. If the feature is related to a paper, please include a link.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you create it.
|
||||
|
||||
## Do you want to implement a new trainer?
|
||||
|
||||
New post-training methods are published on a frequent basis and those which satisfy the following criteria are good candidates to be integrated in TRL:
|
||||
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
|
||||
|
||||
* **Simplicity:** does the new method achieve similar performance as prior methods, but with less complexity? A good example is [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) (DPO), which provided a simpler and compelling alternative to RLHF methods.
|
||||
* **Efficiency:** does the new method provide a significant improvement in training efficiency? A good example is [Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691v2), which utilises a similar objective as DPO, but requires half the GPU VRAM.
|
||||
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
|
||||
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
|
||||
|
||||
Methods which only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
||||
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
||||
|
||||
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
|
||||
|
||||
@ -102,49 +93,40 @@ Based on the community and maintainer feedback, the next step will be to impleme
|
||||
|
||||
## Do you want to add documentation?
|
||||
|
||||
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links and any missing, unclear or inaccurate content.. We'll be happy to make the changes or help you make a contribution if you're interested!
|
||||
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
TRL. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github handle>/trl.git
|
||||
$ cd trl
|
||||
$ git remote add upstream https://github.com/huggingface/trl.git
|
||||
git clone git@github.com:<your Github handle>/trl.git
|
||||
cd trl
|
||||
git remote add upstream https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
$ git checkout main
|
||||
$ git fetch upstream
|
||||
$ git merge upstream/main
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git merge upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
@ -152,107 +134,277 @@ Follow these steps to start contributing:
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ make dev
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
it with `pip uninstall trl` before reinstalling it.)
|
||||
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
|
||||
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
> For the following commands leveraging the `make` utility.
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
TRL relies on `ruff` to format its source code
|
||||
consistently. After you make changes, apply automatic style corrections and code verifications
|
||||
that can't be automated in one go with:
|
||||
```bash
|
||||
make test
|
||||
```
|
||||
|
||||
This target is also optimized to only work with files modified by the PR you're working on.
|
||||
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
||||
|
||||
If you prefer to run the checks one after the other, the following command apply the
|
||||
style corrections:
|
||||
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
|
||||
|
||||
```bash
|
||||
$ make precommit
|
||||
```
|
||||
To apply these checks and corrections in one step, use:
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
```bash
|
||||
make precommit
|
||||
```
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit
|
||||
```
|
||||
This command runs the following:
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
* Runs additional scripts such as adding copyright information.
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
||||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/main
|
||||
```
|
||||
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
|
||||
|
||||
Push the changes to your account using:
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
|
||||
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
We use `pytest` to run the tests. From the root of the
|
||||
repository here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
$ python -m pytest -sv ./tests
|
||||
python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
In fact, that's how `make test` is implemented (sans the `pip install` line)!
|
||||
That's how `make test` is implemented (without the `pip install` line)!
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
You can specify a smaller set of tests to test only the feature
|
||||
you're working on.
|
||||
|
||||
### Default values guidelines
|
||||
|
||||
1. **Use defaults when appropriate**:
|
||||
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
|
||||
2. **Prioritize proven defaults**:
|
||||
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
|
||||
3. **Ensure safety and predictability**:
|
||||
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
|
||||
4. **Balance consistency and flexibility**:
|
||||
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
|
||||
5. **Opt-in for new features**:
|
||||
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
|
||||
### Writing documentation
|
||||
|
||||
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
|
||||
|
||||
To illustrate what good documentation looks like, here’s an example of a well-documented function:
|
||||
|
||||
````python
|
||||
def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
r"""
|
||||
Replicate a string `n` times with a separator.
|
||||
|
||||
Args:
|
||||
string (`str`):
|
||||
String to replicate.
|
||||
n (`int`):
|
||||
Number of times to replicate the string.
|
||||
sep (`str`, *optional*, defaults to `" "`):
|
||||
Separator to use between each replication.
|
||||
|
||||
Returns:
|
||||
`str`: The replicated string.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> replicate_str("hello", 3)
|
||||
"hello hello hello"
|
||||
>>> replicate_str("hello", 3, sep=", ")
|
||||
"hello, hello, hello"
|
||||
```
|
||||
"""
|
||||
return sep.join([string] * n)
|
||||
````
|
||||
|
||||
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
|
||||
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
||||
* **Type Annotations:**
|
||||
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
||||
```txt
|
||||
defaults to `"foo"`
|
||||
```
|
||||
|
||||
* **Dictionary Typing:**
|
||||
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
|
||||
* **Default Value Formatting:**
|
||||
* Consistently surrounded default values with backticks for improved formatting:
|
||||
|
||||
```txt
|
||||
defaults to `4`
|
||||
```
|
||||
|
||||
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
|
||||
|
||||
```python
|
||||
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
|
||||
r"""
|
||||
Calculates basic statistics for a given dataset.
|
||||
|
||||
Args:
|
||||
> Data inputs
|
||||
|
||||
data (`list[float]`):
|
||||
A list of numerical values to analyze.
|
||||
|
||||
> Configuration parameters
|
||||
|
||||
precision (`int`, *optional*, defaults to `2`):
|
||||
Number of decimal places to round the results.
|
||||
include_variance (`bool`, *optional*, defaults to `False`):
|
||||
Whether to include the variance of the dataset in the results.
|
||||
|
||||
Returns:
|
||||
`dict[str, float]`:
|
||||
A dictionary containing calculated statistics such as mean, median, and optionally variance.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Deprecation and backward compatibility
|
||||
|
||||
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
warnings.warn(
|
||||
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
||||
"Please use the `Trainer.bar` class instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
```
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
||||
### Working with warnings
|
||||
|
||||
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
|
||||
|
||||
#### Definitions
|
||||
|
||||
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
|
||||
#### Choosing the right message
|
||||
|
||||
* **Correct → No warning**:
|
||||
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
||||
|
||||
* **Correct but deserves attention → No warning, possibly a log message**:
|
||||
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
||||
|
||||
```python
|
||||
logger.info("This is an informational message about a rare but correct operation.")
|
||||
```
|
||||
|
||||
* **Correct but very likely a mistake → Warning with option to disable**:
|
||||
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar, _warn=True):
|
||||
if foo == bar:
|
||||
if _warn:
|
||||
logger.warning("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
* **Supported but not correct → Warning**:
|
||||
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
logger.warning("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
* **Not supported → Exception**:
|
||||
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
|
||||
```
|
||||
|
||||
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2020-2025 The HuggingFace Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/accelerate_configs/*.yaml
|
||||
include trl/templates/*.md
|
||||
recursive-exclude * __pycache__
|
||||
prune tests
|
||||
|
||||
39
Makefile
39
Makefile
@ -1,44 +1,19 @@
|
||||
.PHONY: test precommit benchmark_core benchmark_aux common_tests slow_tests test_examples tests_gpu
|
||||
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
|
||||
dev:
|
||||
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
|
||||
pip install -e ".[dev]"
|
||||
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
|
||||
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
|
||||
precommit:
|
||||
pre-commit run --all-files
|
||||
python scripts/add_copyrights.py
|
||||
|
||||
benchmark_core:
|
||||
bash ./benchmark/benchmark_core.sh
|
||||
|
||||
benchmark_aux:
|
||||
bash ./benchmark/benchmark_aux.sh
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
pre-commit run --all-files
|
||||
doc-builder style trl tests docs/source --max_len 119
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
||||
test_experimental:
|
||||
pytest -k "experimental" -n auto -s -v
|
||||
256
README.md
256
README.md
@ -1,228 +1,200 @@
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
> Full stack library to fine-tune and align large language models.
|
||||
<hr> <br>
|
||||
|
||||
<h3 align="center">
|
||||
<p>A comprehensive library to post-train foundation models</p>
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
|
||||
<img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
|
||||
</a>
|
||||
<a href="https://huggingface.co/docs/trl/index">
|
||||
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
|
||||
</a>
|
||||
<a href="https://github.com/huggingface/trl/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
|
||||
</a>
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
||||
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
||||
</p>
|
||||
|
||||
## 🎉 What's New
|
||||
|
||||
## What is it?
|
||||
**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows.
|
||||
|
||||
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
|
||||
Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](openenv).
|
||||
|
||||
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
|
||||
## Overview
|
||||
|
||||
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
|
||||
|
||||
## Highlights
|
||||
|
||||
- **`Efficient and scalable`**:
|
||||
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
|
||||
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
|
||||
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
|
||||
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
|
||||
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
|
||||
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
|
||||
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
|
||||
|
||||
- **Efficient and scalable**:
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
|
||||
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
|
||||
|
||||
## Installation
|
||||
|
||||
### Python package
|
||||
Install the library with `pip`:
|
||||
### Python Package
|
||||
|
||||
Install the library using `pip`:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### From source
|
||||
If you want to use the latest features before an official release you can install from source:
|
||||
|
||||
If you want to use the latest features before an official release, you can install TRL from source:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
### Repository
|
||||
|
||||
If you want to use the examples you can clone the repository with the following command:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
## Quick Start
|
||||
|
||||
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.
|
||||
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
|
||||
|
||||
```python
|
||||
# imports
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
# get dataset
|
||||
dataset = load_dataset("stanfordnlp/imdb", split="train")
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# get trainer
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=512,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
# train
|
||||
### `GRPOTrainer`
|
||||
|
||||
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
||||
|
||||
```python
|
||||
# imports
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
...
|
||||
|
||||
# load trainer
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `PPOTrainer`
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
|
||||
|
||||
```python
|
||||
# imports
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
|
||||
from trl.core import respond_to_batch
|
||||
**SFT:**
|
||||
|
||||
# get models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = create_reference_model(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# initialize trainer
|
||||
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
|
||||
|
||||
# encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
|
||||
|
||||
# get model response
|
||||
response_tensor = respond_to_batch(model, query_tensor)
|
||||
|
||||
# create a ppo trainer
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
|
||||
|
||||
# define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0)]
|
||||
|
||||
# train model for one step with ppo
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
**DPO:**
|
||||
|
||||
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://huggingface.co/papers/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
|
||||
# load trainer
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train()
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details.
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
make dev
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## References
|
||||
## Experimental
|
||||
|
||||
### Proximal Policy Optimisation
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://huggingface.co/papers/1909.08593), [code](https://github.com/openai/lm-human-preferences)].
|
||||
A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice.
|
||||
|
||||
### Direct Preference Optimization
|
||||
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://huggingface.co/papers/2305.18290), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
|
||||
Example:
|
||||
|
||||
```python
|
||||
from trl.experimental.new_trainer import NewTrainer
|
||||
```
|
||||
|
||||
Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental_overview).
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{vonwerra2022trl,
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
||||
title = {TRL: Transformer Reinforcement Learning},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
@ -230,3 +202,7 @@ DPO is based on the original implementation of **"Direct Preference Optimization
|
||||
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This repository's source code is available under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
167
RELEASE.md
Normal file
167
RELEASE.md
Normal file
@ -0,0 +1,167 @@
|
||||
# Making a release
|
||||
|
||||
> [!NOTE]
|
||||
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
||||
|
||||
## Major/Minor Release
|
||||
|
||||
### 1. Ensure your local repository is up to date with the upstream repository
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
||||
|
||||
### 2. Create a release branch from main
|
||||
|
||||
```bash
|
||||
git checkout -b release-v{major}.{minor}
|
||||
```
|
||||
|
||||
### 3. Change the version in the following files
|
||||
|
||||
- `.github/workflows/tests_latest.yml`:
|
||||
|
||||
```diff
|
||||
- with: { ref: v{major}.{minor-1}-release }
|
||||
+ with: { ref: v{major}.{minor}-release }
|
||||
```
|
||||
|
||||
- `CITATION.cff`
|
||||
|
||||
```diff
|
||||
- version: "{major}.{minor-1}"
|
||||
+ version: "{major}.{minor}"
|
||||
```
|
||||
|
||||
- `VERSION`
|
||||
|
||||
```diff
|
||||
- {major}.{minor}.0.dev0
|
||||
+ {major}.{minor}.0
|
||||
```
|
||||
|
||||
### 4. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add .github/workflows/tests_latest.yml CITATION.cff VERSION
|
||||
git commit -m 'Release: {major}.{minor}'
|
||||
git push origin release-v{major}.{minor}
|
||||
```
|
||||
|
||||
### 5. Create a pull request
|
||||
|
||||
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
||||
|
||||
### 6. Once the pull request is approved, merge it into `main`
|
||||
|
||||
It will automatically publish the new version of the package on PyPI.
|
||||
|
||||
### 7. Add a tag in git to mark the release
|
||||
|
||||
```shell
|
||||
git checkout main
|
||||
git pull origin main
|
||||
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
||||
git push origin v{major}.{minor}.0
|
||||
```
|
||||
|
||||
### 8. Create a branch `v{major}.{minor}-release` for future patch releases
|
||||
|
||||
```shell
|
||||
git checkout -b v{major}.{minor}-release
|
||||
git push origin v{major}.{minor}-release
|
||||
```
|
||||
|
||||
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
||||
|
||||
### 9. Create a GitHub Release
|
||||
|
||||
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
||||
2. Click **Draft a new release**.
|
||||
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
||||
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
||||
5. Click **Publish Release**.
|
||||
|
||||
### 10. Bump to dev version
|
||||
|
||||
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
||||
|
||||
```shell
|
||||
git checkout -b bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
2. Change the version in file `VERSION`:
|
||||
|
||||
```diff
|
||||
- {major}.{minor}.0
|
||||
+ {major}.{minor+1}.0.dev0
|
||||
```
|
||||
|
||||
3. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add VERSION
|
||||
git commit -m '⬆️ Bump dev version'
|
||||
git push origin bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
||||
|
||||
5. Once the pull request is approved, merge it into `main`.
|
||||
|
||||
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
||||
|
||||
## Making a patch release
|
||||
|
||||
### 1. Ensure your local repository is up to date with the upstream repository
|
||||
|
||||
```bash
|
||||
git checkout v{major}.{minor}-release
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
### 2. Cherry-pick the changes you want to include in the patch release
|
||||
|
||||
```bash
|
||||
git cherry-pick <commit-hash-0>
|
||||
git cherry-pick <commit-hash-1>
|
||||
...
|
||||
```
|
||||
|
||||
### 3. Change the version in the file `VERSION`
|
||||
|
||||
```diff
|
||||
- {major}.{minor}.{patch-1}
|
||||
+ {major}.{minor}.{patch}
|
||||
```
|
||||
|
||||
### 4. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add VERSION
|
||||
git commit -m 'Release: {major}.{minor}.{patch}'
|
||||
git push origin v{major}.{minor}-release
|
||||
```
|
||||
|
||||
### 5. Wait for the CI to pass
|
||||
|
||||
The CI will automatically publish the new version of the package on PyPI.
|
||||
|
||||
### 6. Add a tag in git to mark the release
|
||||
|
||||
```shell
|
||||
git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
|
||||
git push origin v{major}.{minor}.{patch}
|
||||
```
|
||||
|
||||
#### 7. Create a GitHub Release
|
||||
|
||||
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
||||
2. Click **Draft a new release**.
|
||||
3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
|
||||
4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new.
|
||||
5. Click **Publish Release**.
|
||||
@ -1,164 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import uuid
|
||||
from distutils.util import strtobool
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def parse_args():
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--command", type=str, default="",
|
||||
help="the command to run")
|
||||
parser.add_argument("--num-seeds", type=int, default=3,
|
||||
help="the number of random seeds")
|
||||
parser.add_argument("--start-seed", type=int, default=1,
|
||||
help="the number of the starting seed")
|
||||
parser.add_argument("--workers", type=int, default=0,
|
||||
help="the number of workers to run benchmark experimenets")
|
||||
parser.add_argument("--auto-tag", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
|
||||
help="if toggled, the runs will be tagged with git tags, commit, and pull request number if possible")
|
||||
parser.add_argument("--slurm-template-path", type=str, default=None,
|
||||
help="the path to the slurm template file (see docs for more details)")
|
||||
parser.add_argument("--slurm-gpus-per-task", type=int, default=1,
|
||||
help="the number of gpus per task to use for slurm jobs")
|
||||
parser.add_argument("--slurm-total-cpus", type=int, default=50,
|
||||
help="the number of gpus per task to use for slurm jobs")
|
||||
parser.add_argument("--slurm-ntasks", type=int, default=1,
|
||||
help="the number of tasks to use for slurm jobs")
|
||||
parser.add_argument("--slurm-nodes", type=int, default=None,
|
||||
help="the number of nodes to use for slurm jobs")
|
||||
args = parser.parse_args()
|
||||
# fmt: on
|
||||
return args
|
||||
|
||||
|
||||
def run_experiment(command: str):
|
||||
command_list = shlex.split(command)
|
||||
print(f"running {command}")
|
||||
|
||||
# Use subprocess.PIPE to capture the output
|
||||
fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, errors = fd.communicate()
|
||||
|
||||
return_code = fd.returncode
|
||||
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"
|
||||
|
||||
# Convert bytes to string and strip leading/trailing whitespaces
|
||||
return output.decode("utf-8").strip()
|
||||
|
||||
|
||||
def autotag() -> str:
|
||||
wandb_tag = ""
|
||||
print("autotag feature is enabled")
|
||||
git_tag = ""
|
||||
try:
|
||||
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
|
||||
print(f"identified git tag: {git_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
if len(git_tag) == 0:
|
||||
try:
|
||||
count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
|
||||
hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
|
||||
git_tag = f"no-tag-{count}-g{hash}"
|
||||
print(f"identified git tag: {git_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
wandb_tag = f"{git_tag}"
|
||||
|
||||
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
|
||||
try:
|
||||
# try finding the pull request number on github
|
||||
prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
|
||||
if prs.status_code == 200:
|
||||
prs = prs.json()
|
||||
if len(prs["items"]) > 0:
|
||||
pr = prs["items"][0]
|
||||
pr_number = pr["number"]
|
||||
wandb_tag += f",pr-{pr_number}"
|
||||
print(f"identified github pull request: {pr_number}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return wandb_tag
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.auto_tag:
|
||||
existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
|
||||
wandb_tag = autotag()
|
||||
if len(wandb_tag) > 0:
|
||||
if len(existing_wandb_tag) > 0:
|
||||
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
|
||||
else:
|
||||
os.environ["WANDB_TAGS"] = wandb_tag
|
||||
print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
|
||||
commands = []
|
||||
for seed in range(0, args.num_seeds):
|
||||
commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]
|
||||
|
||||
print("======= commands to run:")
|
||||
for command in commands:
|
||||
print(command)
|
||||
|
||||
if args.workers > 0 and args.slurm_template_path is None:
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=args.workers, thread_name_prefix="cleanrl-benchmark-worker-")
|
||||
for command in commands:
|
||||
executor.submit(run_experiment, command)
|
||||
executor.shutdown(wait=True)
|
||||
else:
|
||||
print("not running the experiments because --workers is set to 0; just printing the commands to run")
|
||||
|
||||
# SLURM logic
|
||||
if args.slurm_template_path is not None:
|
||||
if not os.path.exists("slurm"):
|
||||
os.makedirs("slurm")
|
||||
if not os.path.exists("slurm/logs"):
|
||||
os.makedirs("slurm/logs")
|
||||
print("======= slurm commands to run:")
|
||||
with open(args.slurm_template_path) as f:
|
||||
slurm_template = f.read()
|
||||
slurm_template = slurm_template.replace("{{array}}", f"0-{len(commands) - 1}%{args.workers}")
|
||||
slurm_template = slurm_template.replace(
|
||||
"{{seeds}}", f"({' '.join([str(args.start_seed + int(seed)) for seed in range(args.num_seeds)])})"
|
||||
)
|
||||
slurm_template = slurm_template.replace("{{len_seeds}}", f"{args.num_seeds}")
|
||||
slurm_template = slurm_template.replace("{{command}}", args.command)
|
||||
slurm_template = slurm_template.replace("{{gpus_per_task}}", f"{args.slurm_gpus_per_task}")
|
||||
total_gpus = args.slurm_gpus_per_task * args.slurm_ntasks
|
||||
slurm_cpus_per_gpu = math.ceil(args.slurm_total_cpus / total_gpus)
|
||||
slurm_template = slurm_template.replace("{{cpus_per_gpu}}", f"{slurm_cpus_per_gpu}")
|
||||
slurm_template = slurm_template.replace("{{ntasks}}", f"{args.slurm_ntasks}")
|
||||
if args.slurm_nodes is not None:
|
||||
slurm_template = slurm_template.replace("{{nodes}}", f"#SBATCH --nodes={args.slurm_nodes}")
|
||||
else:
|
||||
slurm_template = slurm_template.replace("{{nodes}}", "")
|
||||
filename = str(uuid.uuid4())
|
||||
open(os.path.join("slurm", f"{filename}.slurm"), "w").write(slurm_template)
|
||||
slurm_path = os.path.join("slurm", f"{filename}.slurm")
|
||||
print(f"saving command in {slurm_path}")
|
||||
if args.workers > 0:
|
||||
job_id = run_experiment(f"sbatch --parsable {slurm_path}")
|
||||
print(f"Job ID: {job_id}")
|
||||
@ -1,26 +0,0 @@
|
||||
export WANDB_ENTITY=huggingface
|
||||
export WANDB_PROJECT=trl
|
||||
bash $BENCHMARK_SCRIPT > output.txt
|
||||
|
||||
# Extract Job IDs into an array
|
||||
job_ids=($(grep "Job ID:" output.txt | awk '{print $3}'))
|
||||
|
||||
# Extract WANDB_TAGS into an array
|
||||
WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}'))
|
||||
WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n"))
|
||||
|
||||
# Print to verify
|
||||
echo "Job IDs: ${job_ids[@]}"
|
||||
echo "WANDB_TAGS: ${WANDB_TAGS[@]}"
|
||||
|
||||
TAGS_STRING="?tag=${WANDB_TAGS[0]}"
|
||||
FOLDER_STRING="${WANDB_TAGS[0]}"
|
||||
for tag in "${WANDB_TAGS[@]:1}"; do
|
||||
TAGS_STRING+="&tag=$tag"
|
||||
FOLDER_STRING+="_$tag"
|
||||
done
|
||||
|
||||
echo "TAGS_STRING: $TAGS_STRING"
|
||||
echo "FOLDER_STRING: $FOLDER_STRING"
|
||||
|
||||
TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING BENCHMARK_PLOT_SCRIPT=$BENCHMARK_PLOT_SCRIPT sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch
|
||||
@ -1,44 +0,0 @@
|
||||
# hello world experiment
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/dpo.py --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 1e-3 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="dpo_anthropic_hh" --optim adamw_torch --warmup_steps 150 --report_to wandb --bf16 --logging_first_step --no_remove_unused_columns" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/sft.py --model_name_or_path="facebook/opt-350m" --report_to="wandb" --learning_rate=1.41e-5 --per_device_train_batch_size=64 --gradient_accumulation_steps=16 --output_dir="sft_openassistant-guanaco" --logging_steps=1 --num_train_epochs=3 --max_steps=-1 --push_to_hub --gradient_checkpointing" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/reward_modeling.py --model_name_or_path=facebook/opt-350m --output_dir="reward_modeling_anthropic_hh" --per_device_train_batch_size=64 --num_train_epochs=1 --gradient_accumulation_steps=16 --gradient_checkpointing=True --learning_rate=1.41e-5 --report_to="wandb" --remove_unused_columns=False --optim="adamw_torch" --logging_steps=10 --eval_strategy="steps" --max_length=512" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
@ -1,50 +0,0 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
echo "we deal with $TAGS_STRING"
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/ppo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/rewards/accuracies&metrics=train/loss' \
|
||||
"gpt2$TAGS_STRING" \
|
||||
--env-ids dpo_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/dpo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss&metrics=eval/accuracy&metrics=eval/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids reward_modeling_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/reward_modeling \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids sft_openassistant-guanaco \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/sft \
|
||||
--scan-history
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$FOLDER_STRING" \
|
||||
--path_in_repo="images/benchmark/$FOLDER_STRING" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
# compound experiments: gpt2xl + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 8 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 90 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
@ -1,31 +0,0 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
echo "we deal with $TAGS_STRING"
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo$TAGS_STRING" \
|
||||
"ppo_gpt2xl_grad_accu$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/different_models \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:cerebras/Cerebras-GPT-6.7B \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/deepspeed \
|
||||
--scan-history
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$FOLDER_STRING" \
|
||||
--path_in_repo="images/benchmark/$FOLDER_STRING" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
## w/ and w/o gradient accumulation
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
## w/ different models (gpt2, gpt2-xl, falcon, llama2)
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
|
||||
## w/ and w/o PEFT
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
@ -1,56 +0,0 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
BASELINE_PR_TAG=v0.4.7-55-g110e672
|
||||
BASELINE_PR_NAME=PR-662
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/sentiment \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_step_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb gradient accumulation ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/gradient_accu \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_gpt2?tag=$BASELINE_PR_TAG&cl=sentiment gpt2 ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_falcon_rw_1b?tag=$BASELINE_PR_TAG&cl=sentiment tiiuae/falcon-rw-1b ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_gpt2xl_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment gpt2xl ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/different_models \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_peft?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb w/ peft ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/peft \
|
||||
--scan-history
|
||||
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$BASELINE_PR_TAG" \
|
||||
--path_in_repo="images/benchmark/$BASELINE_PR_TAG" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
||||
@ -1,40 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from ghapi.all import GhApi
|
||||
|
||||
|
||||
FOLDER_STRING = os.environ.get("FOLDER_STRING", "")
|
||||
folder = f"benchmark/trl/{FOLDER_STRING}"
|
||||
host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}"
|
||||
|
||||
# Create a GitHub API instance
|
||||
github_context = json.loads(os.environ["GITHUB_CONTEXT"])
|
||||
token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months
|
||||
status_message = "**[COSTA BENCHMARK BOT]**: Here are the results"
|
||||
body = status_message
|
||||
repo = github_context["repository"]
|
||||
owner, repo = repo.split("/")
|
||||
api = GhApi(owner=owner, repo=repo, token=token)
|
||||
|
||||
# for each `.png` file in the folder, add it to the comment
|
||||
for file in os.listdir(folder):
|
||||
if file.endswith(".png"):
|
||||
body += f"\n"
|
||||
|
||||
# Create a comment on the issue
|
||||
api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body)
|
||||
@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
|
||||
sleep 2m
|
||||
bash $BENCHMARK_PLOT_SCRIPT
|
||||
srun python benchmark/post_github_comment.py
|
||||
@ -1,3 +0,0 @@
|
||||
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
|
||||
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
|
||||
bash benchmark/benchmark_and_report.sh
|
||||
@ -1,19 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --gpus-per-task={{gpus_per_task}}
|
||||
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
|
||||
#SBATCH --ntasks={{ntasks}}
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
#SBATCH --array={{array}}
|
||||
##SBATCH --exclude=ip-26-0-149-199
|
||||
|
||||
module load cuda/12.1
|
||||
|
||||
{{nodes}}
|
||||
|
||||
seeds={{seeds}}
|
||||
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}
|
||||
|
||||
echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
|
||||
srun {{command}} --seed $seed
|
||||
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
||||
@ -1,60 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--dataset_text_field 'text' \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
||||
6
docker/trl-dev/Dockerfile
Normal file
6
docker/trl-dev/Dockerfile
Normal file
@ -0,0 +1,6 @@
|
||||
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
|
||||
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install --upgrade pip uv
|
||||
RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]"
|
||||
RUN uv pip install --system hf_transfer liger_kernel trackio peft
|
||||
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
|
||||
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
transformers \
|
||||
accelerate \
|
||||
peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep trl
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
||||
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
git+https://github.com/huggingface/transformers \
|
||||
git+https://github.com/huggingface/accelerate \
|
||||
git+https://github.com/huggingface/peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep transformers
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
||||
4
docker/trl/Dockerfile
Normal file
4
docker/trl/Dockerfile
Normal file
@ -0,0 +1,4 @@
|
||||
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
|
||||
RUN pip install --upgrade pip uv
|
||||
RUN uv pip install --system trl[liger,peft,vlm] hf_transfer trackio
|
||||
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
|
||||
@ -5,35 +5,67 @@
|
||||
title: Installation
|
||||
- local: quickstart
|
||||
title: Quickstart
|
||||
- local: clis
|
||||
title: Get started with Command Line Interfaces (CLIs)
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: dataset_formats
|
||||
title: Dataset Formats
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
title: Use Trained Models
|
||||
- local: customization
|
||||
title: Customize the Training
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Get started
|
||||
- local: paper_index
|
||||
title: Paper Index
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- sections: # Sort alphabetically
|
||||
- local: alignprop_trainer
|
||||
title: AlignProp
|
||||
- local: bco_trainer
|
||||
title: BCO
|
||||
- local: clis
|
||||
title: Command Line Interface (CLI)
|
||||
- local: jobs_training
|
||||
title: Training using Jobs
|
||||
- local: customization
|
||||
title: Customizing the Training
|
||||
- local: reducing_memory_usage
|
||||
title: Reducing Memory Usage
|
||||
- local: speeding_up_training
|
||||
title: Speeding Up Training
|
||||
- local: distributing_training
|
||||
title: Distributing Training
|
||||
- local: use_model
|
||||
title: Using Trained Models
|
||||
title: How-to guides
|
||||
- sections:
|
||||
- local: deepspeed_integration
|
||||
title: DeepSpeed
|
||||
- local: kernels_hub
|
||||
title: Kernels Hub
|
||||
- local: liger_kernel_integration
|
||||
title: Liger Kernel
|
||||
- local: peft_integration
|
||||
title: PEFT
|
||||
- local: rapidfire_integration
|
||||
title: RapidFire AI
|
||||
- local: trackio_integration
|
||||
title: Trackio
|
||||
- local: unsloth_integration
|
||||
title: Unsloth
|
||||
- local: vllm_integration
|
||||
title: vLLM
|
||||
title: Integrations
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: community_tutorials
|
||||
title: Community Tutorials
|
||||
- local: lora_without_regret
|
||||
title: LoRA Without Regret
|
||||
title: Examples
|
||||
- sections:
|
||||
- sections: # Sorted alphabetically
|
||||
- local: cpo_trainer
|
||||
title: CPO
|
||||
- local: ddpo_trainer
|
||||
title: DDPO
|
||||
- local: dpo_trainer
|
||||
title: DPO
|
||||
- local: online_dpo_trainer
|
||||
title: Online DPO
|
||||
- local: gkd_trainer
|
||||
title: GKD
|
||||
- local: grpo_trainer
|
||||
title: GRPO
|
||||
- local: kto_trainer
|
||||
title: KTO
|
||||
- local: nash_md_trainer
|
||||
@ -42,45 +74,51 @@
|
||||
title: ORPO
|
||||
- local: ppo_trainer
|
||||
title: PPO
|
||||
- local: ppov2_trainer
|
||||
title: PPOv2
|
||||
- local: prm_trainer
|
||||
title: PRM
|
||||
- local: reward_trainer
|
||||
title: Reward
|
||||
- local: rloo_trainer
|
||||
title: RLOO
|
||||
- local: sft_trainer
|
||||
title: SFT
|
||||
- local: iterative_sft_trainer
|
||||
title: Iterative SFT
|
||||
- local: xpo_trainer
|
||||
title: XPO
|
||||
title: Trainers
|
||||
- local: models
|
||||
title: Model Classes
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: model_utils
|
||||
title: Model Utilities
|
||||
- local: judges
|
||||
title: Judges
|
||||
- local: callbacks
|
||||
title: Callbacks
|
||||
- local: data_utils
|
||||
title: Data Utilities
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
- local: rewards
|
||||
title: Reward Functions
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
- local: others
|
||||
title: Others
|
||||
title: API
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: lora_tuning_peft
|
||||
title: Training with PEFT
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
- local: experimental_overview
|
||||
title: Experimental Overview
|
||||
- local: bema_for_reference_model # Sorted alphabetically
|
||||
title: BEMA for Reference Model
|
||||
- local: bco_trainer
|
||||
title: BCO
|
||||
- local: gfpo
|
||||
title: GFPO
|
||||
- local: gold_trainer
|
||||
title: GOLD
|
||||
- local: grpo_with_replay_buffer
|
||||
title: GRPO With Replay Buffer
|
||||
- local: gspo_token
|
||||
title: GSPO-token
|
||||
- local: papo_trainer
|
||||
title: PAPO
|
||||
- local: openenv
|
||||
title: OpenEnv Integration
|
||||
title: Experimental
|
||||
@ -1,91 +0,0 @@
|
||||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
|
||||
## The why
|
||||
|
||||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
||||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
||||
|
||||
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
|
||||
|
||||
|
||||
## Getting started with `examples/scripts/alignprop.py`
|
||||
|
||||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python alignprop.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a dictionary with keys
|
||||
```python
|
||||
['image', 'prompt', 'prompt_metadata', 'rewards']
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
|
||||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.to("cuda")
|
||||
|
||||
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"dump/{prompt}.png")
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
|
||||
@ -1,67 +1,27 @@
|
||||
# BCO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
|
||||
TRL supports the Binary Classifier Optimization (BCO).
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
||||
For a full example have a look at [`examples/scripts/bco.py`].
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The BCO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
bco_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
|
||||
## Expected dataset type
|
||||
|
||||
The [`experimental.bco.BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
|
||||
The [`experimental.bco.BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Expected model format
|
||||
|
||||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `BCOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
@ -71,12 +31,13 @@ bco_trainer = BCOTrainer(
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
```python
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
@ -88,7 +49,7 @@ If the prompts in your desired and undesired datasets differ a lot, it is useful
|
||||
|
||||
Choose an embedding model and tokenizer:
|
||||
|
||||
```py
|
||||
```python
|
||||
embedding_model = AutoModel.from_pretrained(your_model_id)
|
||||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
||||
|
||||
@ -101,9 +62,9 @@ embedding_model = Accelerator().prepare_model(self.embedding_model)
|
||||
embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
```
|
||||
|
||||
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
prompt_sample_size=512,
|
||||
@ -114,7 +75,7 @@ bco_trainer = BCOTrainer(
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
embedding_func=embedding_func,
|
||||
embedding_tokenizer=self.embedding_tokenizer,
|
||||
)
|
||||
@ -132,8 +93,11 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
|
||||
|
||||
## BCOTrainer
|
||||
|
||||
[[autodoc]] BCOTrainer
|
||||
[[autodoc]] experimental.bco.BCOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## BCOConfig
|
||||
|
||||
[[autodoc]] BCOConfig
|
||||
[[autodoc]] experimental.bco.BCOConfig
|
||||
31
docs/source/bema_for_reference_model.md
Normal file
31
docs/source/bema_for_reference_model.md
Normal file
@ -0,0 +1,31 @@
|
||||
# BEMA for Reference Model
|
||||
|
||||
This feature implements the BEMA algorithm to update the reference model during DPO training.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
|
||||
bema_callback = BEMACallback(update_ref_model=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
train_dataset=pref_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[bema_callback],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
@ -1,72 +0,0 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
|
||||
|
||||
```python
|
||||
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.extras import BestOfNSampler
|
||||
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
|
||||
reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# callable that takes a list of raw text and returns a list of corresponding reward scores
|
||||
def queries_to_scores(list_of_strings):
|
||||
return [output["score"] for output in reward_pipe(list_of_strings)]
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
|
||||
|
||||
|
||||
```
|
||||
|
||||
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
|
||||
|
||||
```python
|
||||
|
||||
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
|
||||
|
||||
```
|
||||
The default sample size is 4, but you can change it at the time of instance initialization like so
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
|
||||
|
||||
```
|
||||
|
||||
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
|
||||
|
||||
```
|
||||
|
||||
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
||||
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
||||
|
||||
```python
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
|
||||
|
||||
best_of_n.generate(query_tensors, device=device)
|
||||
|
||||
```
|
||||
|
||||
Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query
|
||||
|
||||
|
||||
@ -14,4 +14,16 @@
|
||||
|
||||
## LogCompletionsCallback
|
||||
|
||||
[[autodoc]] LogCompletionsCallback
|
||||
[[autodoc]] LogCompletionsCallback
|
||||
|
||||
## MergeModelCallback
|
||||
|
||||
[[autodoc]] MergeModelCallback
|
||||
|
||||
## BEMACallback
|
||||
|
||||
[[autodoc]] BEMACallback
|
||||
|
||||
## WeaveCallback
|
||||
|
||||
[[autodoc]] WeaveCallback
|
||||
414
docs/source/clis.md
Normal file
414
docs/source/clis.md
Normal file
@ -0,0 +1,414 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
## Commands
|
||||
|
||||
Currently supported commands are:
|
||||
|
||||
### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl reward`: train a Reward Model
|
||||
- `trl rloo`: fine-tune a LLM with RLOO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
### Other Commands
|
||||
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
|
||||
## Fine-Tuning with the TRL CLI
|
||||
|
||||
### Basic Usage
|
||||
|
||||
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
|
||||
|
||||
<hfoptions id="command_line">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using Configuration Files
|
||||
|
||||
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
|
||||
|
||||
<hfoptions id="config_file">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Scaling Up with Accelerate
|
||||
|
||||
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
|
||||
|
||||
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
|
||||
|
||||
<hfoptions id="launch_args">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward inline">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward w/ config file">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using `--accelerate_config` for Accelerate Configuration
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
- the name of a predefined config profile (built into TRL), or
|
||||
- a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| --- | --- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
#### Example Usage
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward inline">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward w/ config file">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using dataset mixtures
|
||||
|
||||
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
|
||||
|
||||
<hfoptions id="dataset_mixtures">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: stanfordnlp/imdb
|
||||
- path: roneneldan/TinyStories
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: BAAI/Infinity-Preference
|
||||
- path: argilla/Capybara-Preferences
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: trl-lib/tldr-preference
|
||||
- path: trl-lib/lm-human-preferences-sentiment
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes.
|
||||
|
||||
## Getting the System Information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
```bash
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- accelerator(s): NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
- compute_environment: LOCAL_MACHINE
|
||||
- distributed_type: DEEPSPEED
|
||||
- mixed_precision: no
|
||||
- use_cpu: False
|
||||
- debug: False
|
||||
- num_processes: 4
|
||||
- machine_rank: 0
|
||||
- num_machines: 1
|
||||
- rdzv_backend: static
|
||||
- same_network: True
|
||||
- main_training_function: main
|
||||
- enable_cpu_affinity: False
|
||||
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
|
||||
- downcast_bf16: no
|
||||
- tpu_use_cluster: False
|
||||
- tpu_use_sudo: False
|
||||
- tpu_env: []
|
||||
- Datasets version: 3.0.0
|
||||
- HF Hub version: 0.24.7
|
||||
- TRL version: 0.12.0.dev0+acb4d70
|
||||
- bitsandbytes version: 0.41.1
|
||||
- DeepSpeed version: 0.15.1
|
||||
- Diffusers version: 0.30.3
|
||||
- Liger-Kernel version: 0.3.0
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
- vLLM version: not installed
|
||||
```
|
||||
|
||||
This information is required when reporting an issue.
|
||||
@ -1,119 +0,0 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
- `trl sft`: fine-tune a LLM on a text/instruction dataset
|
||||
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
|
||||
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
trl-internal-testing/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
```
|
||||
|
||||
|
||||
The DPO CLI is based on the `examples/scripts/dpo.py` script.
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> To use the chat CLI with the developer installation, you must run `make dev`
|
||||
>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
|
||||
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
|
||||
57
docs/source/community_tutorials.md
Normal file
57
docs/source/community_tutorials.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Community Tutorials
|
||||
|
||||
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
|
||||
|
||||
## Language Models
|
||||
|
||||
### Tutorials
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
|
||||
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
|
||||
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
|
||||
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
|
||||
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
||||
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
||||
|
||||
### Videos
|
||||
|
||||
| Task | Title | Author | Video |
|
||||
| --- | --- | --- | --- |
|
||||
| Instruction tuning | Fine-tuning open AI models using Hugging Face TRL | [Wietse Venema](https://huggingface.co/wietsevenema) | [<img src="https://img.youtube.com/vi/cnGyyM0vOes/0.jpg">](https://youtu.be/cnGyyM0vOes) |
|
||||
| Instruction tuning | How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset | [Mayurji](https://huggingface.co/iammayur) | [<img src="https://img.youtube.com/vi/jKdXv3BiLu0/0.jpg">](https://youtu.be/jKdXv3BiLu0) |
|
||||
|
||||
|
||||
<details>
|
||||
<summary>⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand)</summary>
|
||||
|
||||
> [!WARNING]
|
||||
> The tutorial uses two deprecated features:
|
||||
>
|
||||
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
|
||||
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
|
||||
|
||||
</details>
|
||||
|
||||
## Vision Language Models
|
||||
|
||||
### Tutorials
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
|
||||
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
|
||||
| Object Detection Grounding | [`SFTTrainer`] | Fine tuning a VLM for Object Detection Grounding using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_object_detection_grounding) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_object_detection_grounding.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) |
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
|
||||
126
docs/source/cpo_trainer.md
Normal file
126
docs/source/cpo_trainer.md
Normal file
@ -0,0 +1,126 @@
|
||||
# CPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high level, CPO trains models to avoid generating adequate, but not perfect, translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_cpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import CPOConfig, CPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO")
|
||||
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_cpo.py
|
||||
```
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
|
||||
|
||||
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/cpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir Qwen2-0.5B-CPO
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
||||
|
||||
## CPO variants
|
||||
|
||||
### Simple Preference Optimization (SimPO)
|
||||
|
||||
[Simple Preference Optimization](https://huggingface.co/papers/2405.14734) (SimPO) by [Yu Meng](https://huggingface.co/yumeng5), [Mengzhou Xia](https://huggingface.co/mengzhouxia), and [Danqi Chen](https://huggingface.co/cdq10131) proposes a simpler and more effective preference optimization algorithm than DPO without using a reference model. The key designs in SimPO are (1) using length-normalized log likelihood as the implicit reward, and (2) incorporating a target reward margin in the Bradley-Terry ranking objective. The official code can be found at [princeton-nlp/SimPO](https://github.com/princeton-nlp/SimPO).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model.
|
||||
|
||||
The SimPO loss is integrated in the [`CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and set the `simpo_gamma` to a recommended value.
|
||||
|
||||
### CPO-SimPO
|
||||
|
||||
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
|
||||
|
||||
### AlphaPO
|
||||
|
||||
The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following:
|
||||
|
||||
> Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance.
|
||||
|
||||
To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value.
|
||||
|
||||
## Loss functions
|
||||
|
||||
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| --- | --- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
|
||||
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
|
||||
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g., [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
||||
@ -1,113 +0,0 @@
|
||||
# CPO Trainer
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
|
||||
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
## SimPO
|
||||
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.
|
||||
|
||||
## CPO-SimPO
|
||||
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
cpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `CPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
|
||||
|
||||
```py
|
||||
cpo_config = CPOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
cpo_trainer = CPOTrainer(
|
||||
model,
|
||||
args=cpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
cpo_trainer.train()
|
||||
```
|
||||
|
||||
## Loss functions
|
||||
|
||||
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
|
||||
|
||||
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
||||
119
docs/source/customization.md
Normal file
119
docs/source/customization.md
Normal file
@ -0,0 +1,119 @@
|
||||
# Training customization
|
||||
|
||||
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
|
||||
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, None),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Add a learning rate scheduler
|
||||
|
||||
You can also play with your training by adding learning rate schedulers.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Memory efficient fine-tuning by sharing layers
|
||||
|
||||
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import create_reference_model, DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Use the accelerator cache optimizer
|
||||
|
||||
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:
|
||||
|
||||
```python
|
||||
training_args = DPOConfig(..., optimize_device_cache=True)
|
||||
```
|
||||
@ -1,216 +0,0 @@
|
||||
# Training customization
|
||||
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
|
||||
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
|
||||
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers
|
||||
|
||||
By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
|
||||
```python
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
|
||||
For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
|
||||
### Use LION optimizer
|
||||
|
||||
You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
|
||||
```python
|
||||
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
|
||||
|
||||
...
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
|
||||
</div>
|
||||
|
||||
|
||||
## Add a learning rate scheduler
|
||||
|
||||
You can also play with your training by adding learning rate schedulers!
|
||||
```python
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
```
|
||||
|
||||
## Memory efficient fine-tuning by sharing layers
|
||||
|
||||
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
<div>
|
||||
|
||||
Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).
|
||||
|
||||
</div>
|
||||
|
||||
```python
|
||||
# 0. imports
|
||||
# pip install bitsandbytes
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
```
|
||||
|
||||
## Use the CUDA cache optimizer
|
||||
|
||||
When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
|
||||
|
||||
```python
|
||||
config = PPOConfig(..., optimize_cuda_cache=True)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Use score scaling/normalization/clipping
|
||||
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
|
||||
```python
|
||||
from trl import PPOConfig
|
||||
|
||||
ppo_config = {
|
||||
use_score_scaling=True,
|
||||
use_score_norm=True,
|
||||
score_clip=0.5,
|
||||
}
|
||||
config = PPOConfig(**ppo_config)
|
||||
```
|
||||
|
||||
To run `ppo.py`, you can use the following command:
|
||||
```
|
||||
python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
|
||||
```
|
||||
53
docs/source/data_utils.md
Normal file
53
docs/source/data_utils.md
Normal file
@ -0,0 +1,53 @@
|
||||
# Data Utilities
|
||||
|
||||
## prepare_multimodal_messages
|
||||
|
||||
[[autodoc]] prepare_multimodal_messages
|
||||
|
||||
## prepare_multimodal_messages_vllm
|
||||
|
||||
[[autodoc]] prepare_multimodal_messages_vllm
|
||||
|
||||
## is_conversational
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
## is_conversational_from_value
|
||||
|
||||
[[autodoc]] is_conversational_from_value
|
||||
|
||||
## apply_chat_template
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
|
||||
## maybe_apply_chat_template
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
||||
## maybe_extract_prompt
|
||||
|
||||
[[autodoc]] maybe_extract_prompt
|
||||
|
||||
## unpair_preference_dataset
|
||||
|
||||
[[autodoc]] unpair_preference_dataset
|
||||
|
||||
## maybe_unpair_preference_dataset
|
||||
|
||||
[[autodoc]] maybe_unpair_preference_dataset
|
||||
|
||||
## pack_dataset
|
||||
|
||||
[[autodoc]] pack_dataset
|
||||
|
||||
## truncate_dataset
|
||||
|
||||
[[autodoc]] truncate_dataset
|
||||
@ -1,15 +0,0 @@
|
||||
## Data Utilities
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
||||
[[autodoc]] maybe_extract_prompt
|
||||
|
||||
[[autodoc]] unpair_preference_dataset
|
||||
|
||||
[[autodoc]] maybe_unpair_preference_dataset
|
||||
993
docs/source/dataset_formats.md
Normal file
993
docs/source/dataset_formats.md
Normal file
@ -0,0 +1,993 @@
|
||||
# Dataset formats and types
|
||||
|
||||
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
|
||||
|
||||
## Overview of the dataset formats and types
|
||||
|
||||
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
|
||||
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Type \ Format</th>
|
||||
<th>Standard</th>
|
||||
<th>Conversational</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Language modeling</td>
|
||||
<td>
|
||||
<pre><code>{"text": "The sky is blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-only</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is"}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-completion</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"chosen": " blue.",
|
||||
"rejected": " green."}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": "The sky is blue.",
|
||||
"rejected": "The sky is green."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<td>Unpaired preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue.",
|
||||
"label": True}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is green."}],
|
||||
"label": False}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<td>Stepwise supervision</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"The fractional part of 9.11 is 0.11.",
|
||||
"0.11 is greater than 0.8.",
|
||||
"Hence, 9.11 > 9.8."],
|
||||
"labels": [True, True, False, False]}</code></pre>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Formats
|
||||
|
||||
#### Standard
|
||||
|
||||
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
# Language modeling
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Preference
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Unpaired preference
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
```
|
||||
|
||||
#### Conversational
|
||||
|
||||
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
```
|
||||
|
||||
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
# Prompt-completion
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
# Preference
|
||||
preference_example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}],
|
||||
}
|
||||
```
|
||||
|
||||
#### Tool Calling
|
||||
|
||||
Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool.
|
||||
|
||||
After the assistant initiates a tool call, the tool executes and returns its output. The assistant can then process this output and continue the conversation accordingly.
|
||||
|
||||
Here’s a simple example of a tool-calling interaction:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Turn on the living room lights."},
|
||||
{"role": "assistant", "tool_calls": [
|
||||
{"type": "function", "function": {
|
||||
"name": "control_light",
|
||||
"arguments": {"room": "living room", "state": "on"}
|
||||
}}]
|
||||
},
|
||||
{"role": "tool", "name": "control_light", "content": "The lights in the living room are now on."},
|
||||
{"role": "assistant", "content": "Done!"}
|
||||
]
|
||||
```
|
||||
|
||||
When preparing datasets for Supervised Fine-Tuning (SFT) with tool calling, it is important that your dataset includes an additional column named `tools`. This column contains the list of available tools for the model, which is usually used by the chat template to construct the system prompt.
|
||||
|
||||
The tools must be specified in a codified JSON schema format. You can automatically generate this schema from Python function signatures using the [`~transformers.utils.get_json_schema`] utility:
|
||||
|
||||
```python
|
||||
from transformers.utils import get_json_schema
|
||||
|
||||
def control_light(room: str, state: str) -> str:
|
||||
"""
|
||||
Controls the lights in a room.
|
||||
|
||||
Args:
|
||||
room: The name of the room.
|
||||
state: The desired state of the light ("on" or "off").
|
||||
|
||||
Returns:
|
||||
str: A message indicating the new state of the lights.
|
||||
"""
|
||||
return f"The lights in {room} are now {state}."
|
||||
|
||||
# Generate JSON schema
|
||||
json_schema = get_json_schema(control_light)
|
||||
```
|
||||
|
||||
The generated schema would look like:
|
||||
|
||||
```python
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "control_light",
|
||||
"description": "Controls the lights in a room.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"room": {"type": "string", "description": "The name of the room."},
|
||||
"state": {"type": "string", "description": 'The desired state of the light ("on" or "off").'},
|
||||
},
|
||||
"required": ["room", "state"],
|
||||
},
|
||||
"return": {"type": "string", "description": "str: A message indicating the new state of the lights."},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
A complete dataset entry for SFT might look like:
|
||||
|
||||
```python
|
||||
{"messages": messages, "tools": [json_schema]}
|
||||
```
|
||||
|
||||
For more detailed information on tool calling, refer to the [Tool Calling section in the `transformers` documentation](https://huggingface.co/docs/transformers/chat_extras#tools-and-rag) and the blog post [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
|
||||
|
||||
### Harmony
|
||||
|
||||
The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include:
|
||||
|
||||
- **Developer role** – Provides high level instructions (similar to a system prompt) and lists available tools.
|
||||
- **Channels** – Separate types of assistant output into distinct streams:
|
||||
|
||||
- `analysis` – for internal reasoning, from the key `"thinking"`
|
||||
- `final` – for the user-facing answer, from the key `"content"`
|
||||
- `commentary` – for tool calls or meta notes
|
||||
|
||||
- **Reasoning effort** – Signals how much thinking the model should show (e.g., `"low"`, `"medium"`, `"high"`).
|
||||
- **Model identity** – Explicitly defines the assistant’s persona.
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
|
||||
|
||||
messages = [
|
||||
{"role": "developer", "content": "Use a friendly tone."},
|
||||
{"role": "user", "content": "What is the meaning of life?"},
|
||||
{"role": "assistant", "thinking": "Deep reflection...", "content": "The final answer is..."},
|
||||
]
|
||||
|
||||
print(
|
||||
tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
reasoning_effort="low",
|
||||
model_identity="You are HuggingGPT, a large language model trained by Hugging Face."
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
This produces:
|
||||
|
||||
```txt
|
||||
<|start|>system<|message|>You are HuggingGPT, a large language model trained by Hugging Face.
|
||||
Knowledge cutoff: 2024-06
|
||||
Current date: 2025-08-03
|
||||
|
||||
Reasoning: low
|
||||
|
||||
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
||||
|
||||
Use a friendly tone.<|end|><|start|>user<|message|>What is the meaning of life?<|end|><|start|>assistant<|channel|>analysis<|message|>Deep reflection...<|end|><|start|>assistant<|channel|>final<|message|>The final answer is...<|return|>
|
||||
```
|
||||
|
||||
For full details on message structure, supported fields, and advanced usage, see the [Harmony documentation](https://cookbook.openai.com/articles/openai-harmony).
|
||||
|
||||
### Types
|
||||
|
||||
#### Language modeling
|
||||
|
||||
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Conversational format
|
||||
language_modeling_example = {"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}
|
||||
]}
|
||||
```
|
||||
|
||||
#### Prompt-only
|
||||
|
||||
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating completion based on this prompt, where the model learns to continue or complete the given input.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_only_example = {"prompt": "The sky is"}
|
||||
# Conversational format
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
```
|
||||
|
||||
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
|
||||
|
||||
> [!TIP]
|
||||
> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
||||
>
|
||||
> ```python
|
||||
> from transformers import AutoTokenizer
|
||||
> from trl import apply_chat_template
|
||||
>
|
||||
> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
>
|
||||
> # Example for prompt-only type
|
||||
> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
> apply_chat_template(prompt_only_example, tokenizer)
|
||||
> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
||||
>
|
||||
> # Example for language modeling type
|
||||
> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
> apply_chat_template(lm_example, tokenizer)
|
||||
> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
||||
> ```
|
||||
>
|
||||
> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
||||
> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
||||
|
||||
#### Prompt-completion
|
||||
|
||||
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
|
||||
# Conversational format
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
```
|
||||
|
||||
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
|
||||
|
||||
#### Preference
|
||||
|
||||
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
|
||||
Some datasets may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Implicit prompt
|
||||
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
|
||||
# Conversational format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
||||
## Implicit prompt
|
||||
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}
|
||||
```
|
||||
|
||||
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
|
||||
|
||||
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
|
||||
|
||||
#### Unpaired preference
|
||||
|
||||
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
# Conversational format
|
||||
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
"label": True}
|
||||
```
|
||||
|
||||
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
|
||||
|
||||
#### Stepwise supervision
|
||||
|
||||
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
|
||||
|
||||
```python
|
||||
stepwise_example = {
|
||||
"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
|
||||
"labels": [True, False]
|
||||
}
|
||||
```
|
||||
|
||||
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
|
||||
|
||||
## Which dataset type to use?
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset type |
|
||||
| --- | --- |
|
||||
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
## Using any dataset with TRL: preprocessing and conversion
|
||||
|
||||
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
|
||||
|
||||
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
|
||||
|
||||
### Example: UltraFeedback dataset
|
||||
|
||||
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
|
||||
|
||||
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
|
||||
|
||||
```sh
|
||||
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
|
||||
```
|
||||
|
||||
Once converted, the dataset will look like this:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Now, you can use this dataset with TRL!
|
||||
|
||||
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
|
||||
|
||||
## Utilities for converting dataset types
|
||||
|
||||
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
def concat_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From prompt-completion to prompt-only dataset
|
||||
|
||||
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to language modeling dataset
|
||||
|
||||
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": ["The sky is blue.", "The sun is in the sky."],
|
||||
"rejected": ["The sky is green.", "The sun is in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-only dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
|
||||
```
|
||||
|
||||
### From implicit to explicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to unpaired preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt, unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
### From preference to language modeling dataset
|
||||
|
||||
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
def concat_prompt_chosen(example):
|
||||
return {"text": example["prompt"] + example["chosen"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-only dataset
|
||||
|
||||
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From explicit to implicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
def concat_prompt_to_completions(example):
|
||||
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference to unpaired preference dataset
|
||||
|
||||
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
### From unpaired preference to language modeling dataset
|
||||
|
||||
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-completion dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-only dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to language modeling dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"text": example["prompt"] + completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt-completion dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def join_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"completion": completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt-only dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to unpaired preference dataset
|
||||
|
||||
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
|
||||
|
||||
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def merge_completions_and_labels(example):
|
||||
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
|
||||
|
||||
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
|
||||
```
|
||||
|
||||
## Vision datasets
|
||||
|
||||
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
|
||||
|
||||
A conversational vision dataset differs from a standard conversational dataset in two key ways:
|
||||
|
||||
1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image.
|
||||
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Textual dataset:
|
||||
"content": "What color is the sky?"
|
||||
|
||||
# Vision dataset:
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What color is the sky in the image?"}
|
||||
]
|
||||
```
|
||||
|
||||
An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
> [!NOTE]
|
||||
> Mixing text-only and vision-language data in the dataset is possible, but it requires `transformers` version 4.57.0 or later. Example:
|
||||
>
|
||||
> ```python
|
||||
> dataset = Dataset.from_dict({
|
||||
> "prompt": [
|
||||
> [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"}]}],
|
||||
> [{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
|
||||
> ],
|
||||
> "completion": [
|
||||
> [{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}],
|
||||
> [{"role": "assistant", "content": [{"type": "text", "text": "Paris."}]}],
|
||||
> ],
|
||||
> "images": [
|
||||
> [PIL.Image.open("path/to/sky_image1.png")],
|
||||
> [],
|
||||
> ],
|
||||
> })
|
||||
> ```
|
||||
@ -1,712 +0,0 @@
|
||||
# Dataset formats
|
||||
|
||||
This guide provides an overview of the dataset formats supported by each trainer in TRL. Since conversational datasets are very common, we also provide a guide on how to use them, and how to convert them into a standard dataset format for TRL trainers.
|
||||
|
||||
## Overview of the dataset formats and types
|
||||
|
||||
The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Type \ Format</th>
|
||||
<th>Standard</th>
|
||||
<th>Conversational</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Language modeling</td>
|
||||
<td>
|
||||
<pre><code>{"text": "The sky is blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-only</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is"}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-completion</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"chosen": " blue.",
|
||||
"rejected": " green."}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": "The sky is blue.",
|
||||
"rejected": "The sky is green."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<td>Unpaired preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue.",
|
||||
"label": True}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is green."}],
|
||||
"label": False}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
### Standard dataset format
|
||||
|
||||
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
# Language modeling
|
||||
example = {"text": "The sky is blue."}
|
||||
# Preference
|
||||
example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
```
|
||||
|
||||
### Conversational dataset format
|
||||
|
||||
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
```
|
||||
|
||||
Just like standard datasets, the columns in conversational datasets vary depending on the task. For instance, a preference dataset would include columns like `"chosen"` and `"rejected"` to compare responses:
|
||||
|
||||
```python
|
||||
example = {
|
||||
"chosen": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
### Language modeling
|
||||
|
||||
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
|
||||
|
||||
```python
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
```
|
||||
|
||||
### Prompt-only
|
||||
|
||||
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
|
||||
|
||||
```python
|
||||
prompt_only_example = {"prompt": "The sky is"}
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
While both the prompt-only and language modeling formats are similar, they differ in how the input is handled. In the prompt-only format, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling format, the input is treated as a complete sentence or sequence. These two formats are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each format:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from trl import apply_chat_template
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
|
||||
# Example for prompt-only format
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(prompt_only_example, tokenizer)
|
||||
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
||||
|
||||
# Example for language modeling format
|
||||
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(lm_example, tokenizer)
|
||||
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
||||
```
|
||||
|
||||
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
||||
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Prompt-completion
|
||||
|
||||
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
|
||||
|
||||
```python
|
||||
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
|
||||
```
|
||||
|
||||
### Preference
|
||||
|
||||
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
|
||||
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
|
||||
|
||||
```python
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} # recommended
|
||||
# or,
|
||||
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
```
|
||||
|
||||
### Unpaired preference
|
||||
|
||||
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
|
||||
|
||||
```python
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
```
|
||||
|
||||
## Which dataset format to use?
|
||||
|
||||
Choosing the right dataset format depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset formats supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset format |
|
||||
| ----------------------- | ---------------------------- |
|
||||
| [`BCOTrainer`] | Unpaired preference |
|
||||
| [`CPOTrainer`] | Preference (explicit prompt) |
|
||||
| [`DPOTrainer`] | Preference (explicit prompt) |
|
||||
| [`IterativeSFTTrainer`] | Unpaired preference |
|
||||
| [`KTOTrainer`] | Unpaired preference |
|
||||
| [`NashMDTrainer`] | Prompt-only |
|
||||
| [`OnlineDPOTrainer`] | Prompt-only |
|
||||
| [`ORPOTrainer`] | Preference (explicit prompt) |
|
||||
| [`PPOv2Trainer`] | Tokenized language modeling |
|
||||
| [`RewardTrainer`] | Preference (implicit prompt) |
|
||||
| [`SFTTrainer`] | Language modeling |
|
||||
| [`XPOTrainer`] | Prompt-only |
|
||||
|
||||
<Tip>
|
||||
|
||||
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
||||
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Working with conversational datasets in TRL
|
||||
|
||||
Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format.
|
||||
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
|
||||
|
||||
### Converting a conversational dataset into a standard dataset
|
||||
|
||||
TRL trainers do not support conversational datasets in their raw format. To use them, you need to convert them into a standard dataset format using a chat template. This template is provided by the tokenizer of the model you use.
|
||||
|
||||
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
||||
|
||||
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from trl import apply_chat_template
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
|
||||
example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]
|
||||
}
|
||||
|
||||
apply_chat_template(example, tokenizer)
|
||||
# Output:
|
||||
# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
|
||||
```
|
||||
|
||||
Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import apply_chat_template
|
||||
|
||||
dataset_dict = {
|
||||
"prompt": [[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}]],
|
||||
"completion": [[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}]]
|
||||
}
|
||||
|
||||
dataset = Dataset.from_dict(dataset_dict)
|
||||
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
# Output:
|
||||
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
|
||||
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
|
||||
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We recommend using the [`apply_chat_template`] function rather than directly calling `tokenizer.apply_chat_template`. Handling chat templates nonlanguage modeling datasets can be tricky and may lead to issues, such as inserting a system prompt in the middle of a conversation. For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
|
||||
|
||||
```python
|
||||
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
|
||||
# Output:
|
||||
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
|
||||
# 'completion': 'It is blue.<|im_end|>\n'}
|
||||
```
|
||||
|
||||
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using any dataset with TRL: preprocessing and conversion
|
||||
|
||||
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
|
||||
|
||||
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
|
||||
|
||||
### Example: UltraFeedback dataset
|
||||
|
||||
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
|
||||
|
||||
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference format, and push it to the Hub:
|
||||
|
||||
```sh
|
||||
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
|
||||
```
|
||||
|
||||
Once converted, the dataset will look like this:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Now, you can use this dataset with TRL!
|
||||
|
||||
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
|
||||
|
||||
## Utilities for converting dataset types
|
||||
|
||||
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
def concat_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From prompt-completion to prompt-only dataset
|
||||
|
||||
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to language modeling dataset
|
||||
|
||||
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": ["The sky is blue.", "The sun is in the sky."],
|
||||
"rejected": ["The sky is green.", "The sun is in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-only dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
|
||||
```
|
||||
|
||||
### From implicit to explicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to unpaired preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt, unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
### From preference to language modeling dataset
|
||||
|
||||
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
def concat_prompt_chosen(example):
|
||||
return {"text": example["prompt"] + example["chosen"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-only dataset
|
||||
|
||||
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From explicit to implicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
def concat_prompt_to_completions(example):
|
||||
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference to unpaired preference dataset
|
||||
|
||||
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
### From unpaired preference to language modeling dataset
|
||||
|
||||
To convert an unpaired preference dataset into a language modeling dataset, concatenate the prompt and the completion into the `"text"` column, and remove the prompt, completion and label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-completion dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-completion dataset, remove the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-only dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
@ -1,128 +0,0 @@
|
||||
# Denoising Diffusion Policy Optimization
|
||||
## The why
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
|
||||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
||||
|
||||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/ddpo.py`
|
||||
|
||||
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python ddpo.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a list of lists of the form
|
||||
```python
|
||||
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
||||
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
||||
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from trl import DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# memory optimization
|
||||
pipeline.vae.to(device, torch.float16)
|
||||
pipeline.text_encoder.to(device, torch.float16)
|
||||
pipeline.unet.to(device, torch.float16)
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"{prompt}.png")
|
||||
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
||||
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
||||
36
docs/source/deepspeed_integration.md
Normal file
36
docs/source/deepspeed_integration.md
Normal file
@ -0,0 +1,36 @@
|
||||
# DeepSpeed Integration
|
||||
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
|
||||
|
||||
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
To use DeepSpeed with TRL, install it using the following command:
|
||||
|
||||
```bash
|
||||
pip install deepspeed
|
||||
```
|
||||
|
||||
## Running Training Scripts with DeepSpeed
|
||||
|
||||
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
|
||||
```
|
||||
|
||||
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
@ -1,191 +0,0 @@
|
||||
# Detoxifying a Language Model using PPO
|
||||
|
||||
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
|
||||
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
||||
|
||||
### Computing toxicity scores
|
||||
|
||||
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
||||
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
||||
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
||||
|
||||
### Selection of models
|
||||
|
||||
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
||||
|
||||
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
|---|---|
|
||||
| `gpt2` | 0.01602 |
|
||||
| `facebook/opt-350m` | 0.01628 |
|
||||
| `bigscience/bloom-560m` | 0.00767 |
|
||||
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
||||
|
||||
## Designing the problem
|
||||
|
||||
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
```
|
||||
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
|
||||
```
|
||||
And its `continuation` value:
|
||||
```
|
||||
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
|
||||
```
|
||||
|
||||
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
||||
```python
|
||||
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
||||
```
|
||||
|
||||
### Reward function
|
||||
|
||||
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
||||
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
||||
```python
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
rewards = (logits[:, 0]).tolist()
|
||||
```
|
||||
|
||||
### Impact of input prompts length
|
||||
|
||||
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
||||
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-long-vs-short-context.png">
|
||||
</div>
|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ppo_trainer = PPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_shared_layers=4,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
## Training the model!
|
||||
|
||||
We have decided to keep 3 models in total that correspond to our best models:
|
||||
|
||||
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
||||
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
||||
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
||||
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-collapse-mode.png">
|
||||
</div>
|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-final-run-2.png">
|
||||
</div>
|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-mbs-run.png">
|
||||
</div>
|
||||
|
||||
## Results
|
||||
|
||||
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
||||
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
||||
|
||||
| Model | Mean toxicity score | Std toxicity score |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
||||
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
||||
|
||||
## What is next?
|
||||
|
||||
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
||||
190
docs/source/distributing_training.md
Normal file
190
docs/source/distributing_training.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Distributing Training
|
||||
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## Multi-GPU Training with TRL
|
||||
|
||||
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
|
||||
```
|
||||
|
||||
This automatically distributes the workload across all available GPUs.
|
||||
|
||||
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
||||
|
||||
- Processes its own batch of data
|
||||
- Computes the loss and gradients for that batch
|
||||
- Shares gradient updates across all GPUs
|
||||
|
||||

|
||||
|
||||
The effective batch size is calculated as:
|
||||
|
||||
$$
|
||||
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
|
||||
$$
|
||||
|
||||
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
|
||||
|
||||
Example, these configurations are equivalent, and should yield the same results:
|
||||
|
||||
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
|
||||
| --- | --- | --- | --- |
|
||||
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
|
||||
| 1 | 4 | 8 | Lower memory usage, slower training |
|
||||
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
|
||||
|
||||
> [!TIP]
|
||||
> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
|
||||
|
||||
## Context Parallelism
|
||||
|
||||
Context Parallelism (CP) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
|
||||
|
||||
For more details on CP, see the [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism).
|
||||
|
||||
CP is particularly useful when:
|
||||
|
||||
- You want to train with very long sequences (>32k tokens)
|
||||
- Single GPU memory is insufficient for your desired sequence length
|
||||
- You need to maintain sequence coherence across the full context
|
||||
|
||||
### Requirements and Limitations
|
||||
|
||||
CP has specific requirements:
|
||||
|
||||
1. **Accelerate 1.10 or higher** is required
|
||||
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
|
||||
3. **SDPA attention** - Flash Attention is currently not supported with CP
|
||||
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is now automatically handled using the `pad_to_multiple_of` parameter in the data collator, which works seamlessly with both standard and padding-free modes.
|
||||
|
||||
### Configuration
|
||||
|
||||
To enable CP, you need to configure both Accelerate and your training arguments:
|
||||
|
||||
#### Accelerate Configuration
|
||||
|
||||
Use one of the provided accelerate config files (e.g. [`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml) for 2 GPUs):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2 # Number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
parallelism_config:
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
parallelism_config_cp_size: 2 # Context parallel size
|
||||
```
|
||||
|
||||
#### Training Configuration
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
# required
|
||||
pad_to_multiple_of=4, # ensures divisibility by cp_size * 2
|
||||
# to get the most out of CP
|
||||
max_length=16384, # long sequence length
|
||||
packing=True, # use packing to reduce padding
|
||||
use_liger_kernel=True, # compatible with CP
|
||||
gradient_checkpointing=False, # The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg can't be set to True simultaneously
|
||||
per_device_train_batch_size=1,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Then, launch your training script with the appropriate accelerate config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file context_parallel_2gpu.yaml train.py
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility:
|
||||
- For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`)
|
||||
- For `cp_size=4`: use `pad_to_multiple_of=8` (since `cp_size * 2 = 8`)
|
||||
- The data collator automatically pads sequences to the required multiple, ensuring compatibility with CP
|
||||
|
||||
2. **Use packing with padding** - The default BFD (Best Fit Decreasing) strategy works perfectly:
|
||||
- Preserves sequence boundaries and maintains training quality
|
||||
- Works seamlessly with both `padding_free=True` and standard padding modes
|
||||
|
||||
3. **Combine with other memory optimizations** like Liger kernels, bfloat16, and gradient checkpointing
|
||||
|
||||
4. **Start with smaller context parallel sizes** (2-4 GPUs) before scaling up
|
||||
|
||||
5. **Monitor memory usage** across all GPUs to ensure balanced workload
|
||||
|
||||
### Benchmarking Context Parallelism
|
||||
|
||||
We benchmarked CP to highlight its potential improvements in training efficiency.
|
||||
Our experiments were conducted using **1, 2, 4, and 8 H100 GPUs**, though the results can be extended to larger clusters with more nodes and GPUs.
|
||||
|
||||
For the setup, we fine-tuned an **8B model** ([Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B)) using the provided accelerate configuration
|
||||
([`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml)).
|
||||
We adjusted `num_processes` and `parallelism_config_cp_size` based on the number of GPUs for each run.
|
||||
Training was performed with the [sft.py](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) example script, combined with the parameters described above.
|
||||
|
||||
The results below summarize the **maximum trainable sequence length** and **iterations per second** for different numbers of GPUs. A value marked as `OOM` indicates that the configuration ran out of memory and could not be trained.
|
||||
|
||||
These results show that **Context Parallelism (CP) scales effectively with more GPUs**, enabling training on much longer sequences. With **8 GPUs**, context lengths of over **300k tokens** become feasible, unlocking training with extremely long contexts while maintaining reasonable throughput.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_max_length_plot.png" alt="CP Max content length" width="45%"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_s_it_plot.png" alt="CP seconds/iteration" width="45%"/>
|
||||
</div>
|
||||
|
||||
> [!TIP]
|
||||
> Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs.
|
||||
>
|
||||
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
|
||||
|
||||
### Further Reading on Context Parallelism
|
||||
|
||||
- [Accelerate: Context Parallelism Guide](https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md)
|
||||
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
|
||||
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
|
||||
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST) — Scalable and Efficient Training for Multi-Million Token Sequences (Note that they use a different strategy)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
||||
300
docs/source/dpo_trainer.md
Normal file
300
docs/source/dpo_trainer.md
Normal file
@ -0,0 +1,300 @@
|
||||
# DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
|
||||
|
||||
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
|
||||
|
||||
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||

|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_dpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO")
|
||||
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-DPO
|
||||
<strong><span style="color: red;"><shirin_yamani>:</span></strong>
|
||||
What is Huggingface?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-DPO>:</span></strong>
|
||||
Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
||||
|
||||
### Special considerations for vision-language models
|
||||
|
||||
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
|
||||
|
||||
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
|
||||
|
||||
```diff
|
||||
- model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
+ model = AutoModelForImageTextToText.from_pretrained(model_id)
|
||||
|
||||
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
+ processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
- processing_class=tokenizer,
|
||||
+ processing_class=processor,
|
||||
)
|
||||
```
|
||||
|
||||
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
||||
|
||||
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch trl/scripts/dpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir Qwen2-0.5B-DPO
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## Loss functions
|
||||
|
||||
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| --- | --- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
|
||||
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
|
||||
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
|
||||
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
|
||||
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
|
||||
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
|
||||
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
|
||||
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
|
||||
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
|
||||
| `"sft"` | SFT (Supervised Fine-Tuning) loss is the negative log likelihood loss, used to train the model to generate preferred responses. |
|
||||
|
||||
### Multi-loss combinations
|
||||
|
||||
The DPO trainer supports combining multiple loss functions with different weights, enabling more sophisticated optimization strategies. This is particularly useful for implementing algorithms like MPO (Mixed Preference Optimization). MPO is a training approach that combines multiple optimization objectives, as described in the paper [Enhancing the Reasoning Ability of Multimodal Large Language Models via Mixed Preference Optimization](https://huggingface.co/papers/2411.10442).
|
||||
|
||||
To combine multiple losses, specify the loss types and corresponding weights as lists:
|
||||
|
||||
```python
|
||||
# MPO: Combines DPO (sigmoid) for preference and BCO (bco_pair) for quality
|
||||
training_args = DPOConfig(
|
||||
loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine
|
||||
loss_weights=[0.8, 0.2, 1.0] # Corresponding weights, as used in the MPO paper
|
||||
)
|
||||
```
|
||||
|
||||
If `loss_weights` is not provided, all loss types will have equal weights (1.0 by default).
|
||||
|
||||
### Label smoothing
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
|
||||
|
||||
### Syncing the reference model
|
||||
|
||||
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
|
||||
|
||||
### RPO loss
|
||||
|
||||
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
|
||||
|
||||
### WPO loss
|
||||
|
||||
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
|
||||
|
||||
### LD-DPO loss
|
||||
|
||||
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## Accelerate DPO fine-tuning using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```diff
|
||||
from datasets import load_dataset
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
- from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
+ from unsloth import FastLanguageModel
|
||||
|
||||
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
+ model = FastLanguageModel.get_peft_model(model)
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO")
|
||||
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", bf16=True)
|
||||
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Reference model considerations with PEFT
|
||||
|
||||
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
||||
|
||||
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
||||
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
||||
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
||||
|
||||
### Downsides to merging QLoRA before DPO (approach 2)
|
||||
|
||||
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
||||
|
||||
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
||||
|
||||
### Using option 3 - load the adapter twice
|
||||
|
||||
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Load the base model.
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/mixtral-8x7b-v0.1",
|
||||
load_in_4bit=True,
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="flash_attention_2",
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
# Load the adapter.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
"/path/to/peft",
|
||||
is_trainable=True,
|
||||
adapter_name="train",
|
||||
)
|
||||
# Load the adapter a second time, with a different name, which will be our reference model.
|
||||
model.load_adapter("/path/to/peft", adapter_name="reference")
|
||||
|
||||
# Initialize the trainer, without a ref_model param.
|
||||
training_args = DPOConfig(
|
||||
model_adapter_name="train",
|
||||
ref_adapter_name="reference",
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## DPOConfig
|
||||
|
||||
[[autodoc]] DPOConfig
|
||||
|
||||
## DataCollatorForPreference
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
|
||||
|
||||
## FDivergenceType
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.FDivergenceType
|
||||
@ -1,297 +0,0 @@
|
||||
# DPO Trainer
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
|
||||
|
||||
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
## How DPO works
|
||||
|
||||
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
|
||||
|
||||
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots/direct-preference-optimization-datasets](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) Collection to identify datasets that are likely to support DPO training.
|
||||
|
||||
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
|
||||
</div>
|
||||
|
||||
Therefore the final dataset object should contain these 3 entries if you use the default [`DPODataCollatorWithPadding`] data collator. The entries should be named:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
dpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
[`DPOTrainer`] can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`, and the trainer's `tokenizer` is the VLM's `processor`. For example, for Idefics2, the processor expects the dataset to have the following format:
|
||||
|
||||
Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs.
|
||||
|
||||
```py
|
||||
dpo_dataset_dict = {
|
||||
'images': [
|
||||
[Image.open('beach.jpg')],
|
||||
[Image.open('street.jpg')],
|
||||
],
|
||||
'prompt': [
|
||||
'The image <image> shows',
|
||||
'<image> The image depicts',
|
||||
],
|
||||
'chosen': [
|
||||
'a sunny beach with palm trees.',
|
||||
'a busy street with several cars and buildings.',
|
||||
],
|
||||
'rejected': [
|
||||
'a snowy mountain with skiers.',
|
||||
'a calm countryside with green fields.',
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Expected model format
|
||||
|
||||
The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `DPOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the [`DPOTrainer`] with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
```py
|
||||
training_args = DPOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead
|
||||
)
|
||||
```
|
||||
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
dpo_trainer.train()
|
||||
```
|
||||
|
||||
Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.
|
||||
|
||||
## Loss functions
|
||||
|
||||
Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. To use this loss, set the `loss_type="sigmoid"` (default) in the [`DPOConfig`].
|
||||
|
||||
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. To use this loss, set the `loss_type="hinge"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. To use the loss set the `loss_type="ipo"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
|
||||
|
||||
The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo_pair"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large.
|
||||
|
||||
The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca_pair"` in the [`DPOConfig`].
|
||||
|
||||
The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust"` in the [`DPOConfig`].
|
||||
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco_pair"` in the [`DPOConfig`].
|
||||
|
||||
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
|
||||
|
||||
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to 1.0.
|
||||
|
||||
The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. To use this loss, set the `loss_type="sppo_hard"` in the [`DPOConfig`].
|
||||
|
||||
The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.
|
||||
|
||||
The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. To use these losses, set `loss_type="apo_zero"` or `loss_type="apo_down"` in the [`DPOConfig`].
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## Accelerate DPO fine-tuning using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| -------- | --------- | ---------- | --- | ---------------------- | ---------- | ------------- |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "unsloth/zephyr-sft",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0, # Dropout = 0 is currently optimized
|
||||
bias = "none", # Bias = "none" is currently optimized
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir="./output",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
dpo_trainer.train()
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Reference model considerations with PEFT
|
||||
|
||||
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
||||
|
||||
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
||||
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
||||
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
||||
|
||||
### Downsides to merging QLoRA before DPO (approach 2)
|
||||
|
||||
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
||||
|
||||
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
||||
|
||||
### Using option 3 - load the adapter twice
|
||||
|
||||
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Load the base model.
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/mixtral-8x7b-v0.1",
|
||||
load_in_4bit=True,
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
# Load the adapter.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
"/path/to/peft",
|
||||
is_trainable=True,
|
||||
adapter_name="train",
|
||||
)
|
||||
# Load the adapter a second time, with a different name, which will be our reference model.
|
||||
model.load_adapter("/path/to/peft", adapter_name="reference")
|
||||
|
||||
# Initialize the trainer, without a ref_model param.
|
||||
training_args = DPOConfig(
|
||||
model_adapter_name="train",
|
||||
ref_adapter_name="reference",
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## DPOConfig
|
||||
|
||||
[[autodoc]] DPOConfig
|
||||
@ -1,82 +1,85 @@
|
||||
# Examples
|
||||
|
||||
This directory contains a collection of examples that demonstrate how to use the TRL library for various applications. We provide both **scripts** for advanced use cases and **notebooks** for an easy start and interactive experimentation.
|
||||
|
||||
## Introduction
|
||||
The notebooks are self-contained and can run on **free Colab**, while the scripts can run on **single GPU, multi-GPU, or DeepSpeed** setups.
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single GPU
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
**Getting Started**
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
|
||||
**NOTE to train with a 4-bit or 8-bit model**, please run
|
||||
Install TRL and additional dependencies as follows:
|
||||
|
||||
```bash
|
||||
pip install --upgrade trl[quantization]
|
||||
```
|
||||
|
||||
Check for additional optional dependencies [here](https://github.com/huggingface/trl/blob/main/pyproject.toml).
|
||||
|
||||
## Accelerate Config
|
||||
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
|
||||
For scripts, you will also need an 🤗 Accelerate config (recommended for multi-gpu settings):
|
||||
|
||||
```shell
|
||||
```bash
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
```
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
This allows you to run scripts with `accelerate launch` in single or multi-GPU settings.
|
||||
|
||||
## Notebooks
|
||||
|
||||
These notebooks are easier to run and are designed for quick experimentation with TRL. The list of notebooks can be found in the [`trl/examples/notebooks/`](https://github.com/huggingface/trl/tree/main/examples/notebooks/) directory.
|
||||
|
||||
|
||||
# Maintained Examples
|
||||
| Notebook | Description | Open in Colab |
|
||||
|----------|-------------|---------------|
|
||||
| [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) |
|
||||
| [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) |
|
||||
| [`grpo_qwen3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_qwen3_vl.ipynb) | GRPO Qwen3-VL with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb) |
|
||||
|
||||
## Scripts
|
||||
|
||||
Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more.
|
||||
|
||||
| File | Description |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_visual.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_visual.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
|
||||
| [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested on a [LLaVA 1.5]([llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)) model so users may see unexpected behaviour in other model architectures. |
|
||||
File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. |
|
||||
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
||||
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
|
||||
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
|
||||
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. |
|
||||
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
|
||||
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models, so users may see unexpected behaviour in other model architectures. |
|
||||
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
|
||||
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
|
||||
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
|
||||
|
||||
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
||||
## Distributed Training (for scripts)
|
||||
|
||||
| File | Description |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
|
||||
## Distributed training
|
||||
|
||||
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
|
||||
You can run scripts on multiple GPUs with 🤗 Accelerate:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
|
||||
For DeepSpeed ZeRO-{1,2,3}:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Adjust `NUM_GPUS` and `--all_arguments_of_the_script` as needed.
|
||||
|
||||
31
docs/source/experimental_overview.md
Normal file
31
docs/source/experimental_overview.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Experimental
|
||||
|
||||
This directory contains a minimal, clearly separated space for fast iteration on new ideas.
|
||||
|
||||
> [!WARNING]
|
||||
> **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
|
||||
|
||||
## Promotion Path (Simple)
|
||||
|
||||
1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly.
|
||||
2. **Experimental inclusion:** Once it’s ready for early users, move the idea into `trl.experimental.<feature>`.
|
||||
3. **Improve:** Add tests, a short doc/example, and demonstrate the usage.
|
||||
4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.<feature>` (stable module).
|
||||
|
||||
## FAQ
|
||||
|
||||
**Why not just use branches?**
|
||||
Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.
|
||||
|
||||
**Can these APIs change or vanish without warning?**
|
||||
Yes. Anything inside `trl.experimental` can change or disappear in *any* release.
|
||||
|
||||
**Should I use this in production?**
|
||||
Only if you are fine with updating your code quickly when things change.
|
||||
|
||||
**Will maintainers promptly fix issues in `trl.experimental`?**
|
||||
Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.
|
||||
|
||||
**How to silence the runtime notice?**
|
||||
|
||||
Use: `export TRL_EXPERIMENTAL_SILENCE=1`.
|
||||
39
docs/source/gfpo.md
Normal file
39
docs/source/gfpo.md
Normal file
@ -0,0 +1,39 @@
|
||||
# GFPO
|
||||
|
||||
This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726).
|
||||
|
||||
## Usage
|
||||
|
||||
To activate GFPO in [`GFPOTrainer`]:
|
||||
|
||||
- set `num_remains_in_group` in [`GFPOConfig`]
|
||||
- define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group.
|
||||
|
||||
```python
|
||||
# train_gfpo.py
|
||||
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer
|
||||
|
||||
# dummy group filter to scores the completions based on its indice in group
|
||||
class GroupFilter:
|
||||
def __call__(self, group_completions, group_rewards, **kwargs):
|
||||
group_scores = []
|
||||
for completions, rewards in zip(group_completions, group_rewards):
|
||||
scores = [float(i) for i in range(len(completions))]
|
||||
group_scores.append(scores)
|
||||
return group_scores
|
||||
|
||||
training_args = GFPOConfig(
|
||||
output_dir="Qwen3-0.6B-GFPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_remains_in_group=2,
|
||||
bf16=True,
|
||||
)
|
||||
trainer = GFPOTrainer(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reward_funcs=...,
|
||||
train_dataset=...,
|
||||
args=training_args,
|
||||
group_filter_func=GroupFilter(),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
@ -1,15 +1,17 @@
|
||||
# Generalized Knowledge Distillation Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
|
||||
|
||||
|
||||
The key aspects of GKD are:
|
||||
|
||||
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
|
||||
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
|
||||
|
||||
@ -17,8 +19,10 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
|
||||
|
||||
## Usage tips
|
||||
|
||||
The GKD Trainer is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs two parameters to be set via the [`GKDConfig`] namely:
|
||||
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
|
||||
|
||||
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
|
||||
|
||||
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
|
||||
@ -67,28 +71,31 @@ eval_dataset = Dataset.from_dict(
|
||||
}
|
||||
)
|
||||
|
||||
args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
|
||||
training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
|
||||
trainer = GKDTrainer(
|
||||
model=model,
|
||||
teacher_model=teacher_model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Expected dataset format
|
||||
### Expected dataset type
|
||||
|
||||
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
|
||||
|
||||
* `role`: either `system`, `assistant` or `user`
|
||||
* `content`: the message content
|
||||
|
||||
|
||||
## GKDTrainer
|
||||
|
||||
[[autodoc]] GKDTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## GKDConfig
|
||||
|
||||
|
||||
120
docs/source/gold_trainer.md
Normal file
120
docs/source/gold_trainer.md
Normal file
@ -0,0 +1,120 @@
|
||||
# General Online Logit Distillation (GOLD) Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=sft,gold)
|
||||
|
||||
## Overview
|
||||
|
||||
General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports
|
||||
student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the
|
||||
associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including
|
||||
mixed model families (for example, LLaMA students with Qwen teachers).
|
||||
|
||||
Key capabilities:
|
||||
|
||||
1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
|
||||
2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
|
||||
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
|
||||
|
||||
> [!NOTE]
|
||||
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.
|
||||
|
||||
## Usage tips
|
||||
|
||||
The [`GOLDTrainer`] subclasses [`SFTTrainer`] and accepts the same datasets as other TRL trainers (lists of ChatML style
|
||||
messages). Important configuration flags on [`GOLDConfig`] include:
|
||||
|
||||
* `use_uld_loss` – toggles Universal Logit Distillation. Set this to `True` for cross-tokenizer setups.
|
||||
* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
|
||||
* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
|
||||
matched/unmatched loss.
|
||||
* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy
|
||||
sampling ratio.
|
||||
|
||||
A minimal end-to-end example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl.experimental.gold import GOLDConfig, GOLDTrainer
|
||||
|
||||
train_dataset = load_dataset(
|
||||
"HuggingFaceTB/OpenR1-Math-220k-default-verified",
|
||||
"all",
|
||||
split="train[:1024]",
|
||||
)
|
||||
|
||||
trainer = GOLDTrainer(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
teacher_model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"),
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating `AutoModelForCausalLM`, `AutoTokenizer`, or populating `GOLDConfig` is recommended only for advanced use cases where you need fine-grained control over initialization.
|
||||
|
||||
A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GOLDConfig, GOLDTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
student_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
teacher_name = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(student_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(student_name)
|
||||
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name)
|
||||
|
||||
train_dataset = load_dataset(
|
||||
"HuggingFaceTB/Countdown-Task-GOLD",
|
||||
"verified_Qwen2.5-0.5B-Instruct",
|
||||
split="train",
|
||||
)
|
||||
|
||||
training_args = GOLDConfig(
|
||||
output_dir="gold-model",
|
||||
per_device_train_batch_size=1,
|
||||
teacher_model=teacher_name,
|
||||
teacher_tokenizer_name_or_path=teacher_name,
|
||||
use_uld_loss=True,
|
||||
uld_use_hybrid_loss=True,
|
||||
)
|
||||
|
||||
trainer = GOLDTrainer(
|
||||
model=model,
|
||||
teacher_model=teacher_model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Expected dataset type
|
||||
|
||||
GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.:
|
||||
|
||||
```python
|
||||
{"messages": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}]}
|
||||
```
|
||||
|
||||
`GOLDTrainer` keeps the raw messages so the ChatML collator can construct prompts and completions with the correct
|
||||
boundaries.
|
||||
|
||||
## GOLDTrainer
|
||||
|
||||
[[autodoc]] experimental.gold.GOLDTrainer
|
||||
- train
|
||||
- generate_on_policy_outputs
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## GOLDConfig
|
||||
|
||||
[[autodoc]] experimental.gold.GOLDConfig
|
||||
598
docs/source/grpo_trainer.md
Normal file
598
docs/source/grpo_trainer.md
Normal file
@ -0,0 +1,598 @@
|
||||
# GRPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
|
||||
|
||||
This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model.
|
||||
|
||||
```python
|
||||
# train_grpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
|
||||
|
||||
# Dummy reward function for demonstration purposes
|
||||
def reward_num_unique_letters(completions, **kwargs):
|
||||
"""Reward function that rewards completions with more unique letters."""
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
return [float(len(set(content))) for content in completion_contents]
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_num_unique_letters,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_grpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 day.
|
||||
|
||||

|
||||
|
||||
## Looking deeper into the GRPO method
|
||||
|
||||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
|
||||
|
||||

|
||||
|
||||
### Generating completions
|
||||
|
||||
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
For each of the \\( G \\) sequences, we compute the reward using a reward model or reward function. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
|
||||
|
||||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
|
||||
> [!TIP]
|
||||
> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
> [!TIP]
|
||||
> As shown in [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221), calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
|
||||
|
||||
### Estimating the KL divergence
|
||||
|
||||
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
|
||||
|
||||
$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,
|
||||
$$
|
||||
|
||||
### Computing the loss
|
||||
|
||||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
|
||||
> [!TIP]
|
||||
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
|
||||
|
||||
> [!TIP]
|
||||
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
|
||||
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
|
||||
|
||||
#### Loss Types
|
||||
|
||||
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
where
|
||||
|
||||
$$
|
||||
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
|
||||
$$
|
||||
|
||||
The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
To use this formulation, set `loss_type="dapo"` in [`GRPOConfig`].
|
||||
|
||||
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
|
||||
- `completions/mean_length`: The average length of generated completions.
|
||||
- `completions/min_length`: The minimum length of generated completions.
|
||||
- `completions/max_length`: The maximum length of generated completions.
|
||||
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
|
||||
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS.
|
||||
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
|
||||
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
|
||||
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
- `reward_std`: The standard deviation of rewards after applying reward weights.
|
||||
- If `scale_rewards` is `"group"` or `"none"`, this is the average of the per-group standard deviations.
|
||||
- If `scale_rewards` is `"batch"`, this is the standard deviation computed over all rewards in the batch (ignoring groups).
|
||||
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
|
||||
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
|
||||
- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
|
||||
|
||||
## Customization
|
||||
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
> [!TIP]
|
||||
> By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)
|
||||
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
>
|
||||
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
>
|
||||
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
|
||||
>
|
||||
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
>
|
||||
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
||||
> [!TIP]
|
||||
> By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### GRPO at scale: train a 70B+ Model on multiple nodes
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
|
||||
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
|
||||
|
||||
```sh
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=5
|
||||
#SBATCH --gres=gpu:8
|
||||
|
||||
# Get the list of allocated nodes
|
||||
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
|
||||
|
||||
# Assign the first 4 nodes for training and the 5th node for vLLM
|
||||
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
|
||||
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
|
||||
|
||||
# Run training on the first 4 nodes (Group 1)
|
||||
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
--num_processes 32 \
|
||||
--num_machines 4 \
|
||||
--main_process_ip ${NODELIST[0]} \
|
||||
--machine_rank $SLURM_PROCID \
|
||||
--rdzv_backend c10d \
|
||||
train_grpo.py \
|
||||
--server_ip $VLLM_NODE &
|
||||
|
||||
# Run vLLM server on the 5th node (Group 2)
|
||||
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
|
||||
|
||||
wait
|
||||
```
|
||||
|
||||
```python
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example dataset from TLDR
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-72B-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
use_vllm=True,
|
||||
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
|
||||
trainer.train()
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Using a custom reward function
|
||||
|
||||
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
|
||||
|
||||
1. **Input arguments**:
|
||||
- The function must accept the following as keyword arguments:
|
||||
- `prompts` (contains the prompts),
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
|
||||
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
- Depending on the dataset format, the input will vary:
|
||||
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
|
||||
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
|
||||
|
||||
2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
|
||||
|
||||
#### Example 1: Reward longer completions
|
||||
|
||||
Below is an example of a reward function for a standard format that rewards longer completions:
|
||||
|
||||
```python
|
||||
def reward_func(completions_ids, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
||||
return [float(len(ids)) for ids in completions_ids]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based on the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
```python
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
|
||||
return [float(len(completion)) for completion in completions]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
#### Example 2: Reward completions with a specific format
|
||||
|
||||
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
It is designed for a conversational format, where prompts and completions consist of structured messages.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = [
|
||||
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
|
||||
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
|
||||
... ]
|
||||
>>> completions = [
|
||||
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
|
||||
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
|
||||
... ]
|
||||
>>> format_reward_func(prompts=prompts, completions=completions)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 3: Reward completions based on a reference
|
||||
|
||||
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def reward_func(completions, ground_truth, **kwargs):
|
||||
# Regular expression to capture content inside \boxed{}
|
||||
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
|
||||
contents = [match.group(1) if match else "" for match in matches]
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
|
||||
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
|
||||
>>> ground_truth = ["2", "5"]
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Define a dataset that contains both math and coding problems
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
{"prompt": "What is 2+2?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
|
||||
{"prompt": "What is 3*4?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
|
||||
]
|
||||
)
|
||||
|
||||
# Math-specific reward function
|
||||
def math_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "math":
|
||||
# Calculate math-specific reward
|
||||
correct = check_math_solution(prompt, completion)
|
||||
reward = 1.0 if correct else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-math tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Coding-specific reward function
|
||||
def coding_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "coding":
|
||||
# Calculate coding-specific reward
|
||||
works = test_code_solution(prompt, completion)
|
||||
reward = 1.0 if works else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-coding tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Use both task-specific reward functions
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=[math_reward_func, coding_reward_func],
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
|
||||
|
||||
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=reward_func,
|
||||
...,
|
||||
)
|
||||
```
|
||||
|
||||
If you have multiple reward functions, you can pass them as a list:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=[reward_func1, reward_func2],
|
||||
...,
|
||||
)
|
||||
```
|
||||
|
||||
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
|
||||
|
||||
Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
|
||||
|
||||
## Vision-Language Model (VLM) Training
|
||||
|
||||
GRPO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.
|
||||
|
||||
### Supported Models
|
||||
|
||||
Tested with:
|
||||
|
||||
- **Gemma3** — e.g., `google/gemma-3-4b-it`
|
||||
- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf`
|
||||
- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct`
|
||||
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
|
||||
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
|
||||
|
||||
> [!TIP]
|
||||
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
|
||||
|
||||
### Quick Start
|
||||
|
||||
Use [grpo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/grpo_vlm.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--gradient_checkpointing \
|
||||
--dtype bfloat16 \
|
||||
--max_prompt_length 2048 \
|
||||
--max_completion_length 1024 \
|
||||
--use_vllm \
|
||||
--vllm_mode colocate \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
```
|
||||
|
||||
### Configuration Tips
|
||||
|
||||
> [!TIP]
|
||||
> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_prompt_length=None` in the [`GRPOConfig`]. This allows the model to process the full sequence length without truncating image tokens.
|
||||
>
|
||||
> ```python
|
||||
> GRPOConfig(max_prompt_length=None, ...)
|
||||
> ```
|
||||
>
|
||||
> Only use `max_prompt_length` when you've verified that truncation won't remove image tokens for the entire dataset.
|
||||
|
||||
- Use LoRA on vision-language projection layers
|
||||
- Enable 4-bit quantization to reduce memory usage
|
||||
- VLMs are memory-intensive — start with smaller batch sizes
|
||||
- Most models are compatible with vLLM (`server` and `colocate` modes)
|
||||
|
||||
### Dataset Format
|
||||
|
||||
Each training sample should include:
|
||||
|
||||
- `prompt`: Text formatted via the processor's chat template
|
||||
- `image`/`images`: PIL Image or list of PIL Images
|
||||
|
||||
The trainer automatically handles image-to-tensor conversion via the model’s image processor.
|
||||
|
||||
## GRPOTrainer
|
||||
|
||||
[[autodoc]] GRPOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## GRPOConfig
|
||||
|
||||
[[autodoc]] GRPOConfig
|
||||
39
docs/source/grpo_with_replay_buffer.md
Normal file
39
docs/source/grpo_with_replay_buffer.md
Normal file
@ -0,0 +1,39 @@
|
||||
# GRPO With Replay Buffer
|
||||
|
||||
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
# Guarantee that some rewards have 0 std
|
||||
def custom_reward_func(completions, **kwargs):
|
||||
if torch.rand(1).item() < 0.25:
|
||||
return [0] * len(completions) # simulate some None rewards
|
||||
else:
|
||||
return torch.rand(len(completions)).tolist()
|
||||
|
||||
training_args = GRPOWithReplayBufferConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=1e-4,
|
||||
per_device_train_batch_size=4,
|
||||
num_generations=4,
|
||||
max_completion_length=8,
|
||||
replay_buffer_size=8,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs=[custom_reward_func],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
18
docs/source/gspo_token.md
Normal file
18
docs/source/gspo_token.md
Normal file
@ -0,0 +1,18 @@
|
||||
# GSPO-token
|
||||
|
||||
In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from trl.experimental.gspo_token import GRPOTrainer
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
importance_sampling_level="sequence_token",
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
|
||||
@ -1,65 +0,0 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
||||
150
docs/source/index.md
Normal file
150
docs/source/index.md
Normal file
@ -0,0 +1,150 @@
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
## 🎉 What's New
|
||||
|
||||
**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows.
|
||||
|
||||
Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](openenv).
|
||||
|
||||
## Taxonomy
|
||||
|
||||
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support; 🧪 = experimental).
|
||||
|
||||
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
### Online methods
|
||||
|
||||
- [`GRPOTrainer`] ⚡️
|
||||
- [`RLOOTrainer`] ⚡️
|
||||
- [`OnlineDPOTrainer`] ⚡️
|
||||
- [`NashMDTrainer`] ⚡️
|
||||
- [`XPOTrainer`] ⚡️
|
||||
- [`PPOTrainer`]
|
||||
|
||||
### Reward modeling
|
||||
|
||||
- [`PRMTrainer`]
|
||||
- [`RewardTrainer`]
|
||||
|
||||
</div>
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
### Offline methods
|
||||
|
||||
- [`SFTTrainer`]
|
||||
- [`DPOTrainer`]
|
||||
- [`ORPOTrainer`]
|
||||
- [`experimental.bco.BCOTrainer`] 🧪
|
||||
- [`CPOTrainer`]
|
||||
- [`KTOTrainer`]
|
||||
|
||||
### Knowledge distillation
|
||||
|
||||
- [`GKDTrainer`]
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
|
||||
|
||||
## Learn
|
||||
|
||||
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into the following sections:
|
||||
|
||||
- **Getting Started**: installation and quickstart guide.
|
||||
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
|
||||
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
|
||||
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
|
||||
- **Examples**: example overview, community tutorials, etc.
|
||||
- **API**: trainers, utils, etc.
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-vlm-alignment">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/openenv/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published October 23, 2025</p>
|
||||
<p class="text-gray-700">Building the Open Agent Ecosystem Together: Introducing OpenEnv</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-vlm-alignment">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/trl_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on August 7, 2025</p>
|
||||
<p class="text-gray-700">Vision Language Model Alignment in TRL ⚡️</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/vllm-colocate">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/vllm-colocate/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on June 3, 2025</p>
|
||||
<p class="text-gray-700">NO GPU left behind: Unlocking Efficiency with Co-located vLLM in TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/liger-grpo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/liger-grpo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on May 25, 2025</p>
|
||||
<p class="text-gray-700">🐯 Liger GRPO meets TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/open-r1">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/open-r1/thumbnails.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on January 28, 2025</p>
|
||||
<p class="text-gray-700">Open-R1: a fully open reproduction of DeepSeek-R1</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/putting_rl_back_in_rlhf_with_rloo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on June 12, 2024</p>
|
||||
<p class="text-gray-700">Putting RL back in RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/166_trl_ddpo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on September 29, 2023</p>
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/157_dpo_trl/dpo_thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on August 8, 2023</p>
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/138_stackllama/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on April 5, 2023</p>
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/133_trl_peft/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on March 9, 2023</p>
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on December 9, 2022</p>
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Talks
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/Fine%20tuning%20with%20TRL%20(Oct%2025).pdf">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/Fine%20tuning%20with%20TRL%20(Oct%2025).png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Talk given on October 30, 2025</p>
|
||||
<p class="text-gray-700">Fine tuning with TRL</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
@ -1,65 +0,0 @@
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
|
||||
Check the appropriate sections of the documentation depending on your needs:
|
||||
|
||||
## API documentation
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
|
||||
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/166_trl_ddpo/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
42
docs/source/installation.md
Normal file
42
docs/source/installation.md
Normal file
@ -0,0 +1,42 @@
|
||||
# Installation
|
||||
|
||||
You can install TRL either from PyPI or from source:
|
||||
|
||||
## PyPI
|
||||
|
||||
Install the library with pip or [uv](https://docs.astral.sh/uv/):
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
|
||||
|
||||
```bash
|
||||
uv pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip">
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Source
|
||||
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you want the development install you can replace the pip install with the following:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
@ -1,24 +0,0 @@
|
||||
# Installation
|
||||
You can install TRL either from pypi or from source:
|
||||
|
||||
## pypi
|
||||
Install the library with pip:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### Source
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you want the development install you can replace the pip install with the following:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
@ -1,54 +0,0 @@
|
||||
# Iterative Trainer
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
|
||||
```python
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
|
||||
## IterativeTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
274
docs/source/jobs_training.md
Normal file
274
docs/source/jobs_training.md
Normal file
@ -0,0 +1,274 @@
|
||||
# Training with Jobs
|
||||
|
||||
[](https://huggingface.co/models?other=hf_jobs,trl)
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure—no need to manage GPUs or local environment setup.
|
||||
|
||||
In this guide, you'll learn how to:
|
||||
|
||||
* Use [TRL Jobs](https://github.com/huggingface/trl-jobs) to easily run pre-optimized TRL training
|
||||
* Run any TRL training script with uv scripts
|
||||
|
||||
For general details about Hugging Face Jobs (hardware selection, job monitoring, etc.), see the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/guides/jobs).
|
||||
|
||||
## Requirements
|
||||
|
||||
* A [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan
|
||||
* Logged in to the Hugging Face Hub (`hf auth login`)
|
||||
|
||||
## Using TRL Jobs
|
||||
|
||||
[TRL Jobs](https://github.com/huggingface/trl-jobs) is a high-level wrapper around Hugging Face Jobs and TRL that streamlines training. It provides optimized default configurations so you can start quickly without manually tuning parameters.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
pip install trl-jobs
|
||||
trl-jobs sft --model_name Qwen/Qwen3-0.6B --dataset_name trl-lib/Capybara
|
||||
```
|
||||
|
||||
TRL Jobs supports everything covered in this guide, with additional optimizations to simplify workflows.
|
||||
|
||||
## Using uv Scripts
|
||||
|
||||
For more control, you can run Hugging Face Jobs directly with your own scripts, using [uv scripts](https://docs.astral.sh/uv/guides/scripts/).
|
||||
|
||||
Create a Python script (e.g., `train.py`) containing your training code:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
|
||||
```
|
||||
|
||||
Launch the job using either the [`hf jobs` CLI](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) or the Python API:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--with trl \
|
||||
--secrets HF_TOKEN \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
dependencies=["trl"],
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To run successfully, the script needs:
|
||||
|
||||
* **TRL installed**: Use the `--with trl` flag or the `dependencies` argument. uv installs these dependencies automatically before running the script.
|
||||
* **An authentication token**: Required to push the trained model (or perform other authenticated operations). Provide it with the `--secrets HF_TOKEN` flag or the `secrets` argument.
|
||||
|
||||
> [!WARNING]
|
||||
> When training with Jobs, be sure to:
|
||||
>
|
||||
> * **Set a sufficient timeout**. Jobs time out after 30 minutes by default. If your job exceeds the timeout, it will fail and all progress will be lost. See [Setting a custom timeout](https://huggingface.co/docs/huggingface_hub/guides/jobs#setting-a-custom-timeout).
|
||||
> * **Push the model to the Hub**. The Jobs environment is ephemeral—files are deleted when the job ends. If you don’t push the model, it will be lost.
|
||||
|
||||
You can also run a script directly from a URL:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--with trl \
|
||||
--secrets HF_TOKEN \
|
||||
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py",
|
||||
flavor="a100-large",
|
||||
dependencies=["trl"],
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To make a script self-contained, declare dependencies at the top:
|
||||
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl",
|
||||
# "peft",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
peft_config=LoraConfig(),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
|
||||
```
|
||||
|
||||
You can then run the script without specifying dependencies:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
TRL example scripts are fully uv-compatible, so you can run a complete training workflow directly on Jobs. You can customize training with standard script arguments plus hardware and secrets:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/prm800k \
|
||||
--output_dir Qwen2-0.5B-Reward \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
run_uv_job(
|
||||
"https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B-Instruct",
|
||||
"--dataset_name", "trl-lib/prm800k",
|
||||
"--output_dir", "Qwen2-0.5B-Reward",
|
||||
"--push_to_hub"
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
See the full list of examples in [Maintained examples](example_overview#maintained-examples).
|
||||
|
||||
### Docker Images
|
||||
|
||||
An up-to-date Docker image with all TRL dependencies is available at [huggingface/trl](https://hub.docker.com/r/huggingface/trl) and can be used directly with Hugging Face Jobs:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--image huggingface/trl \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
image="huggingface/trl",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Jobs runs on a Docker image from Hugging Face Spaces or Docker Hub, so you can also specify any custom image:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--image <docker-image> \
|
||||
--secrets HF_TOKEN \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
image="<docker-image>",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@ -1,16 +1,19 @@
|
||||
# Judges
|
||||
|
||||
> [!WARNING]
|
||||
> TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
Make sure to have installed the required dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install trl[llm_judge]
|
||||
pip install trl[judges]
|
||||
```
|
||||
|
||||
## Using the provided judges
|
||||
|
||||
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
TRL provides several judges out of the box. For example, you can use the [`HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
|
||||
```python
|
||||
from trl import HfPairwiseJudge
|
||||
@ -46,34 +49,38 @@ judge.judge(
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## BaseJudge
|
||||
## Provided judges
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
## BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
## BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
||||
|
||||
## RandomRankJudge
|
||||
|
||||
[[autodoc]] RandomRankJudge
|
||||
|
||||
## RandomPairwiseJudge
|
||||
|
||||
[[autodoc]] RandomPairwiseJudge
|
||||
|
||||
## PairRMJudge
|
||||
### PairRMJudge
|
||||
|
||||
[[autodoc]] PairRMJudge
|
||||
|
||||
## HfPairwiseJudge
|
||||
### HfPairwiseJudge
|
||||
|
||||
[[autodoc]] HfPairwiseJudge
|
||||
|
||||
## OpenAIPairwiseJudge
|
||||
### OpenAIPairwiseJudge
|
||||
|
||||
[[autodoc]] OpenAIPairwiseJudge
|
||||
|
||||
### AllTrueJudge
|
||||
|
||||
[[autodoc]] AllTrueJudge
|
||||
|
||||
## Base classes
|
||||
|
||||
### BaseJudge
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
### BaseBinaryJudge
|
||||
|
||||
[[autodoc]] BaseBinaryJudge
|
||||
|
||||
### BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
### BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
||||
96
docs/source/kernels_hub.md
Normal file
96
docs/source/kernels_hub.md
Normal file
@ -0,0 +1,96 @@
|
||||
# Kernels Hub Integration and Usage
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4b5175f3-1d60-455b-8664-43b2495ee1c3" width="450" height="450" alt="kernel-builder logo">
|
||||
|
||||
The [`kernels`](https://huggingface.co/blog/hello-hf-kernels#get-started-and-next-steps) library allows optimized compute kernels to be loaded directly from the Hub.
|
||||
You can find `kernels` in [dedicated orgs](https://huggingface.co/kernels-community) or by searching for the [`kernel` tag](https://huggingface.co/models?other=kernel) within the Hub.
|
||||
|
||||
Kernels are **optimized code pieces** that help in model development, training, and inference. Here, we’ll focus on their **integration with TRL**, but check out the above resources to learn more about them.
|
||||
|
||||
## Installation
|
||||
|
||||
To use kernels with TRL, you'd need to install the library in your Python environment:
|
||||
|
||||
```bash
|
||||
pip install kernels
|
||||
```
|
||||
|
||||
## Using Kernels from the Hub in TRL
|
||||
|
||||
Kernels can directly replace attention implementations, removing the need to manually compile attention backends like Flash Attention and boosting training speed just by pulling the respective attention kernel from the Hub.
|
||||
|
||||
You can specify a kernel when loading a model:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model-name",
|
||||
attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
|
||||
)
|
||||
```
|
||||
|
||||
Or when running a TRL training script:
|
||||
|
||||
```bash
|
||||
python sft.py ... --attn_implementation kernels-community/flash-attn
|
||||
```
|
||||
|
||||
Or using the TRL CLI:
|
||||
|
||||
```bash
|
||||
trl sft ... --attn_implementation kernels-community/flash-attn
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training.
|
||||
|
||||
## Comparing Attention Implementations
|
||||
|
||||
We evaluated various attention implementations available in transformers, along with different kernel backends, using **TRL** and **SFT**.
|
||||
The experiments were run on a single **H100 GPU** with **CUDA 12.9**, leveraging **Qwen3-8B** with a **batch size of 8**, **gradient accumulation of 1**, and **bfloat16** precision.
|
||||
Keep in mind that the results shown here are specific to this setup and may vary with different training configurations.
|
||||
|
||||
The following figure illustrates both **latency** (time per training step) and **peak allocated memory** for the different attention implementations and kernel backends.
|
||||
Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#flash-attention-vs-hub-kernels).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_latency.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_peak_allocated_memory.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
</div>
|
||||
|
||||
## Flash Attention vs. Hub Kernels
|
||||
|
||||
Building Flash Attention from source can be time-consuming, often taking anywhere from several minutes to hours, depending on your hardware, CUDA/PyTorch configuration, and whether precompiled wheels are available.
|
||||
|
||||
In contrast, **Hugging Face Kernels** provide a much faster and more reliable workflow. Developers don’t need to worry about complex setups—everything is handled automatically. In our benchmarks, kernels were ready to use in about **2.5 seconds**, with no compilation required. This allows you to start training almost instantly, significantly accelerating development. Simply specify the desired version, and `kernels` takes care of the rest.
|
||||
|
||||
## Combining FlashAttention Kernels with Liger Kernels
|
||||
|
||||
You can combine **FlashAttention kernels** with **Liger kernels** for additional TRL performance improvements.
|
||||
|
||||
First, install the Liger kernel dependency:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
Then, combine both in your code:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import SFTConfig
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model-name",
|
||||
attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant
|
||||
)
|
||||
|
||||
training_args = SFTConfig(
|
||||
use_liger_kernel=True,
|
||||
# ... other TRL training args
|
||||
)
|
||||
```
|
||||
|
||||
Learn more about the [Liger Kernel Integration](./liger_kernel_integration).
|
||||
139
docs/source/kto_trainer.md
Normal file
139
docs/source/kto_trainer.md
Normal file
@ -0,0 +1,139 @@
|
||||
# KTO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
|
||||
|
||||
The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs).
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente.
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/kto-mix-14k/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_kto.py
|
||||
from datasets import load_dataset
|
||||
from trl import KTOConfig, KTOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")
|
||||
|
||||
training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO")
|
||||
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_kto.py
|
||||
```
|
||||
|
||||
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-KTO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-KTO>:</span></strong>
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.
|
||||
|
||||
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
|
||||
|
||||
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch trl/scripts/kto.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/kto-mix-14k \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir Qwen2-0.5B-KTO
|
||||
```
|
||||
|
||||
## Usage tips
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
### Batch size recommendations
|
||||
|
||||
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
|
||||
|
||||
### Learning rate recommendations
|
||||
|
||||
Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results.
|
||||
|
||||
### Imbalanced data
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `logps/chosen_sum`: the sum of log probabilities of the chosen completions
|
||||
- `logps/rejected_sum`: the sum of log probabilities of the rejected completions
|
||||
- `logits/chosen_sum`: the sum of logits of the chosen completions
|
||||
- `logits/rejected_sum`: the sum of logits of the rejected completions
|
||||
- `count/chosen`: the count of chosen samples in a batch
|
||||
- `count/rejected`: the count of rejected samples in a batch
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
||||
@ -1,107 +0,0 @@
|
||||
# KTO Trainer
|
||||
|
||||
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://huggingface.co/papers/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
|
||||
For a full example have a look at [`examples/scripts/kto.py`].
|
||||
|
||||
Depending on how good your base model is, you may or may not need to do SFT before KTO.
|
||||
This is different from standard RLHF and DPO, which always require SFT.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
kto_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
|
||||
|
||||
|
||||
## Expected model format
|
||||
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `KTOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
<Tip>
|
||||
It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
training_args = KTOConfig(
|
||||
beta=0.1,
|
||||
desirable_weight=1.0,
|
||||
undesirable_weight=1.0,
|
||||
learning_rate=5e-7,
|
||||
)
|
||||
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
kto_trainer.train()
|
||||
```
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
||||
@ -1,232 +0,0 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
||||
78
docs/source/liger_kernel_integration.md
Normal file
78
docs/source/liger_kernel_integration.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
|
||||
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
|
||||
|
||||
| Speed Up | Memory Reduction |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
## Supported Trainers
|
||||
|
||||
Liger Kernel is supported in the following TRL trainers:
|
||||
- **SFT** (Supervised Fine-Tuning)
|
||||
- **DPO** (Direct Preference Optimization)
|
||||
- **GRPO** (Group Relative Policy Optimization)
|
||||
- **KTO** (Kahneman-Tversky Optimization)
|
||||
- **GKD** (Generalized Knowledge Distillation)
|
||||
|
||||
## Usage
|
||||
|
||||
1. First, install Liger Kernel:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
2. Once installed, set `use_liger_kernel=True` in your trainer config. No other changes are needed!
|
||||
|
||||
<hfoptions id="liger">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., use_liger_kernel=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., use_liger_kernel=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_liger_kernel=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="KTO">
|
||||
|
||||
```python
|
||||
from trl import KTOConfig
|
||||
|
||||
training_args = KTOConfig(..., use_liger_kernel=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="GKD">
|
||||
|
||||
```python
|
||||
from trl import GKDConfig
|
||||
|
||||
training_args = GKDConfig(..., use_liger_kernel=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).
|
||||
@ -1,75 +0,0 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`.
|
||||
|
||||
Upon initialization, pass one of these two options to the [`PPOConfig`]:
|
||||
```
|
||||
config = PPOConfig(
|
||||
model_name=args.model_name,
|
||||
log_with=`wandb`, # or `tensorboard`
|
||||
)
|
||||
```
|
||||
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
442
docs/source/lora_without_regret.md
Normal file
442
docs/source/lora_without_regret.md
Normal file
@ -0,0 +1,442 @@
|
||||
# LoRA Without Regret
|
||||
|
||||
Recent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) (Schulman et al., 2025) shows that **LoRA can match full fine-tuning performance** when configured correctly, while using only ~67% of the compute. These findings are exciting to TRL users because they're straightforward to implement and can improve model performance on smaller budgets.
|
||||
|
||||
This guide provides simple instructions to reproduce the results of the blog post in TRL.
|
||||
|
||||
> [!TIP]
|
||||
> It is recommended to read the blog post before following this guide, or to consult both resources in parallel for best results.
|
||||
|
||||
## Benefits of LoRA over full fine-tuning
|
||||
|
||||
First of all, let's remind ourselves of the benefits of [LoRA over full fine-tuning](https://huggingface.co/docs/trl/en/peft_integration).
|
||||
|
||||
LoRA adds adapter layers on top of the base model, which contains significantly fewer parameters than the base model itself. This design reduces GPU memory requirements and enables more efficient training. As described in the [blog](https://thinkingmachines.ai/blog/lora/), this approach was originally thought to involve a performance trade-off, although careful configuration can overcome this trade-off and match full fine-tuning performance.
|
||||
|
||||
## Examples with TRL
|
||||
|
||||
Let's implement and train LoRA adapters in TRL scripts based on the core findings of the blog post. Afterwards, we'll revisit each finding in light of the TRL results.
|
||||
|
||||
### Supervised Fine-Tuning (SFT)
|
||||
|
||||
The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
| --- | --- |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
|
||||
<hfoptions id="sft">
|
||||
<hfoption id="python">
|
||||
|
||||
We can integrate these findings with the TRL Python API like so:
|
||||
|
||||
```python
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
|
||||
|
||||
peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear")
|
||||
|
||||
training_args = SFTConfig(
|
||||
learning_rate=2e-4,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1,
|
||||
report_to=["trackio"],
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-3B-Instruct",
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--timeout 8h \
|
||||
--secrets HF_TOKEN \
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
|
||||
--dataset_name open-thoughts/OpenThoughts-114k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--use_peft \
|
||||
--lora_r 256 \
|
||||
--lora_alpha 16 \
|
||||
--lora_target_modules all-linear \
|
||||
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
|
||||
--report_to trackio \
|
||||
--push_to_hub
|
||||
|
||||
```
|
||||
|
||||
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="local">
|
||||
|
||||
```bash
|
||||
|
||||
uv run "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
|
||||
--dataset_name open-thoughts/OpenThoughts-114k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--gradient_checkpointing \
|
||||
--eval_strategy no \
|
||||
--use_peft \
|
||||
--lora_r 256 \
|
||||
--lora_alpha 16 \
|
||||
--lora_target_modules all-linear \
|
||||
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
|
||||
--report_to trackio \
|
||||
--push_to_hub
|
||||
|
||||
```
|
||||
|
||||
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Once training starts, you can monitor the progress in [Trackio](https://huggingface.co/trackio), which will log the URL.
|
||||
|
||||
### Reinforcement Learning (GRPO)
|
||||
|
||||
The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
| --- | --- |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
|
||||
For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function.
|
||||
|
||||
<details>
|
||||
<summary>Reward function</summary>
|
||||
|
||||
```python
|
||||
def strip_reasoning_accuracy_reward(
|
||||
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
|
||||
) -> list[float | None]:
|
||||
"""Reward function that strips reasoning tags and checks mathematical accuracy.
|
||||
|
||||
This function:
|
||||
1. Extracts the content from completions
|
||||
2. Removes <think></think> tags (for reasoning that shouldn't be evaluated)
|
||||
3. Parses both the gold solution and the predicted answer
|
||||
4. Uses math_verify to check if they are mathematically equivalent
|
||||
|
||||
Args:
|
||||
completions: List of model completions, each containing a list of messages
|
||||
solution: List of ground truth solutions
|
||||
**kwargs: Additional arguments (ignored but required for trainer compatibility)
|
||||
|
||||
Returns:
|
||||
List of rewards where:
|
||||
- 1.0 if the answer is correct
|
||||
- 0.0 if the answer is incorrect
|
||||
- None if the solution is not parseable (skips this example)
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
|
||||
for content, sol in zip(contents, solution):
|
||||
# Strip reasoning tags from completion
|
||||
while "<think>" in content and "</think>" in content:
|
||||
start = content.find("<think>")
|
||||
end = content.find("</think>", start)
|
||||
if start != -1 and end != -1:
|
||||
content = content[:start] + content[end + len("</think>") :]
|
||||
else:
|
||||
break
|
||||
|
||||
# Parse gold solution
|
||||
gold_parsed = parse(
|
||||
f"${sol}$",
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
boxed_match_priority=0, try_extract_without_anchor=True
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
boxed_match_priority=0,
|
||||
normalization_config=NormalizationConfig(
|
||||
basic_latex=True,
|
||||
units=True,
|
||||
malformed_operators=False,
|
||||
nits=False,
|
||||
boxed=True,
|
||||
),
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}"
|
||||
)
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = None
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<hfoptions id="grpo">
|
||||
<hfoption id="python">
|
||||
|
||||
We can implement these recommendations with the TRL Python API like so:
|
||||
|
||||
```python
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
|
||||
|
||||
def strip_reasoning_accuracy_reward(completions, **kwargs):
|
||||
"""Reward function that strips reasoning and accuracy scores from the model outputs."""
|
||||
|
||||
...
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=1,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
training_args = GRPOConfig(
|
||||
learning_rate=5e-5,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1,
|
||||
num_generations=8,
|
||||
generation_batch_size=8,
|
||||
report_to=["trackio"],
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reward_funcs=strip_reasoning_accuracy_reward,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This snippet skips the reward function which is defined above to keep the example concise.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
--env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
"https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
|
||||
--model_name_or_path Qwen/Qwen3-0.6B \
|
||||
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
|
||||
--output_dir grpo-full-qwen3-0.6b \
|
||||
--learning_rate 1.0e-6 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--beta 0.0 \
|
||||
--max_prompt_length 1024 \
|
||||
--max_completion_length 4096 \
|
||||
--num_generations 16 \
|
||||
--generation_batch_size 16 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--lora_r 1 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.0 \
|
||||
--lora_target_modules all-linear \
|
||||
--vllm_mode colocate \
|
||||
--save_strategy steps \
|
||||
--save_steps 50 \
|
||||
--save_total_limit 1 \
|
||||
--logging_steps 1 \
|
||||
--max_steps 200 \
|
||||
--report_to trackio
|
||||
```
|
||||
|
||||
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="local">
|
||||
|
||||
```bash
|
||||
uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
|
||||
--model_name_or_path Qwen/Qwen3-0.6B \
|
||||
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
|
||||
--output_dir grpo-full-qwen3-0.6b \
|
||||
--learning_rate 1.0e-6 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--beta 0.0 \
|
||||
--max_prompt_length 1024 \
|
||||
--max_completion_length 4096 \
|
||||
--num_generations 16 \
|
||||
--generation_batch_size 16 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--lora_r 1 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.0 \
|
||||
--lora_target_modules all-linear \
|
||||
--vllm_mode colocate \
|
||||
--save_strategy steps \
|
||||
--save_steps 50 \
|
||||
--save_total_limit 1 \
|
||||
--logging_steps 1 \
|
||||
--max_steps 200 \
|
||||
--report_to trackio
|
||||
```
|
||||
|
||||
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The reinforcement learning script with GRPO is implemented as a custom script in TRL, which uses the reward function shown above. You can review it at [`grpo.py`](https://huggingface.co/datasets/burtenshaw/lora-without-regrets/blob/main/grpo.py) - Reinforcement learning with LoRA best practices
|
||||
|
||||
## Key findings in optimizing LoRA
|
||||
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices.
|
||||
|
||||
We were able to reproduce the results of the blog post using TRL and the SmolLM3 model. We trained the model for 500 steps on the [Math 220k dataset](https://huggingface.co/datasets/HuggingFaceH4/OpenR1-Math-220k-default-verified) with the reward function and configuration above. As you can see in the figure below, the LoRA model's average train reward curve matches the full fine-tuning curve.
|
||||
|
||||

|
||||
|
||||
And most importantly, the LoRA model uses significantly less memory than the full fine-tuning model, as we can see in the figure below.
|
||||
|
||||

|
||||
|
||||
Here are the parameters we used to train the above models
|
||||
|
||||
| Parameter | LoRA | Full FT |
|
||||
| --- | --- | --- |
|
||||
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
|
||||
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
|
||||
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
|
||||
| `--max_prompt_length` | 1024 | 1024 |
|
||||
| `--max_completion_length` | 4096 | 4096 |
|
||||
| `--lora_r` | 1 | - |
|
||||
| `--lora_alpha` | 32 | - |
|
||||
| `--lora_dropout` | 0.0 | - |
|
||||
| `--lora_target_modules` | all-linear | - |
|
||||
|
||||
Let's break down the key findings of the blog post and how we were able to reproduce them.
|
||||
|
||||
### 1. *LoRA performs better when applied to all weight matrices*
|
||||
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction.
|
||||
|
||||

|
||||
|
||||
Attention-only LoRA underperforms even when using a higher rank to match parameter count. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. In Python, we can do this like so:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(target_modules="all-linear")
|
||||
```
|
||||
|
||||
### 2. *The adapter needs sufficient capacity to learn from the dataset*
|
||||
|
||||
The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT".
|
||||
|
||||

|
||||
|
||||
In the TRL script, we could use `--lora_r` to set the rank and adapt it based on the task and dataset we're training on. The blog post recommends the following ranks based on the task and dataset size:
|
||||
|
||||
Reinforcement learning tasks typically require lower capacity, so smaller LoRA ranks can be used. This is because policy gradient algorithms extract roughly ~1 bit of information per episode, demanding minimal parameter capacity.
|
||||
|
||||
The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as:
|
||||
|
||||
| Task Type | Dataset Size | Recommended Rank |
|
||||
| --- | --- | --- |
|
||||
| **SFT** | Post-training scale | 256 |
|
||||
| **RL** | Any size | 1-32 |
|
||||
|
||||
### 3. *"FullFT and high-rank LoRAs have similar learning curves"*
|
||||
|
||||
Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
|
||||
|
||||

|
||||
|
||||
### 4. *"In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning."*
|
||||
|
||||
The blog post recommends using an effective batch size < 32 because the authors found LoRA to be less tolerant of large batch sizes. This could not be mitigated by increasing the LoRA rank. In the TRL script, we could use `--per_device_train_batch_size` and `--gradient_accumulation_steps` to set the batch size.
|
||||
|
||||

|
||||
|
||||
## Takeaways
|
||||
|
||||
Using TRL, you can efficiently implement LoRA adapters to match full fine-tuning performance, applying the core insights (targeting all weight matrices, choosing the right rank, and managing batch size and learning rate) without the heavy compute cost of FullFT.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{schulman2025lora,
|
||||
title = {{LoRA Without Regret}},
|
||||
author = {John Schulman and Thinking Machines Lab},
|
||||
year = 2025,
|
||||
journal = {Thinking Machines Lab: Connectionism},
|
||||
doi = {10.64434/tml.20250929},
|
||||
note = {https://thinkingmachines.ai/blog/lora/}
|
||||
}
|
||||
```
|
||||
9
docs/source/model_utils.md
Normal file
9
docs/source/model_utils.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Model Utilities
|
||||
|
||||
## clone_chat_template
|
||||
|
||||
[[autodoc]] clone_chat_template
|
||||
|
||||
## get_act_offloading_ctx_manager
|
||||
|
||||
[[autodoc]] models.get_act_offloading_ctx_manager
|
||||
@ -8,7 +8,6 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
[[autodoc]] AutoModelForCausalLMWithValueHead
|
||||
- __init__
|
||||
- forward
|
||||
@ -25,4 +24,4 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## create_reference_model
|
||||
|
||||
[[autodoc]] create_reference_model
|
||||
[[autodoc]] create_reference_model
|
||||
@ -1,100 +0,0 @@
|
||||
# Multi Adapter RL (MARL) - a single base model for everything
|
||||
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
|
||||
|
||||
## Requirements
|
||||
|
||||
You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning.
|
||||
|
||||
## Summary
|
||||
|
||||
You need to address this approach in three stages that we summarize as follows:
|
||||
|
||||
1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
|
||||
|
||||
```python
|
||||
model_name = "huggyllama/llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
|
||||
# PPO adapter
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
)
|
||||
|
||||
...
|
||||
trainer = PPOTrainer(
|
||||
model=model,
|
||||
...
|
||||
)
|
||||
|
||||
...
|
||||
```
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
rewards = trainer.model.compute_reward_score(**inputs)
|
||||
```
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
adapter_name_policy_1 = "policy_1"
|
||||
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
|
||||
...
|
||||
```
|
||||
|
||||
### Using 4-bit and 8-bit base models
|
||||
|
||||
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
|
||||
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
|
||||
```python
|
||||
model_name = "llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
|
||||
# PPO adapter
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
...
|
||||
trainer = PPOTrainer(
|
||||
model=model,
|
||||
...
|
||||
)
|
||||
...
|
||||
```
|
||||
@ -1,18 +1,20 @@
|
||||
# Nash-MD Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
|
||||
@ -26,21 +28,17 @@ Below is the script to train the model:
|
||||
```python
|
||||
# train_nash_md.py
|
||||
from datasets import load_dataset
|
||||
from trl import NashMDConfig, NashMDTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
judge = PairRMJudge()
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
|
||||
|
||||
args = NashMDConfig(output_dir="nash-md-qwen2", logging_steps=10)
|
||||
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD")
|
||||
trainer = NashMDTrainer(
|
||||
model=model,
|
||||
reward_model=reward_model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
@ -51,22 +49,51 @@ Execute the script using the following command:
|
||||
accelerate launch train_nash_md.py
|
||||
```
|
||||
|
||||
## Expected dataset format
|
||||
Distributed across 8 GPUs, the training takes approximately 3 hours.
|
||||
|
||||
Nash-MD requires a [prompt-only dataset](dataset_format#preference). The [`NashMDTrainer`] supports both [conversational](dataset_format#conversational-dataset-format) and [standard](dataset_format#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-NashMD
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-NashMD>:</span></strong>
|
||||
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Usage tips
|
||||
|
||||
### ⚠️ Use the same chat template
|
||||
### Use a reward model
|
||||
|
||||
Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
|
||||
|
||||
```diff
|
||||
- from trl import PairRMJudge
|
||||
+ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
- judge = PairRMJudge()
|
||||
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
|
||||
trainer = NashMDTrainer(
|
||||
...
|
||||
- judge=judge,
|
||||
+ reward_funcs=reward_model,
|
||||
)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
We can want the model to generate completion within a given length. During the learning, the model will generate completion up to the maximum completion length specified in the `max_new_tokens` argument of [`NashMDConfig`]. I you want to penalize for not generating an EOS token before the maximum completion length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
|
||||
We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
|
||||
|
||||
```python
|
||||
args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
```
|
||||
|
||||
### Logging Completions
|
||||
@ -81,39 +108,35 @@ trainer.add_callback(completions_callback)
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||

|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)
|
||||
|
||||
To test the Nash-MD script with the [Pythia 14M model](https://huggingface.co/EleutherAI/pythia-14m) on the TL;DR summarization task, run the following command:
|
||||
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
|
||||
|
||||
```bash
|
||||
python examples/scripts/nash_md.py \
|
||||
--model_name_or_path EleutherAI/pythia-14m \
|
||||
--reward_model_path EleutherAI/pythia-14m \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--judge pair_rm \
|
||||
--dataset_name trl-lib/ultrafeedback-prompt \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-14m-tldr-nash-md \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 32 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 64 \
|
||||
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The logged metrics are as follows:
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
* `loss/kl`: The mean KL divergence between the model and reference data.
|
||||
* `objective/entropy`: The mean entropy of the model and reference data.
|
||||
* `loss/score`: The mean reinforce score loss.
|
||||
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions.
|
||||
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
|
||||
* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion.
|
||||
* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
|
||||
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
|
||||
* `logps/chosen`: The mean log probabilities of the chosen completions.
|
||||
@ -126,6 +149,9 @@ The logged metrics are as follows:
|
||||
## NashMDTrainer
|
||||
|
||||
[[autodoc]] NashMDTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## NashMDConfig
|
||||
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
# Online DPO Trainer
|
||||
|
||||
## Overview
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
## Overview
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
|
||||
|
||||
The current implementation uses reward models for scoring completions -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use.
|
||||
|
||||
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and the [Qwen 0.5B reward model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) as the reward model. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
|
||||
@ -28,21 +28,17 @@ Below is the script to train the model:
|
||||
```python
|
||||
# train_online_dpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import OnlineDPOConfig, OnlineDPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
judge = PairRMJudge()
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
|
||||
|
||||
args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
|
||||
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO")
|
||||
trainer = OnlineDPOTrainer(
|
||||
model=model,
|
||||
reward_model=reward_model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
@ -55,37 +51,51 @@ accelerate launch train_online_dpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the trained model performs, use the following code to generate completions:
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> generator = pipeline("text-generation", model="online-dpo-qwen2/checkpoint-1773", device="cuda")
|
||||
>>> question = "Why is the problem always DNS?"
|
||||
>>> output = generator([{"role": "user", "content": question}], max_new_tokens=200, return_full_text=False)[0]
|
||||
>>> print(output["generated_text"])
|
||||
The reason why the problem of DNS (Domain Name System) can always be encountered is that it is designed to provide reliable and accurate information about the availability, ownership, or expiration of domain names. However, there may be some circumstances where the system fails to resolve an IP address correctly, leading to the problem of DNS.
|
||||
For example, if the server hosting the domain name does not have the correct IP address associated with it, or if the IP address is incorrectly formatted, then the DNS system will fail to resolve the domain name correctly. Additionally, if the server hosting the domain name has been compromised, then the DNS system may also fail to resolve the domain name correctly.
|
||||
It's worth noting that the exact cause of DNS failure can vary depending on the specific situation, so it's important to carefully check all relevant factors before attempting to resolve the issue. If you suspect that your DNS problem may be caused by a bug in the system, you should report it to the DNS provider directly for further investigation.
|
||||
```
|
||||
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
## Expected dataset format
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-OnlineDPO>:</span></strong>
|
||||
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
|
||||
</code></pre>
|
||||
|
||||
Online DPO only requires a [prompt-only dataset](dataset_format#preference) (unlike offline DPO, that expects [preference dataset](dataset_format#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_format#conversational-dataset-format) and [standard](dataset_format#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
## Expected dataset type
|
||||
|
||||
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Usage tips
|
||||
|
||||
### ⚠️ Use the same chat template
|
||||
### Use a reward model
|
||||
|
||||
Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
|
||||
|
||||
```diff
|
||||
- from trl import PairRMJudge
|
||||
+ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
- judge = PairRMJudge()
|
||||
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")
|
||||
|
||||
trainer = OnlineDPOTrainer(
|
||||
...
|
||||
- judge=judge,
|
||||
+ reward_funcs=reward_model,
|
||||
+ reward_processing_class=reward_tokenizer,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
We can want the model to generate completion within a given length. During the learning, the model will generate completion up to the maximum completion length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. I you want to penalize for not generating an EOS token before the maximum completion length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
|
||||
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
|
||||
|
||||
```python
|
||||
args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
```
|
||||
|
||||
### Logging Completions
|
||||
@ -100,40 +110,34 @@ trainer.add_callback(completions_callback)
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
|
||||
|
||||
To test the online DPO script with the [Pythia 1B model](https://huggingface.co/trl-lib/pythia-1b-deduped-tldr-sft) on the TL;DR summarization task, run the following command:
|
||||
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
|
||||
|
||||
```bash
|
||||
python examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--judge pair_rm \
|
||||
--dataset_name trl-lib/ultrafeedback-prompt \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-1b-tldr-online-dpo \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 32 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
While training and evaluating, we record the following reward metrics. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
|
||||
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
|
||||
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
|
||||
* `objective/scores`: The mean scores returned by the reward mode.
|
||||
* `objective/scores`: The mean scores returned by the reward model.
|
||||
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
|
||||
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
|
||||
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
|
||||
@ -148,8 +152,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
|
||||
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
```
|
||||
```shell
|
||||
# 1B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
@ -165,7 +168,6 @@ accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
|
||||
@ -184,8 +186,6 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
|
||||
@ -204,18 +204,15 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--gradient_checkpointing \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
* [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
* [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
@ -259,14 +256,15 @@ plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||
|
||||
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
|
||||
|
||||
## OnlineDPOTrainer
|
||||
|
||||
[[autodoc]] OnlineDPOTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## OnlineDPOConfig
|
||||
|
||||
[[autodoc]] OnlineDPOConfig
|
||||
[[autodoc]] OnlineDPOConfig
|
||||
|
||||
373
docs/source/openenv.md
Normal file
373
docs/source/openenv.md
Normal file
@ -0,0 +1,373 @@
|
||||
# OpenEnv Integration for Training LLMs with Environments
|
||||
|
||||
## Overview
|
||||
|
||||
[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open-source framework from Meta's PyTorch team for defining, deploying, and interacting with environments in reinforcement learning (RL) and agentic workflows. It offers [Gymnasium-style APIs](https://gymnasium.farama.org) (e.g., `reset()` and `step()`) to interface with environments in a standard manner, and supports running these environments as backend servers (for example via HTTP or containerised execution). You can find a collection of ready-to-use OpenEnv environments on the [Hugging Face Hub](https://huggingface.co/collections/openenv/environment-hub).
|
||||
|
||||
In this guide, we’ll focus on **how to integrate OpenEnv with TRL**, but feel free to explore the links above to dive deeper into OpenEnv itself.
|
||||
|
||||
## Installation
|
||||
|
||||
To use OpenEnv with TRL, install the framework:
|
||||
|
||||
```bash
|
||||
pip install openenv-core
|
||||
```
|
||||
|
||||
## Using `rollout_func` with OpenEnv environments
|
||||
|
||||
TRL's [`GRPOTrainer`] supports _custom rollout logic_ through the `rollout_func` argument. This lets you override the trainer's default text-generation loop and directly interact with OpenEnv environments — for instance, to compute environment-driven rewards instead of relying solely on model-based signals.
|
||||
|
||||
### Rollout Function Signature
|
||||
|
||||
A rollout function must have the following signature:
|
||||
|
||||
```python
|
||||
def rollout_func(
|
||||
prompts: list[str],
|
||||
args: GRPOConfig,
|
||||
processing_class
|
||||
) -> dict[str, list]:
|
||||
"""
|
||||
Custom rollout function for generation and reward computation.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts to generate from
|
||||
args: GRPOConfig containing sampling parameters (temperature, top_p, etc.)
|
||||
processing_class: Tokenizer/processor for encoding/decoding
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- prompt_ids: List of token IDs for each prompt
|
||||
- completion_ids: List of token IDs for each completion
|
||||
- logprobs: List of log probabilities for each token
|
||||
- Any additional fields are forwarded to reward functions as kwargs
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step.
|
||||
|
||||
### Integration pattern
|
||||
|
||||
The typical pattern when combining OpenEnv with TRL looks like this:
|
||||
|
||||
1. Start or connect to an OpenEnv environment (e.g., an HTTP endpoint or Dockerized env).
|
||||
2. Generate completions from your model — for example, via a vLLM inference server (`use_vllm=True`, `vllm_mode="server"`).
|
||||
3. Step through the environment using each completion to compute rewards or metrics.
|
||||
4. Add environment results (e.g., `env_reward`) to the rollout result dict.
|
||||
5. Access those rewards inside your reward function via `**kwargs`.
|
||||
|
||||
By using OpenEnv in this loop, you can:
|
||||
|
||||
* Train with realistic or interactive feedback (not just static reward functions).
|
||||
* Plug in custom simulators, web APIs, or evaluators as environments.
|
||||
* Pass structured reward signals back into RL training seamlessly.
|
||||
|
||||
## A simple example
|
||||
|
||||
The [echo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards:
|
||||
|
||||
```python
|
||||
from envs.echo_env import EchoEnv, EchoAction
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
# Create HTTP client for Echo Environment
|
||||
client = EchoEnv.from_docker_image("echo-env:latest")
|
||||
|
||||
def rollout_func(prompts, args, processing_class):
|
||||
# 1. Generate completions via vLLM inference server (running on port 8000)
|
||||
payload = {
|
||||
"prompts": prompts,
|
||||
"n": args.num_generations,
|
||||
"temperature": args.temperature,
|
||||
"max_tokens": args.max_completion_length,
|
||||
}
|
||||
response = requests.post("http://0.0.0.0:8000/generate/", json=payload)
|
||||
result = response.json()
|
||||
|
||||
completions_text = processing_class.batch_decode(
|
||||
result["completion_ids"],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
# 2. Step through the environment to get rewards
|
||||
client.reset()
|
||||
env_rewards = []
|
||||
for msg in completions_text:
|
||||
env_result = client.step(EchoAction(message=msg))
|
||||
env_rewards.append(env_result.reward)
|
||||
|
||||
# 3. Add environment rewards as extra field
|
||||
result["env_reward"] = env_rewards
|
||||
return result
|
||||
|
||||
def reward_from_env(completions, **kwargs):
|
||||
"""Extract environment rewards passed via rollout_func kwargs."""
|
||||
env_rewards = kwargs.get("env_reward", [])
|
||||
return [float(reward) for reward in env_rewards] if env_rewards else [0.0] * len(completions)
|
||||
|
||||
dataset = Dataset.from_dict({"prompt": ["You are an AI that interacts with an *Echo* environment. Word to echo:"] * 64})
|
||||
|
||||
# Setup trainer with custom rollout
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
reward_funcs=reward_from_env,
|
||||
train_dataset=dataset,
|
||||
rollout_func=rollout_func, # Use custom rollout
|
||||
args=GRPOConfig(
|
||||
vllm_mode="server",
|
||||
use_vllm=True,
|
||||
num_train_epochs=1,
|
||||
num_generations=8,
|
||||
max_completion_length=2048,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
That's it! Now that you’ve seen the full example, let’s unpack how the main pieces fit together.
|
||||
|
||||
1. **Environment Client:** `EchoEnv` implements an HTTP interface to interact with the environment server.
|
||||
2. **Custom rollout:** The `rollout_func` generates completions and steps through the environment to collect rewards.
|
||||
3. **Extra fields:** The rollout adds `env_reward` to the result dictionary, which is automatically passed to reward functions.
|
||||
4. **Reward function:** Extracts `env_reward` from `kwargs` to apply environment-computed rewards during training.
|
||||
|
||||
> [!WARNING]
|
||||
> The `rollout_func` is currently only supported when using vLLM in server mode (`use_vllm=True`, `vllm_mode="server"`).
|
||||
|
||||
### Running the Example
|
||||
|
||||
The example requires two GPUs:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM inference server
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
|
||||
# Terminal 2: Run GRPO training with OpenEnv
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py
|
||||
```
|
||||
|
||||
Below is the reward curve from training:
|
||||
|
||||
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
|
||||
|
||||
To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md).
|
||||
|
||||
## Advanced Example
|
||||
|
||||
Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the `textarena` environment.
|
||||
|
||||
### The TextArena Environment
|
||||
|
||||
[TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks.
|
||||
|
||||

|
||||
|
||||
We will use the `textarena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them.
|
||||
|
||||
### Wordle
|
||||
|
||||
Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements.
|
||||
|
||||
> [!NOTE] How does Wordle work?
|
||||
> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment.
|
||||
>
|
||||
> For example, if the wordle environment returns the following feedback:
|
||||
>
|
||||
> ```
|
||||
> G U E S S
|
||||
> X G Y X X
|
||||
> ```
|
||||
> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position.
|
||||
|
||||
In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script.
|
||||
|
||||
### Rollout Function
|
||||
|
||||
The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties.
|
||||
|
||||
```python
|
||||
def rollout_once(
|
||||
env: TextArenaEnv,
|
||||
tokenizer: AutoTokenizer,
|
||||
args: GRPOConfig,
|
||||
dataset_prompt: str,
|
||||
cli_args: argparse.Namespace,
|
||||
system_prompt: str,
|
||||
) -> dict[str, list]:
|
||||
result = env.reset()
|
||||
observation = result.observation
|
||||
|
||||
prompt_ids: list[int] = []
|
||||
completion_ids: list[int] = []
|
||||
logprobs: list[float] = []
|
||||
raw_rewards: list[float] = []
|
||||
green_scores: list[float] = []
|
||||
yellow_scores: list[float] = []
|
||||
repetition_scores: list[float] = []
|
||||
correct_scores: list[float] = []
|
||||
guess_counts: dict[str, int] = {}
|
||||
|
||||
for _turn in range(cli_args.max_turns):
|
||||
# when the game is over the environment will return a done=True
|
||||
if result.done:
|
||||
break
|
||||
|
||||
# set up the prompt for the model
|
||||
base_prompt = observation.prompt or dataset_prompt
|
||||
user_prompt = make_user_prompt(base_prompt, observation.messages)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
prompt_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
# generate the completion from the model using vLLM
|
||||
vllm_result = request_vllm_completion(
|
||||
prompt_text,
|
||||
args,
|
||||
endpoint=cli_args.vllm_endpoint,
|
||||
timeout=cli_args.request_timeout,
|
||||
fallback=cli_args,
|
||||
)
|
||||
prompt_ids.extend(vllm_result["prompt_ids"])
|
||||
completion_ids.extend(vllm_result["completion_ids"])
|
||||
logprobs.extend(vllm_result["logprobs"])
|
||||
completion_text = vllm_result.get("text") or tokenizer.decode(
|
||||
vllm_result["completion_ids"], skip_special_tokens=True
|
||||
)
|
||||
# extract the guess from the completion
|
||||
guess = extract_guess(completion_text)
|
||||
|
||||
# step the environment with the guess
|
||||
result = env.step(TextArenaAction(message=guess))
|
||||
raw_rewards.append(float(result.reward or 0.0))
|
||||
observation = result.observation
|
||||
correct_score = float(result.reward or 0.0)
|
||||
feedback = extract_wordle_feedback(observation)
|
||||
|
||||
# Update guess counts
|
||||
previous_occurrences = guess_counts[guess]
|
||||
repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts))
|
||||
guess_counts[guess] += 1
|
||||
|
||||
# calculate custom reward signals from the feedback
|
||||
if not feedback:
|
||||
green_score = 0.0
|
||||
yellow_score = 0.0
|
||||
else:
|
||||
green_count, yellow_count = extract_feedback_counts(feedback)
|
||||
green_score = green_count / 5.0
|
||||
yellow_score = yellow_count / 5.0
|
||||
|
||||
repetition_scores.append(repetition_score)
|
||||
green_scores.append(green_score)
|
||||
yellow_scores.append(yellow_score)
|
||||
correct_scores.append(correct_score)
|
||||
|
||||
correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"completion_ids": completion_ids,
|
||||
"logprobs": logprobs,
|
||||
"raw_rewards": raw_rewards,
|
||||
"correct_reward": correct_reward_value,
|
||||
"green_reward": green_scores[-1] if green_scores else 0.0,
|
||||
"yellow_reward": yellow_scores[-1] if yellow_scores else 0.0,
|
||||
"repetition_reward": repetition_scores[-1] if repetition_scores else 0.0,
|
||||
}
|
||||
```
|
||||
|
||||
The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game.
|
||||
|
||||
### Reward Functions
|
||||
|
||||
We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses:
|
||||
|
||||
- `reward_correct`: final win/loss signal from the environment.
|
||||
- `reward_greens`: density of green letters in the last feedback.
|
||||
- `reward_yellows`: density of yellow letters in the last feedback.
|
||||
- `reward_repetition`: penalty for guessing the same token multiple times.
|
||||
|
||||
```python
|
||||
def reward_correct(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
|
||||
rewards = kwargs.get("correct_reward") if kwargs else None
|
||||
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
|
||||
|
||||
|
||||
def reward_greens(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
|
||||
rewards = kwargs.get("green_reward") if kwargs else None
|
||||
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
|
||||
|
||||
|
||||
def reward_yellows(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
|
||||
rewards = kwargs.get("yellow_reward") if kwargs else None
|
||||
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
|
||||
|
||||
|
||||
def reward_repetition(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
|
||||
rewards = kwargs.get("repetition_reward") if kwargs else None
|
||||
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
|
||||
```
|
||||
|
||||
### Training the Model
|
||||
|
||||
The training script wires the custom rollout and rewards into `GRPOTrainer`. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time.
|
||||
|
||||
```python
|
||||
parser = argparse.ArgumentParser()
|
||||
# ... add CLI arguments with sensible defaults ...
|
||||
cli_args = parser.parse_args()
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=cli_args.model_id,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
reward_correct,
|
||||
reward_greens,
|
||||
reward_yellows,
|
||||
reward_repetition,
|
||||
],
|
||||
train_dataset=dataset,
|
||||
args=grpo_config,
|
||||
rollout_func=lambda prompts, args, processing_class: rollout_func(
|
||||
env=env,
|
||||
tokenizer=tokenizer,
|
||||
prompts=prompts,
|
||||
args=args,
|
||||
cli_args=cli_args,
|
||||
system_prompt=system_prompt,
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Running the Example
|
||||
|
||||
The example requires two GPUs:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM inference server
|
||||
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
||||
|
||||
# Terminal 2: Run GRPO training with OpenEnv
|
||||
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
|
||||
```
|
||||
|
||||
### Results
|
||||
|
||||
The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters.
|
||||
|
||||
<iframe src="https://burtenshaw-wordle-grpo.hf.space/?project=group-Qwen-Qwen3-17B&metrics=train/rewards/reward_coverage/mean&runs=run-2025-10-26_09-39-49&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
|
||||
|
||||
We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?
|
||||
@ -1,105 +1,130 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[Odds Ratio Preference Optimization](https://huggingface.co/papers/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B).
|
||||
|
||||
It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
|
||||
|
||||
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
|
||||
|
||||
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
|
||||
The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo).
|
||||
|
||||
## Expected dataset format
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt).
|
||||
|
||||
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
## Quick start
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
for example:
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
```py
|
||||
orpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_orpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import ORPOConfig, ORPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO")
|
||||
trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
Execute the script using the following command:
|
||||
|
||||
## Using the `ORPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
|
||||
|
||||
```py
|
||||
orpo_config = ORPOConfig(
|
||||
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
|
||||
)
|
||||
|
||||
orpo_trainer = ORPOTrainer(
|
||||
model,
|
||||
args=orpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```bash
|
||||
accelerate launch train_orpo.py
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
orpo_trainer.train()
|
||||
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-ORPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-ORPO>:</span></strong>
|
||||
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
|
||||
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
|
||||
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
|
||||
<strong><span style="color: green;">• Accessibility:</span></strong> Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
|
||||
<strong><span style="color: green;">• Version control:</span></strong> As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
|
||||
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)
|
||||
|
||||
To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/orpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--output_dir Qwen2-0.5B-ORPO
|
||||
```
|
||||
|
||||
## Usage tips
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## Logging
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
While training and evaluating, we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
|
||||
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
|
||||
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
|
||||
9
docs/source/others.md
Normal file
9
docs/source/others.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Other
|
||||
|
||||
## profiling_decorator
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_decorator
|
||||
|
||||
## profiling_context
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_context
|
||||
651
docs/source/paper_index.md
Normal file
651
docs/source/paper_index.md
Normal file
@ -0,0 +1,651 @@
|
||||
# Paper Index
|
||||
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## Group Relative Policy Optimization
|
||||
|
||||
Papers relating to the [`GRPOTrainer`]
|
||||
|
||||
### Group Sequence Policy Optimization
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2507.18071
|
||||
|
||||
GSPO is a GRPO variant that computes importance sampling weights at the sequence level instead of per-token. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
importance_sampling_level="sequence",
|
||||
loss_type="grpo",
|
||||
beta=0.0, # GSPO set KL regularization to zero: https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306
|
||||
epsilon=3e-4, # GSPO paper (v2), section 5.1
|
||||
epsilon_high=4e-4, # GSPO paper (v2), section 5.1
|
||||
gradient_accumulation_steps=1,
|
||||
steps_per_generation=4, # partition rollout batch into 4 mini-batches. GSPO paper (v2), section 5.1. Must be 4 times gradient_accumulation_steps
|
||||
)
|
||||
```
|
||||
|
||||
Note that this method only has an effect when training goes slightly off-policy—for example, when `steps_per_generation > gradient_accumulation_steps` or `num_iterations > 1`. Otherwise, it is effectively equivalent to no modification.
|
||||
|
||||
TRL also provide an experimental implementation of GSPO-token, see [Experimental - GSPO-Token](experimental#gspo-token).
|
||||
|
||||
#### Policy ratio: GRPO vs. GSPO
|
||||
|
||||
In GSPO, the policy ratio is defined at the sequence-level. In other words, it is the ratio between the probability of the current policy generating a sequence over the old policy generating that same sequence.
|
||||
|
||||
The sequence likelihood is defined as:
|
||||
|
||||
$$
|
||||
\pi_\theta (o_i | q) = \prod_{t=1}^{|o_i|} \pi_\theta (o_{i,t} | q, o_{i, < t} ),
|
||||
$$
|
||||
|
||||
where \\( \pi_\theta \\) is the policy \\( \pi \\) with parameters \\(\theta\\), \\( o_i \\) is the \\( i \\)-th output sequence \\( o \\) and \\(o_{i,t}\\) is the \\( t \\)-th token in this sequence, \\( q \\) is the input query. The sequence likelihood ratio \\( s_i (\theta) \\) is defined as:
|
||||
|
||||
$$
|
||||
s_i (\theta) = \left(\frac{\pi_\theta (o_i | q)}{\pi_{\theta_{old}} (o_i | q)} \right)^{\frac{1}{|o_i|}}
|
||||
$$
|
||||
|
||||
The exponent \\( \frac{1}{|o_i|} \\) represents a sequence-length normalization, minimizing the influence of sequence length in sequence likelihood. In other terms, it computes the geometric mean of token probabilities, ensuring a fair comparison across sequences of varying lengths.
|
||||
|
||||
While GSPO defines the policy ratio at the sequence level, GRPO operates at the token level. Specifically, GRPO computes an importance ratio for each token in the sequence:
|
||||
|
||||
$$
|
||||
w_{i,t}(\theta) = \frac{\pi_\theta (o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}} (o_{i,t} | q, o_{i,< t})}
|
||||
$$
|
||||
|
||||
This token-level ratio is then combined with a shared advantage \\( \hat{A}_i \\), and the GRPO objective clips and optimizes each token independently across the sequence.
|
||||
|
||||
### DAPO: An Open-Source LLM Reinforcement Learning System at Scale
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2503.14476
|
||||
|
||||
The DAPO algorithm includes 5 key components:
|
||||
|
||||
- Overlong Filtering
|
||||
- Clip-Higher
|
||||
- Soft Overlong Punishment
|
||||
- Token-level Loss
|
||||
- Dynamic Sampling (⚠️ Not supported in TRL)
|
||||
|
||||
To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
training_args = GRPOConfig(
|
||||
# Overlong Filtering
|
||||
mask_truncated_completions=True,
|
||||
# Token-level Loss
|
||||
loss_type="dapo",
|
||||
# Clip-Higher
|
||||
epsilon_high=0.28, # DAPO paper: section 4.1
|
||||
epsilon=0.2, # DAPO paper: section 4.1
|
||||
# Other parameters used
|
||||
per_device_train_batch_size=512, # mini-batch size for training in the paper, DAPO paper: section 4.1
|
||||
num_generations=16, # number of sample responses in the paper, DAPO paper: section 4.1
|
||||
max_completion_length=20480, # maximum number of tokens for generation in the paper, DAPO paper: section 4.1
|
||||
beta=0.0 # section 2.3, DAPO paper
|
||||
|
||||
)
|
||||
# Soft Overlong Punishment
|
||||
sop_reward = get_soft_overlong_punishment(max_completion_len=20480, soft_punish_cache=4096) # DAPO paper: section 4.1
|
||||
trainer = GRPOTrainer(
|
||||
...,
|
||||
args=training_args,
|
||||
reward_funcs=[..., sop_reward],
|
||||
)
|
||||
```
|
||||
|
||||
### Dr. GRPO: Understanding R1-Zero-Like Training: A Critical Perspective
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2503.20783
|
||||
|
||||
A study of R1-Zero training identifies pretraining effects on RL performance and proffers Dr. GRPO to enhance token efficiency, achieving superior accuracy on AIME 2024. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
loss_type="dr_grpo",
|
||||
per_device_train_batch_size=1, # train_batch_size_per_device in the Training section of the repository
|
||||
num_generations=8, # num_samples in the Training section of the repository
|
||||
max_prompt_length=1024, # prompt_max_length in the Training section of the repository
|
||||
max_completion_length=3000, # generate_max_length in the Training section of the repository
|
||||
beta=0.0, # beta in the Training section of the repository
|
||||
)
|
||||
```
|
||||
|
||||
### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2508.08221
|
||||
|
||||
The authors of this paper find that the combination of:
|
||||
|
||||
1. scaling rewards by the standard deviation computed over the entire batch and
|
||||
2. aggregating loss over the total number of tokens
|
||||
|
||||
can unlock the learning capability of critic-free policies using vanilla PPO loss. Their results demonstrate that this simple combination consistently improves performance, surpassing strategies like GRPO and [DAPO](https://huggingface.co/papers/2503.14476).
|
||||
|
||||
TRL supports using these learnings to train a GRPO model by:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...
|
||||
scale_rewards="batch",
|
||||
loss_type="dapo",
|
||||
# Other parameters used
|
||||
beta=0.0, # = init_kl_coef in the paper
|
||||
top_p=0.99,
|
||||
top_k=100,
|
||||
temperature=0.99,
|
||||
num_completions=8, # = num_return_sequences in the paper
|
||||
num_iterations=1, # = ppo_epochs in the paper
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=32,
|
||||
steps_per_generation=8, # (rollout_batch_size*num_return_sequences) / (per_device_train_batch_size*gradient_accumulation_steps)
|
||||
)
|
||||
```
|
||||
|
||||
Note that when using gradient accumulation, the loss is aggregated over the total number of tokens in the batch, but not over the accumulated batch. For more details, see the [GRPO Trainer - Loss types](grpo_trainer#loss_types).
|
||||
|
||||
### Truncated Importance Sampling
|
||||
|
||||
**📰 Blog**: https://fengyao.notion.site/off-policy-rl
|
||||
|
||||
Online policy learning methods commonly use an optimized inference framework for rollout generation (e.g vLLM) that is separate from the training backend. This introduces a rollout-training mismatch, exemplified in the following PPO objective:
|
||||
|
||||
$$
|
||||
\small{
|
||||
\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})}
|
||||
\Bigl[
|
||||
\min\Bigl(
|
||||
\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A,
|
||||
\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A
|
||||
\Bigr)
|
||||
\Bigr]
|
||||
}
|
||||
$$
|
||||
|
||||
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
|
||||
|
||||
Truncated Importance Sampling (TIS) addresses this issue by adapting the model update via importance-sampling correction. The gradient computation of the aforementioned PPO objective becomes
|
||||
|
||||
$$
|
||||
\small{
|
||||
\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})}
|
||||
\Bigl[
|
||||
\underbrace{\min(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}{\textcolor{red}{\pi_{\text{inference}}}(a, \theta_{\mathrm{old}})}, C)}_{\text{truncated importance ratio}} \cdot
|
||||
\nabla_\theta
|
||||
\min\Bigl(
|
||||
\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A,
|
||||
\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A
|
||||
\Bigr)
|
||||
\Bigr]
|
||||
}
|
||||
$$
|
||||
|
||||
where \\( C \\) is a hyper-parameter. In TRL, TIS is implemented for GRPO, and enabled by default when vLLM is used for generation (`use_vllm=True`)
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...
|
||||
use_vllm=True,
|
||||
vllm_importance_sampling_correction=True, # default True
|
||||
vllm_importance_sampling_cap=2.0, # hyper-parameter C
|
||||
)
|
||||
```
|
||||
|
||||
### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2508.09726
|
||||
|
||||
See [Experimental - GFPO](experimental#gfpo).
|
||||
|
||||
### Perception-Aware Policy Optimization for Multimodal Reasoning
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2507.06448
|
||||
|
||||
A novel policy gradient algorithm that encourages VLMs to learn to perceive while learning to reason. This is a TRL adaptation. The TRL implementation is not the official one provided by the authors.
|
||||
This is a TRL adaptation of PAPO. Note that this is not the official implementation. The official code can be found in [MikeWangWZHL/PAPO](https://github.com/MikeWangWZHL/PAPO).
|
||||
|
||||
```python
|
||||
from trl.experimental.papo import PAPOConfig, PAPOTrainer
|
||||
|
||||
training_args = PAPOConfig(
|
||||
# PAPO-specific params
|
||||
perception_loss_weight=0.01, # Weight for perception loss
|
||||
mask_ratio=0.6, # 40% of image will be masked
|
||||
mask_type="random", # Use patch masking (recommended)
|
||||
der_loss_weight1=0.02,
|
||||
der_loss_weight2=0.02,
|
||||
# ...other GRPO params...
|
||||
)
|
||||
trainer = PAPOTrainer(
|
||||
args=training_args,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## Direct Policy Optimization
|
||||
|
||||
Papers relating to the [`DPOTrainer`]
|
||||
|
||||
### Direct Preference Optimization (DPO): Your Language Model is Secretly a Reward Model
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2305.18290
|
||||
|
||||
Direct Preference Optimization (DPO) fine-tunes language models more efficiently and with better performance compared to reinforcement learning from human feedback (RLHF), by directly optimizing policy training based on human preferences. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="sigmoid", # losses in Appendix B of the paper
|
||||
per_device_train_batch_size=64, # batch size in Appendix B of the paper
|
||||
learning_rate=1e-6, # learning rate in Appendix B of the paper
|
||||
beta=0.1, # beta in Appendix B of the paper
|
||||
)
|
||||
```
|
||||
|
||||
### A General Theoretical Paradigm to Understand Learning from Human Preferences
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2310.12036
|
||||
|
||||
A new general objective, \\( \Psi \\)$PO, bypasses both key approximations in reinforcement learning from human preferences, allowing for theoretical analysis and empirical superiority over DPO. To reproduce the paper's setting, use this configuration: To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="ipo", # Section 5.1 of the paper
|
||||
per_device_train_batch_size=90, # mini-batch size in Section C.1 of the paper
|
||||
learning_rate=1e-2, # learning rate in Section C.1 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
These parameters only appear in the [published version](https://proceedings.mlr.press/v238/gheshlaghi-azar24a/gheshlaghi-azar24a.pdf)
|
||||
|
||||
### SLiC-HF: Sequence Likelihood Calibration with Human Feedback
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2305.10425
|
||||
|
||||
Sequence Likelihood Calibration (SLiC) is shown to be an effective and simpler alternative to Reinforcement Learning from Human Feedback (RLHF) for learning from human preferences in language models. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="hinge", # Section 2 of the paper
|
||||
per_device_train_batch_size=512, # batch size in Section 3.2 of the paper
|
||||
learning_rate=1e-4, # learning rate in Section 3.2 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
These parameters only appear in the [published version](https://openreview.net/pdf?id=0qSOodKmJaN)
|
||||
|
||||
### Towards Efficient and Exact Optimization of Language Model Alignment
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2402.00856
|
||||
|
||||
Efficient exact optimization (EXO) method is proposed to align language models with human preferences, providing a guaranteed and efficient alternative to reinforcement learning and direct preference optimization. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="exo_pair", # Section 3.2 of the paper
|
||||
per_device_train_batch_size=64, # batch size in Section B of the paper
|
||||
learning_rate=1e-6, # learning rate in Section B of the paper
|
||||
beta=0.1, # $\beta_r$ in Section B of the paper
|
||||
)
|
||||
```
|
||||
|
||||
### Noise Contrastive Alignment of Language Models with Explicit Rewards
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2402.05369
|
||||
|
||||
A framework using Noise Contrastive Estimation enhances language model alignment with both scalar rewards and pairwise preferences, demonstrating advantages over Direct Preference Optimization. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="nca_pair", # Section 4.1 of the paper
|
||||
per_device_train_batch_size=32, # batch size in Section C of the paper
|
||||
learning_rate=5e-6, # learning rate in Section C of the paper
|
||||
beta=0.01, # $\alpha$ in Section C of the paper
|
||||
)
|
||||
```
|
||||
|
||||
### Provably Robust DPO: Aligning Language Models with Noisy Feedback
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2403.00409
|
||||
|
||||
The paper introduces a robust direct preference optimization (rDPO) framework to address noise in preference-based feedback for language models, proving its sub-optimality gap and demonstrating its effectiveness through experiments. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="robust", # Section 3.1 of the paper
|
||||
per_device_train_batch_size=16, # batch size in Section B of the paper
|
||||
learning_rate=1e-3, # learning rate in Section B of the paper
|
||||
beta=0.01, # $\beta$ in Section B of the paper,
|
||||
max_prompt_length=128, # max prompt length in Section B of the paper
|
||||
max_length=512, # max length in Section B of the paper
|
||||
label_smoothing=0.1 # label smoothing $\epsilon$ in section 6 of the paper
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
### Binary Classifier Optimization for Large Language Model Alignment
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2404.04656
|
||||
|
||||
Theoretical analysis and a new algorithm, Binary Classifier Optimization, explain and enhance the alignment of large language models using binary feedback signals. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="bco_pair", # Section 4 of the paper
|
||||
per_device_train_batch_size=128, # batch size in Section C of the paper
|
||||
learning_rate=5e-7, # learning rate in Section C of the paper
|
||||
beta=0.01, # $\beta$ in Section C of the paper,
|
||||
max_prompt_length=1536, # max prompt length in Section C of the paper
|
||||
max_completion_length=512, # max completion length in Section C of the paper
|
||||
)
|
||||
```
|
||||
|
||||
For the unpaired version, the user should utilize [`experimental.bco.BCOConfig`] and [`experimental.bco.BCOTrainer`].
|
||||
|
||||
### Self-Play Preference Optimization for Language Model Alignment
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2405.00675
|
||||
|
||||
A self-play method called SPPO for language model alignment achieves state-of-the-art performance by approximating Nash equilibrium policy in a constant-sum game setting, outperforming other approaches with limited data. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="sppo_hard", # Section 3 of the paper
|
||||
per_device_train_batch_size=64, # batch size in Section C of the paper
|
||||
learning_rate=5e-7, # learning rate in Section C of the paper
|
||||
)
|
||||
```
|
||||
|
||||
### Distributional Preference Alignment of LLMs via Optimal Transport
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2406.05882
|
||||
|
||||
Alignment via Optimal Transport (AOT) aligns large language models distributionally by penalizing violations of stochastic dominance between positive and negative sample distributions, achieving state-of-the-art performance on alignment benchmarks. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="aot", # Section 3 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="aot_pair", # Section 3 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
There is no additional hyperparameter in the paper.
|
||||
|
||||
### Discovering Preference Optimization Algorithms with and for Large Language Models
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2406.08414
|
||||
|
||||
An LLM-driven method automatically discovers performant preference optimization algorithms, leading to a new algorithm called DiscoPOP that blends logistic and exponential losses. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="discopop", # Section 3 of the paper
|
||||
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
|
||||
learning_rate=5e-7, # learning rate in Section B.1 of the paper
|
||||
beta=0.05, # $\beta$ in Section B.1 of the paper,
|
||||
discopop_tau=0.05 # $\tau$ in Section E of the paper
|
||||
)
|
||||
```
|
||||
|
||||
### Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2408.06266
|
||||
|
||||
CLAIR and APO enhance LLM alignment through more contrastive preference pairs and controlled alignment objectives, improving model performance close to GPT4-turbo. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="apo_zero", # Section 4 of the paper
|
||||
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
|
||||
learning_rate=2e-7, # learning rate in Section 5.2 of the paper
|
||||
beta=0.1, # $\beta$ in Section 5.2 of the paper,
|
||||
max_prompt_length=512, # prompt length in Section 5.2 of the paper
|
||||
max_completion_length=512, # completion length in Section 5.2 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(
|
||||
loss_type="apo_down", # Section 4 of the paper
|
||||
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
|
||||
learning_rate=2e-7, # learning rate in Section 5.2 of the paper
|
||||
beta=0.1, # $\beta$ in Section 5.2 of the paper,
|
||||
max_prompt_length=512, # prompt length in Section 5.2 of the paper
|
||||
max_completion_length=512, # completion length in Section 5.2 of the paper
|
||||
)
|
||||
```
|
||||
|
||||
These parameters only appear in the [published version](https://aclanthology.org/2025.tacl-1.22.pdf)
|
||||
|
||||
## Supervised Fine-Tuning
|
||||
|
||||
Papers relating to the [`SFTTrainer`]
|
||||
|
||||
### EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2508.00180
|
||||
|
||||
Bias-Corrected Exponential Moving Average (BEMA) improves the stability and efficiency of language model fine-tuning by reducing stochasticity and eliminating bias. To use BEMA with SFT as described in the paper, you can use the [`BEMACallback`]:
|
||||
|
||||
```python
|
||||
from trl import BEMACallback, SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
...
|
||||
callbacks=[BEMACallback()],
|
||||
)
|
||||
```
|
||||
|
||||
### On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2508.05629
|
||||
|
||||
Dynamic Fine-Tuning (DFT) improves the generalization of Large Language Models (LLMs) by dynamically rescaling gradients, outperforming standard Supervised Fine-Tuning (SFT) and showing competitive results in offline reinforcement learning.
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DFT}}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|} \textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)} \; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
|
||||
$$
|
||||
|
||||
where \\( \text{sg}(\cdot) \\) is the stop-gradient operator. To use DFT with SFT as described in the paper, you can use the `loss_type="dft"` argument:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
loss_type="dft",
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
To closely match the paper’s setup, you can use the following configuration (see Sec. 4.1). Authors also mention that the hyperparameters are not very sensitive (Sec. 4.3):
|
||||
|
||||
```python
|
||||
SFTConfig(
|
||||
loss_type="dft",
|
||||
learning_rate=5e-5,
|
||||
max_length=2048,
|
||||
# Target batch size 256; achieved via per-device batch 8 * grad accumulation 32
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=32,
|
||||
)
|
||||
```
|
||||
|
||||
## Reinforce Leave-One-Out
|
||||
|
||||
Papers relating to the [`RLOOTrainer`]
|
||||
|
||||
### Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2402.14740
|
||||
|
||||
RLOO is a variant of REINFORCE that reduces variance by using leave-one-out baselines. It computes rewards by comparing each sample against the average of all other samples in the batch, providing more stable gradients than standard REINFORCE. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(
|
||||
per_device_train_batch_size=512, # section C Training Detail of the paper
|
||||
steps_per_generation=2 # section C Training Detail of the paper
|
||||
beta=0.03 # section C Training Detail of the paper
|
||||
num_generations=2, # experiments of paper different num_generations={2,4}
|
||||
learning_rate=1e-6 # section C Training Detail of the paper
|
||||
)
|
||||
```
|
||||
|
||||
## Contrastive Preference Optimization
|
||||
|
||||
Papers relating to the [`CPOTrainer`]
|
||||
|
||||
### AlphaPO -- Reward shape matters for LLM alignment
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2501.03884
|
||||
|
||||
AlphaPO is a new Direct Alignment Algorithms (DAAs) method that leverages an alpha-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and over-optimization. To reproduce the paper's setting, use this configuration:
|
||||
|
||||
```python
|
||||
from trl import CPOConfig
|
||||
|
||||
# Mistral-Instruct from Table 3 of the paper
|
||||
training_args = CPOConfig(
|
||||
loss_type="alphapo",
|
||||
alpha=0.25,
|
||||
beta=2.5,
|
||||
simpo_gamma=0.1,
|
||||
learning_rate=7e-7,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## Reward Modeling
|
||||
|
||||
Papers relating to the [`RewardTrainer`]
|
||||
|
||||
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2312.09244
|
||||
|
||||
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
|
||||
$$
|
||||
|
||||
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
|
||||
|
||||
```python
|
||||
from trl import RewardConfig
|
||||
|
||||
training_args = RewardConfig(
|
||||
center_rewards_coefficient=0.01, # η in the paper
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Llama 2: Open Foundation and Fine-Tuned Chat Models
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2307.09288
|
||||
|
||||
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
|
||||
$$
|
||||
|
||||
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
|
||||
|
||||
```python
|
||||
def add_margin(example):
|
||||
preference_to_margin = {
|
||||
"significantly better": 1.0,
|
||||
"better": 2.0/3.0,
|
||||
"slightly better": 1.0/3.0,
|
||||
"negligibly better / unsure": 0.0,
|
||||
}
|
||||
return {"margin": preference_to_margin[example["preference_label"]]}
|
||||
|
||||
dataset = dataset.map(add_margin)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
Papers relating to training a student model with the help of a teacher model.
|
||||
|
||||
### On-Policy Distillation
|
||||
**📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/
|
||||
|
||||
On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model.
|
||||
|
||||
| Method | Sampling | Reward signal |
|
||||
|-------------------------|------------|---------------|
|
||||
| Supervised finetuning | off-policy | dense |
|
||||
| Reinforcement learning | on-policy | sparse |
|
||||
| On-policy distillation | on-policy | dense |
|
||||
|
||||
On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT.
|
||||
|
||||
Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.
|
||||
|
||||
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:
|
||||
|
||||
```python
|
||||
from trl import GKDConfig
|
||||
|
||||
config = GKDConfig(
|
||||
lmbda=1.0, # student produces rollouts for all batches
|
||||
beta=1.0, # to ensure reverse-kl as the loss function
|
||||
teacher_model_name_or_path="teacher-model", # specify the teacher model
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:
|
||||
|
||||
```python
|
||||
from trl.experimental import GOLDConfig
|
||||
|
||||
config = GOLDConfig(
|
||||
lmbda=1.0, # student produces rollouts for all batches
|
||||
beta=1.0, # to ensure reverse-kl as the loss function
|
||||
teacher_model_name_or_path="teacher-model", # specify the teacher model
|
||||
|
||||
)
|
||||
```
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user