[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)

https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792
Approved by: https://github.com/XilunWu
This commit is contained in:
Chien-Chin Huang
2025-10-17 23:41:05 -07:00
committed by PyTorch MergeBot
parent ad67170c8b
commit 4740ce7787
2 changed files with 56 additions and 11 deletions

View File

@ -771,5 +771,40 @@ class TestCPCustomOps(DTensorTestBase):
torch.library.opcheck(flex_cp_allgather, example)
class TestSharding(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(2)
@with_comms
def test_context_parallel_shard(self) -> None:
B = 4
seq_len = 32
device_mesh = init_device_mesh(
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
)
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
load_balancer = _HeadTailLoadBalancer(
seq_len, self.world_size, torch.device(self.device_type)
)
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
)
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
torch.cat(
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
),
)
if __name__ == "__main__":
run_tests()

View File

@ -1068,10 +1068,16 @@ def _context_parallel_buffers(
for buffer, seq_dim in zip(buffers, buffer_seq_dims):
if isinstance(buffer, torch.Tensor):
# TODO: the load balance doesn't perform error handling.
# NOTE: assuming batch dim is 0
if load_balance_indices is not None:
# NOTE: assuming batch dim is 0
# TODO: we should expclitly ask users to unsqueeze the batch dim.
# But this is a BC breaking ask.
# However, what we have done today is also not very safe.
idx_batch_size = load_balance_indices.size(0)
data_batch_size = buffer.size(0)
data_batch_size = buffer.size(0) if seq_dim > 0 else 1
if idx_batch_size != 1 and idx_batch_size != data_batch_size:
raise ValueError(
"Cannot rearrange buffer: "
@ -1079,16 +1085,20 @@ def _context_parallel_buffers(
f"but buffer has shape {buffer.shape}."
)
for i in range(data_batch_size):
index = (
load_balance_indices[0] # identical load-balance in batch
if idx_batch_size == 1
else load_balance_indices[i]
if seq_dim == 0:
buffer = torch.index_select(
buffer, dim=0, index=load_balance_indices[0]
)
buffer_batch_i = torch.index_select(
buffer[i], dim=seq_dim - 1, index=index
)
buffer[i] = buffer_batch_i
else:
indices = load_balance_indices
if idx_batch_size == 1:
size = [data_batch_size] + list(indices.size())[1:]
indices = indices.expand(*size)
for i in range(data_batch_size):
buffer[i] = torch.index_select(
buffer[i], dim=seq_dim - 1, index=indices[i]
)
# use DTensor to shard the buffer on sequence dimension, retain the local tensor
sharded_buffer = distribute_tensor(