diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py index 3ed104820b4..18c5be29c5c 100644 --- a/tests/distributed/test_pipeline_partition.py +++ b/tests/distributed/test_pipeline_partition.py @@ -34,3 +34,27 @@ def test_custom_layer_partition(): # Wrong number of layers with pytest.raises(ValueError): _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + + +@pytest.mark.parametrize( + "num_hidden_layers,pp_size,pp_rank,indices", + [ + # pp_size 2 + (2, 2, 0, (0, 1)), + (2, 2, 1, (1, 2)), + (3, 2, 0, (0, 2)), + (3, 2, 1, (2, 3)), + # pp_size 3 + (3, 3, 0, (0, 1)), + (3, 3, 1, (1, 2)), + (3, 3, 2, (2, 3)), + (4, 3, 0, (0, 1)), + (4, 3, 1, (1, 3)), + (4, 3, 2, (3, 4)), + (5, 3, 0, (0, 2)), + (5, 3, 1, (2, 4)), + (5, 3, 2, (4, 5)), + ]) +def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int, + pp_rank: int, indices: tuple[int, int]): + assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 79f9a84b476..d6fca4f0221 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -67,8 +67,17 @@ def split_tensor_along_last_dim( def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, - the last partition will have the remaining layers. + the remaining layers are evenly distributed across all but the last + partition. The last partition is excluded because it often contains an + additional norm layer and we are attempting to balance compute. + + If `pp_size > 2` and the number of remaining layers is + `0 < x <= pp_size - 2` then the remaining layers are evenly distributed + across the middle partitions. The first and last partitions are excluded + because they contain the input and output embeddings respectively and we + are attempting to reduce maximum memory consumption across partitions. """ partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: @@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, if sum(partitions) != num_hidden_layers: raise ValueError( f"{sum(partitions)=} does not match {num_hidden_layers=}.") - start_layer = sum(partitions[:pp_rank]) - end_layer = start_layer + partitions[pp_rank] else: layers_per_partition = num_hidden_layers // pp_size - start_layer = pp_rank * layers_per_partition - end_layer = start_layer + layers_per_partition + partitions = [layers_per_partition for _ in range(pp_size)] - if pp_rank == pp_size - 1: - end_layer = num_hidden_layers + if remaining_layers := num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info("Hidden layers were unevenly partitioned: %s", + ",".join(str(p) for p in partitions)) + logger.info("This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable") + + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] return (start_layer, end_layer)