mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
950 lines
39 KiB
Plaintext
950 lines
39 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ed65145e-1dcc-4fad-85cd-8cc683fde91b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# VeRL Ray API Tutorial"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5f90dfa8-3285-41e4-ba44-12abcb76b3ce",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"## Chapter 1: Ray Basics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "cafc9be5-614b-4380-9b69-31525ea4e73a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"# turn off Megatron timer\n",
|
|
"os.environ['MEGATRON_USE_CUDA_TIMER'] = '0'\n",
|
|
"os.environ['MEGATRON_START_PROCESS_TIMER'] = 'False'\n",
|
|
"os.environ['NCCL_DEBUG'] = 'WARN'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3fb65288-4410-43b4-9ffc-8538197ae039",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import ray\n",
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ca7f0778-f676-4f67-8c1f-9f6c8eee74e2",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2024-01-22 14:52:10,398\tINFO worker.py:1655 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
" <div style=\"margin-left: 50px;display: flex;flex-direction: row;align-items: center\">\n",
|
|
" <h3 style=\"color: var(--jp-ui-font-color0)\">Ray</h3>\n",
|
|
" <svg version=\"1.1\" id=\"ray\" width=\"3em\" viewBox=\"0 0 144.5 144.6\" style=\"margin-left: 3em;margin-right: 3em\">\n",
|
|
" <g id=\"layer-1\">\n",
|
|
" <path fill=\"#00a2e9\" class=\"st0\" d=\"M97.3,77.2c-3.8-1.1-6.2,0.9-8.3,5.1c-3.5,6.8-9.9,9.9-17.4,9.6S58,88.1,54.8,81.2c-1.4-3-3-4-6.3-4.1\n",
|
|
" c-5.6-0.1-9.9,0.1-13.1,6.4c-3.8,7.6-13.6,10.2-21.8,7.6C5.2,88.4-0.4,80.5,0,71.7c0.1-8.4,5.7-15.8,13.8-18.2\n",
|
|
" c8.4-2.6,17.5,0.7,22.3,8c1.3,1.9,1.3,5.2,3.6,5.6c3.9,0.6,8,0.2,12,0.2c1.8,0,1.9-1.6,2.4-2.8c3.5-7.8,9.7-11.8,18-11.9\n",
|
|
" c8.2-0.1,14.4,3.9,17.8,11.4c1.3,2.8,2.9,3.6,5.7,3.3c1-0.1,2,0.1,3,0c2.8-0.5,6.4,1.7,8.1-2.7s-2.3-5.5-4.1-7.5\n",
|
|
" c-5.1-5.7-10.9-10.8-16.1-16.3C84,38,81.9,37.1,78,38.3C66.7,42,56.2,35.7,53,24.1C50.3,14,57.3,2.8,67.7,0.5\n",
|
|
" C78.4-2,89,4.7,91.5,15.3c0.1,0.3,0.1,0.5,0.2,0.8c0.7,3.4,0.7,6.9-0.8,9.8c-1.7,3.2-0.8,5,1.5,7.2c6.7,6.5,13.3,13,19.8,19.7\n",
|
|
" c1.8,1.8,3,2.1,5.5,1.2c9.1-3.4,17.9-0.6,23.4,7c4.8,6.9,4.6,16.1-0.4,22.9c-5.4,7.2-14.2,9.9-23.1,6.5c-2.3-0.9-3.5-0.6-5.1,1.1\n",
|
|
" c-6.7,6.9-13.6,13.7-20.5,20.4c-1.8,1.8-2.5,3.2-1.4,5.9c3.5,8.7,0.3,18.6-7.7,23.6c-7.9,5-18.2,3.8-24.8-2.9\n",
|
|
" c-6.4-6.4-7.4-16.2-2.5-24.3c4.9-7.8,14.5-11,23.1-7.8c3,1.1,4.7,0.5,6.9-1.7C91.7,98.4,98,92.3,104.2,86c1.6-1.6,4.1-2.7,2.6-6.2\n",
|
|
" c-1.4-3.3-3.8-2.5-6.2-2.6C99.8,77.2,98.9,77.2,97.3,77.2z M72.1,29.7c5.5,0.1,9.9-4.3,10-9.8c0-0.1,0-0.2,0-0.3\n",
|
|
" C81.8,14,77,9.8,71.5,10.2c-5,0.3-9,4.2-9.3,9.2c-0.2,5.5,4,10.1,9.5,10.3C71.8,29.7,72,29.7,72.1,29.7z M72.3,62.3\n",
|
|
" c-5.4-0.1-9.9,4.2-10.1,9.7c0,0.2,0,0.3,0,0.5c0.2,5.4,4.5,9.7,9.9,10c5.1,0.1,9.9-4.7,10.1-9.8c0.2-5.5-4-10-9.5-10.3\n",
|
|
" C72.6,62.3,72.4,62.3,72.3,62.3z M115,72.5c0.1,5.4,4.5,9.7,9.8,9.9c5.6-0.2,10-4.8,10-10.4c-0.2-5.4-4.6-9.7-10-9.7\n",
|
|
" c-5.3-0.1-9.8,4.2-9.9,9.5C115,72.1,115,72.3,115,72.5z M19.5,62.3c-5.4,0.1-9.8,4.4-10,9.8c-0.1,5.1,5.2,10.4,10.2,10.3\n",
|
|
" c5.6-0.2,10-4.9,9.8-10.5c-0.1-5.4-4.5-9.7-9.9-9.6C19.6,62.3,19.5,62.3,19.5,62.3z M71.8,134.6c5.9,0.2,10.3-3.9,10.4-9.6\n",
|
|
" c0.5-5.5-3.6-10.4-9.1-10.8c-5.5-0.5-10.4,3.6-10.8,9.1c0,0.5,0,0.9,0,1.4c-0.2,5.3,4,9.8,9.3,10\n",
|
|
" C71.6,134.6,71.7,134.6,71.8,134.6z\"/>\n",
|
|
" </g>\n",
|
|
" </svg>\n",
|
|
" <table>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Python version:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b>3.9.2</b></td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Ray version:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b> 2.3.0</b></td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <td style=\"text-align: left\"><b>Dashboard:</b></td>\n",
|
|
" <td style=\"text-align: left\"><b><a href=\"http://127.0.0.1:8265\" target=\"_blank\">http://127.0.0.1:8265</a></b></td>\n",
|
|
"</tr>\n",
|
|
"\n",
|
|
" </table>\n",
|
|
" </div>\n",
|
|
"</div>\n"
|
|
],
|
|
"text/plain": [
|
|
"RayContext(dashboard_url='127.0.0.1:8265', python_version='3.9.2', ray_version='2.3.0', ray_commit='{{RAY_COMMIT_SHA}}', address_info={'node_ip_address': '10.122.229.29', 'raylet_ip_address': '10.122.229.29', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2024-01-22_14-52-08_389091_1470202/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2024-01-22_14-52-08_389091_1470202/sockets/raylet', 'webui_url': '127.0.0.1:8265', 'session_dir': '/tmp/ray/session_2024-01-22_14-52-08_389091_1470202', 'metrics_export_port': 46547, 'gcs_address': '10.122.229.29:57281', 'address': '10.122.229.29:57281', 'dashboard_agent_listen_port': 52365, 'node_id': 'ec81d96ec29985ce0b98441aabb46fc8dcb49b0e1dec1088f973977f'})"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489193)\u001b[0m rank 0, value: tensor([1.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489357)\u001b[0m rank 2, value: tensor([3.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489358)\u001b[0m rank 3, value: tensor([4.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489356)\u001b[0m rank 1, value: tensor([2.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489837)\u001b[0m rank 0, value: tensor([2.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489999)\u001b[0m rank 3, value: tensor([5.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489998)\u001b[0m rank 2, value: tensor([4.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1489997)\u001b[0m rank 1, value: tensor([3.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490000)\u001b[0m rank 0, value: tensor([3.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490633)\u001b[0m rank 3, value: tensor([6.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490635)\u001b[0m rank 5, value: tensor([8.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490636)\u001b[0m rank 6, value: tensor([9.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490631)\u001b[0m rank 1, value: tensor([4.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490637)\u001b[0m rank 7, value: tensor([10.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490634)\u001b[0m rank 4, value: tensor([7.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulator pid=1490632)\u001b[0m rank 2, value: tensor([5.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491761)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491761)\u001b[0m rank 0, value: tensor([10.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491923)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491922)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491921)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491924)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491919)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491925)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491920)\u001b[0m 10\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491923)\u001b[0m rank 5, value: tensor([15.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491921)\u001b[0m rank 3, value: tensor([13.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491924)\u001b[0m rank 7, value: tensor([17.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491925)\u001b[0m rank 6, value: tensor([16.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491920)\u001b[0m rank 2, value: tensor([12.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491919)\u001b[0m rank 1, value: tensor([11.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(GPUAccumulatorDecorator pid=1491922)\u001b[0m rank 4, value: tensor([14.], device='cuda:0')\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument timing_log_level:2 with timing_log_level:0\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument tensor_model_parallel_size:1 with tensor_model_parallel_size:4\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument use_distributed_optimizer:False with use_distributed_optimizer:True\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument weight_decay:0.01 with weight_decay:0.0\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument sgd_momentum:0.9 with sgd_momentum:0.0\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument bf16:False with bf16:True\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument clip_grad:1.0 with clip_grad:0.0\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m WARNING: overriding default argument sequence_parallel:False with sequence_parallel:True\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m using world size: 4, data-parallel-size: 1, tensor-model-parallel size: 4, pipeline-model-parallel size: 1 \n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m setting global batch size to 1\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m accumulate gradients in fp32 and all-reduce gradients in fp32 for bfloat16 data type. \n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m using torch.bfloat16 for parameters ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m setting number of micro-batches to constant 1\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [2024-01-22 14:53:47] > initializing torch distributed ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group1]\tnew_group dp\thigh_stream False\tranks: [0]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group2]\tnew_group dp\thigh_stream False\tranks: [1]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group3]\tnew_group dp\thigh_stream False\tranks: [2]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group4]\tnew_group dp\thigh_stream False\tranks: [3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group5]\tnew_group mp\thigh_stream False\tranks: [0, 1, 2, 3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group6]\tnew_group tp\thigh_stream False\tranks: [0, 1, 2, 3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group8]\tnew_group pp\thigh_stream False\tranks: [0]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group9]\tnew_group emb\thigh_stream False\tranks: [0]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group10]\tnew_group posemb\thigh_stream False\tranks: [0]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group11]\tnew_group pp\thigh_stream False\tranks: [1]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group12]\tnew_group emb\thigh_stream False\tranks: [1]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group13]\tnew_group posemb\thigh_stream False\tranks: [1]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group14]\tnew_group pp\thigh_stream False\tranks: [2]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group15]\tnew_group emb\thigh_stream False\tranks: [2]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group16]\tnew_group posemb\thigh_stream False\tranks: [2]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group17]\tnew_group pp\thigh_stream False\tranks: [3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group18]\tnew_group emb\thigh_stream False\tranks: [3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [rank0]\t[group19]\tnew_group posemb\thigh_stream False\tranks: [3]\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m NCCL version 2.18.6+cuda12.1\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m this torch version does not have hires nccl profiling. nccl profiling is not enabled this time. this is only a kindly notice, not a error message\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [2024-01-22 14:53:52] > initialized tensor model parallel with size 4\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [2024-01-22 14:53:52] > initialized pipeline model parallel with size 1\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m > setting random seeds to 1234 ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m > compiling dataset index builder ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m already compiled, skip compiling\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m >>> done with dataset index builder. Compilation time: 0.005 seconds\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m > compiling and loading fused kernels ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m will compile object fused_mix_prec_layer_norm_cuda with extra cuda flags ['-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__']\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m [2024-01-22 14:53:52] >>> done with compiling and loading fused kernels. Compilation time: 0.039 seconds\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m this torch version does not have hires nccl profiling. nccl profiling is not enabled this time. this is only a kindly notice, not a error message\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m > compiling and loading fused kernels ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m will compile object fused_mix_prec_layer_norm_cuda with extra cuda flags ['-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__']\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m [2024-01-22 14:53:52] >>> done with compiling and loading fused kernels. Compilation time: 0.044 seconds\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m this torch version does not have hires nccl profiling. nccl profiling is not enabled this time. this is only a kindly notice, not a error message\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m > compiling and loading fused kernels ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m will compile object fused_mix_prec_layer_norm_cuda with extra cuda flags ['-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__']\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m [2024-01-22 14:53:52] >>> done with compiling and loading fused kernels. Compilation time: 0.044 seconds\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m this torch version does not have hires nccl profiling. nccl profiling is not enabled this time. this is only a kindly notice, not a error message\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m > compiling and loading fused kernels ...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m will compile object fused_mix_prec_layer_norm_cuda with extra cuda flags ['-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__']\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m [2024-01-22 14:53:52] >>> done with compiling and loading fused kernels. Compilation time: 0.045 seconds\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m No modifications detected for re-loaded extension module fused_mix_prec_layer_norm_cuda, skipping build step...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m Loading extension module fused_mix_prec_layer_norm_cuda...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m No modifications detected for re-loaded extension module fused_mix_prec_layer_norm_cuda, skipping build step...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m Loading extension module fused_mix_prec_layer_norm_cuda...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m No modifications detected for re-loaded extension module fused_mix_prec_layer_norm_cuda, skipping build step...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m Loading extension module fused_mix_prec_layer_norm_cuda...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m No modifications detected for re-loaded extension module fused_mix_prec_layer_norm_cuda, skipping build step...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m Loading extension module fused_mix_prec_layer_norm_cuda...\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493490)\u001b[0m Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493654)\u001b[0m Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493653)\u001b[0m Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.\n",
|
|
"\u001b[2m\u001b[36m(MLPLayerWorker pid=1493652)\u001b[0m Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Build a local ray cluster. The head node and worker node are on this machine\n",
|
|
"ray.init()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a0d35eb2-cf28-411a-b39f-f45e97c2edba",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"Implement an Accumulator class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "86c98e9c-26a5-44d0-bbe9-064f8809708c",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@ray.remote\n",
|
|
"class Accumulator:\n",
|
|
" def __init__(self):\n",
|
|
" self.value = 0\n",
|
|
" \n",
|
|
" def add(self, x):\n",
|
|
" self.value += x\n",
|
|
" \n",
|
|
" def get_value(self):\n",
|
|
" return self.value"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "9d5aa406-23c3-4270-a6f8-d4af20769ed9",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Instantiate an accumulator. Accumulator can be viewed as a process, acting as an RPC service.\n",
|
|
"accumulator = Accumulator.remote()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "56850d3a-3803-4395-a73e-a27e99dc6905",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n",
|
|
"# Get the value\n",
|
|
"value = ray.get(value_ref)\n",
|
|
"print(value)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "932807f1-792b-42af-aa44-6a6202718919",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"10\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Accumulate, then check the result.\n",
|
|
"accumulator.add.remote(10) # Similarly, the 'add' here will return immediately.\n",
|
|
"new_value = ray.get(accumulator.get_value.remote())\n",
|
|
"print(new_value)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ca31e450-e02c-44d0-86a1-a3f9eeadd311",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Chapter 2: Resource Pool and RayWorkerGroup\n",
|
|
"In the previous example, it was a simple single-process worker. \n",
|
|
"In this example, we implement a worker with a GPU and form a RayWorkerGroup. Within this RayWorkerGroup, we implement a simple operation of an accumulator."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "84727b3d-1351-4550-9009-48526d18dc2c",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n",
|
|
"from single_controller.base import Worker"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "aab5eeb7-6485-4911-94cc-3f5ca52b6634",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"resource_pool = RayResourcePool([4], use_gpu=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "54e4d531-ca50-43e9-ab79-d24bb66ffe0e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@ray.remote\n",
|
|
"class GPUAccumulator(Worker):\n",
|
|
"\n",
|
|
" def __init__(self) -> None:\n",
|
|
" super().__init__()\n",
|
|
" # The initial value of each rank is the same as the rank\n",
|
|
" self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n",
|
|
"\n",
|
|
" def add(self, x):\n",
|
|
" self.value += x\n",
|
|
" print(f'rank {self.rank}, value: {self.value}')\n",
|
|
" return self.value.cpu()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "73e0658b-a6a7-4a86-a53f-5442c8cd56fe",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[tensor([1.]), tensor([2.]), tensor([3.]), tensor([4.])]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\n",
|
|
"class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\n",
|
|
"worker_group = RayWorkerGroup(resource_pool, class_with_args)\n",
|
|
"print(worker_group.execute_all_sync('add', x=[1,1,1,1]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e36f73e8-4f38-459f-a09b-f6ffb95fd42b",
|
|
"metadata": {},
|
|
"source": [
|
|
"The principle of parameter passing: The input parameter is a list of length world_size, where each element in the list is dispatched respectively to each worker in the RayWorkerGroup. \n",
|
|
"The return parameter is also a list, corresponding to the return value of each worker."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3873383b-e89d-4283-ae3e-af367e6160a5",
|
|
"metadata": {},
|
|
"source": [
|
|
"### GPU Resource Sharing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cf62933e-143e-4a15-90c3-3d0d0568d89c",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"RayWorkerGroups mapped to the same resource pool share the GPU. In this example, we implement three resource pools: the first occupies 4 GPUs, the second also occupies 4 GPUs, and the last occupies all 8 GPUs. Among them, the first resource pool reuses the resource pool mentioned above."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "100e75f1-6f2b-44b6-b4e1-f703eae8e5bf",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a new resource pool and then merge the newly created resource pool with the previous one.\n",
|
|
"resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix='a')\n",
|
|
"resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "59393509-037b-475a-800a-5e6de9ea7b66",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Establish a RayWorkerGroup on the newly created resource pool.\n",
|
|
"worker_group_1 = RayWorkerGroup(resource_pool_1, class_with_args)\n",
|
|
"worker_group_merge = RayWorkerGroup(resource_pool_merge, class_with_args)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "558dc5bf-75aa-4675-b8de-99944e088522",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[tensor([2.]), tensor([3.]), tensor([4.]), tensor([5.])]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\n",
|
|
"output_1 = worker_group_1.execute_all_sync('add', x=[2,2,2,2])\n",
|
|
"print(output_1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "30228667-f8ef-448c-a2cc-44241ede86d9",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[tensor([3.]), tensor([4.]), tensor([5.]), tensor([6.]), tensor([7.]), tensor([8.]), tensor([9.]), tensor([10.])]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\n",
|
|
"output_merge = worker_group_merge.execute_all_sync('add', x=[3,3,3,3,3,3,3,3])\n",
|
|
"print(output_merge)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "041329de-0e27-4e31-ac11-f0dd54f4f4b3",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"4 4 8\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(worker_group.world_size, worker_group_1.world_size, worker_group_merge.world_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8cb9fc59-6035-4b72-9863-67bd4d48f3fe",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Chapter 3: Data Dispatch, Execution and Collection"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0e1fedea-16d1-4dec-b269-c402747e11e8",
|
|
"metadata": {},
|
|
"source": [
|
|
"In the above example, we used the `execute_all_sync` function in the RayWorkerGroup to dispatch data from the driver to each worker. This is very inconvenient for coding. \n",
|
|
"In this chapter, we use the form of function decorators to allow RayWorkerGroup to directly call functions written in the Worker, and to greatly simplify parameter passing."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "c70d1b33-030a-4681-aeb2-16ab48b6445b",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from single_controller.ray.decorator import register, Dispatch, Execute"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "bb6b76c1-8caf-4749-b20d-c3f842751aa4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@ray.remote\n",
|
|
"class GPUAccumulatorDecorator(Worker):\n",
|
|
"\n",
|
|
" def __init__(self) -> None:\n",
|
|
" super().__init__()\n",
|
|
" # The initial value of each rank is the same as the rank\n",
|
|
" self.value = torch.zeros(size=(1,), device='cuda') + self.rank\n",
|
|
" \n",
|
|
" # map from a single input to all the worker\n",
|
|
" @register(Dispatch.ONE_TO_ALL)\n",
|
|
" def add(self, x):\n",
|
|
" print(x)\n",
|
|
" self.value = self.value + x\n",
|
|
" print(f'rank {self.rank}, value: {self.value}')\n",
|
|
" return self.value.cpu()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "2d2bb460-4e9d-4fc5-a767-e7c8b9a4d7fe",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class_with_args = RayClassWithInitArgs(cls=GPUAccumulatorDecorator)\n",
|
|
"gpu_accumulator_decorator = RayWorkerGroup(resource_pool_merge, class_with_args)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "e4db17f7-8896-4ea6-a732-2146400ff2da",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[tensor([10.]), tensor([11.]), tensor([12.]), tensor([13.]), tensor([14.]), tensor([15.]), tensor([16.]), tensor([17.])]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# As we can see, 10 is automatically dispatched to each Worker in this RayWorkerGroup.\n",
|
|
"print(gpu_accumulator_decorator.add(x=10))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f8b54ccb-0064-4d06-b74b-fbf39f8c5faf",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Custom Dispatch, Collection\n",
|
|
"Users can customize `dispatch` and `collection` function. You only need to write the `dispatch_fn` and `collect_fn` functions yourself. We also support executing RPC only on rank_zero, with specific examples provided below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "9ad68e0f-372e-4792-b4db-c2406f211113",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "74f91242-3176-4ea8-adce-104a58d21874",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def two_to_all_dispatch_fn(worker_group, *args, **kwargs):\n",
|
|
" \"\"\"\n",
|
|
" Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\n",
|
|
" \"\"\"\n",
|
|
" for arg in args:\n",
|
|
" assert len(arg) == 2\n",
|
|
" for i in range(worker_group.world_size - 2):\n",
|
|
" arg.append(arg[i % 2])\n",
|
|
" for k, v in kwargs.items():\n",
|
|
" assert len(v) == 2\n",
|
|
" for i in range(worker_group.world_size - 2):\n",
|
|
" v.append(v[i % 2])\n",
|
|
" return args, kwargs\n",
|
|
"\n",
|
|
"\n",
|
|
"@ray.remote\n",
|
|
"class TestActor(Worker):\n",
|
|
" # TODO: pass *args and **kwargs is bug prone and not very convincing\n",
|
|
" def __init__(self, x) -> None:\n",
|
|
" super().__init__()\n",
|
|
" self._x = x\n",
|
|
"\n",
|
|
" def foo(self, y):\n",
|
|
" return self._x + y\n",
|
|
"\n",
|
|
" @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n",
|
|
" def foo_rank_zero(self, x, y):\n",
|
|
" return self._x + y + x\n",
|
|
"\n",
|
|
" @register(dispatch_mode={'dispatch_fn': two_to_all_dispatch_fn, 'collect_fn': collect_all_to_all})\n",
|
|
" def foo_custom(self, x, y):\n",
|
|
" return self._x + y + x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "00f386a8-c5f7-49c4-87ec-6e2ab3547bf1",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n",
|
|
"worker_group = RayWorkerGroup(resource_pool, class_with_args)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "45ae935c-4f92-4c28-9e35-add439dcd2a0",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\n",
|
|
"assert output_ref == [8, 10, 8, 10]\n",
|
|
"\n",
|
|
"output_ref = worker_group.foo_rank_zero(x=1, y=2)\n",
|
|
"assert output_ref == 5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "165ae826-7b20-4305-a5c6-d7f236eab410",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Chapter 4: MegatronRayWorkerGroup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d0855a6b-41d7-41bb-95e6-d9395b175d81",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we implement a `MegatronRayWorkerGroup`, within which we create a Megatron and then run a tensor parallel (tp) split Llama mlp layer. Here, we use a complex dispatch mode, `Megatron_COMPUTE`. This dispatch mode assumes that user passes the data partitioned by DP dimension. The data is dispatched to all tp/pp ranks within the same dp group, and ultimately only collects output data from tp=0 and the last pp. In this way, for users that only write code on the driver, the Megatron behind the RPC becomes transparent."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "d36b4606-0e28-4fd6-8e04-498408ca161a",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from single_controller.ray import MegatronRayWorkerGroup\n",
|
|
"from single_controller.megatron.worker import MegatronWorker\n",
|
|
"from omegaconf import OmegaConf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "430fa41d-7874-49fd-b7f0-2c173a6c849b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"@ray.remote\n",
|
|
"class MLPLayerWorker(MegatronWorker):\n",
|
|
" @register(Dispatch.ONE_TO_ALL)\n",
|
|
" def init_model(self, config):\n",
|
|
" from omegaconf import OmegaConf\n",
|
|
" from verl.models.llama.megatron.layers import ParallelLlamaMLP\n",
|
|
" megatron_config = OmegaConf.create({'sequence_parallel_enabled': False})\n",
|
|
" self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\n",
|
|
" \n",
|
|
" @register(Dispatch.ONE_TO_ALL)\n",
|
|
" def get_weights(self):\n",
|
|
" output = {}\n",
|
|
" for key, val in self.parallel_layer.named_parameters():\n",
|
|
" output[key] = val\n",
|
|
" return output\n",
|
|
" \n",
|
|
" @register(Dispatch.MEGATRON_COMPUTE)\n",
|
|
" def run_layer(self, x):\n",
|
|
" x = x.to('cuda')\n",
|
|
" y = self.parallel_layer(x)\n",
|
|
" return y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "2b749128-4190-454f-96a2-72477a7a77bd",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\n",
|
|
"layer_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool,\n",
|
|
" ray_cls_with_init=layer_cls,\n",
|
|
" default_megatron_kwargs={\n",
|
|
" 'tensor_model_parallel_size': 4,\n",
|
|
" 'pipeline_model_parallel_size': 1,\n",
|
|
" 'num_layers_per_virtual_pipeline_stage': None\n",
|
|
" })\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "f134f99e-678a-4b6f-8d70-fce2a5c9ac3e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"4 4 1 1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(layer_worker_group.world_size, layer_worker_group.tp_size, layer_worker_group.pp_size, layer_worker_group.dp_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "204ce866-868a-4b0b-900b-0dda6164a995",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"ffn_hidden_size = 11008\n",
|
|
"batch_size = 16\n",
|
|
"seq_len = 2048\n",
|
|
"hidden_size = 4096\n",
|
|
"\n",
|
|
"config = OmegaConf.create({\n",
|
|
" 'hidden_size': hidden_size,\n",
|
|
" 'intermediate_size': ffn_hidden_size,\n",
|
|
" 'hidden_act': 'silu',\n",
|
|
" 'pretraining_tp': 1,\n",
|
|
" 'tp': layer_worker_group.tp_size,\n",
|
|
"})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "d35a0216-7787-4ea7-86ab-c87f8e98f13d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.rand(size=(seq_len, batch_size, hidden_size), dtype=torch.float32)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "0f84eaf8-adcd-4f89-a1ce-800db07df242",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[None, None, None, None]"
|
|
]
|
|
},
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"layer_worker_group.init_model(config)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "5203bb40-f831-4898-832c-5f65596ba0f6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([2048, 16, 4096])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"output = layer_worker_group.run_layer([x]) # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n",
|
|
"print(output[0].shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "6a159884-7409-4dfe-a367-3c31090b09a1",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(gpu_accumulator_decorator.world_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "806c36c3-a6a6-4c99-88fb-f2e2468b12db",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Shutdown ray cluster\n",
|
|
"ray.shutdown()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ea54e858-5221-493e-b53a-4765b918a0fa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"fileId": "e862b5a9-13b2-48bf-9eda-96bd0c76e37c",
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.2 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.2"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|