[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)

https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792
Approved by: https://github.com/XilunWu
This commit is contained in:
Chien-Chin Huang
2025-10-17 23:41:05 -07:00
committed by PyTorch MergeBot
parent ad67170c8b
commit 4740ce7787
2 changed files with 56 additions and 11 deletions

View File

@ -1068,10 +1068,16 @@ def _context_parallel_buffers(
for buffer, seq_dim in zip(buffers, buffer_seq_dims):
if isinstance(buffer, torch.Tensor):
# TODO: the load balance doesn't perform error handling.
# NOTE: assuming batch dim is 0
if load_balance_indices is not None:
# NOTE: assuming batch dim is 0
# TODO: we should expclitly ask users to unsqueeze the batch dim.
# But this is a BC breaking ask.
# However, what we have done today is also not very safe.
idx_batch_size = load_balance_indices.size(0)
data_batch_size = buffer.size(0)
data_batch_size = buffer.size(0) if seq_dim > 0 else 1
if idx_batch_size != 1 and idx_batch_size != data_batch_size:
raise ValueError(
"Cannot rearrange buffer: "
@ -1079,16 +1085,20 @@ def _context_parallel_buffers(
f"but buffer has shape {buffer.shape}."
)
for i in range(data_batch_size):
index = (
load_balance_indices[0] # identical load-balance in batch
if idx_batch_size == 1
else load_balance_indices[i]
if seq_dim == 0:
buffer = torch.index_select(
buffer, dim=0, index=load_balance_indices[0]
)
buffer_batch_i = torch.index_select(
buffer[i], dim=seq_dim - 1, index=index
)
buffer[i] = buffer_batch_i
else:
indices = load_balance_indices
if idx_batch_size == 1:
size = [data_batch_size] + list(indices.size())[1:]
indices = indices.expand(*size)
for i in range(data_batch_size):
buffer[i] = torch.index_select(
buffer[i], dim=seq_dim - 1, index=indices[i]
)
# use DTensor to shard the buffer on sequence dimension, retain the local tensor
sharded_buffer = distribute_tensor(