mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792 Approved by: https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad67170c8b
commit
4740ce7787
@ -771,5 +771,40 @@ class TestCPCustomOps(DTensorTestBase):
|
||||
torch.library.opcheck(flex_cp_allgather, example)
|
||||
|
||||
|
||||
class TestSharding(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_context_parallel_shard(self) -> None:
|
||||
B = 4
|
||||
seq_len = 32
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
|
||||
)
|
||||
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
|
||||
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
|
||||
load_balancer = _HeadTailLoadBalancer(
|
||||
seq_len, self.world_size, torch.device(self.device_type)
|
||||
)
|
||||
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
|
||||
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
|
||||
)
|
||||
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
|
||||
chunks = freqs_cis.chunk(self.world_size * 2)
|
||||
self.assertEqual(
|
||||
freqs_cis_shard,
|
||||
torch.cat(
|
||||
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user