[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)

https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792
Approved by: https://github.com/XilunWu
This commit is contained in:
Chien-Chin Huang
2025-10-17 23:41:05 -07:00
committed by PyTorch MergeBot
parent ad67170c8b
commit 4740ce7787
2 changed files with 56 additions and 11 deletions

View File

@ -771,5 +771,40 @@ class TestCPCustomOps(DTensorTestBase):
torch.library.opcheck(flex_cp_allgather, example)
class TestSharding(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(2)
@with_comms
def test_context_parallel_shard(self) -> None:
B = 4
seq_len = 32
device_mesh = init_device_mesh(
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
)
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
load_balancer = _HeadTailLoadBalancer(
seq_len, self.world_size, torch.device(self.device_type)
)
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
)
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
chunks = freqs_cis.chunk(self.world_size * 2)
self.assertEqual(
freqs_cis_shard,
torch.cat(
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
),
)
if __name__ == "__main__":
run_tests()