[Doc] add debugging tips for crash and multi-node debugging (#5581)

This commit is contained in:
youkaichao
2024-06-16 19:08:01 -07:00
committed by GitHub
parent f07d513320
commit 845a3f26f9

View File

@ -24,6 +24,8 @@ If you have already taken care of the above issues, but the vLLM instance still
With more logging, hopefully you can find the root cause of the issue.
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
Here are some common issues that can cause hangs:
- **Incorrect network setup**: The vLLM instance cannot get the correct IP address. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``.
@ -31,15 +33,26 @@ Here are some common issues that can cause hangs:
.. code-block:: python
# save it as `test.py` , and run it with `NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py`
# adjust `--nproc-per-node` to the number of GPUs you want to use.
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl")
data = torch.FloatTensor([1,] * 128).to(f"cuda:{dist.get_rank()}")
local_rank = dist.get_rank() % torch.cuda.device_count()
data = torch.FloatTensor([1,] * 128).to(f"cuda:{local_rank}")
dist.all_reduce(data, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
value = data.mean().item()
assert value == dist.get_world_size()
.. tip::
Save the script as ``test.py``.
If you are testing in a single-node, run it with ``NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py``, adjust ``--nproc-per-node`` to the number of GPUs you want to use.
If you are testing with multi-nodes, run it with ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py``. Adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup. Make sure ``MASTER_ADDR``:
- is the correct IP address of the master node
- is reachable from all nodes
- is set before running the script.
If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.