mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792 Approved by: https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad67170c8b
commit
4740ce7787
@ -771,5 +771,40 @@ class TestCPCustomOps(DTensorTestBase):
|
||||
torch.library.opcheck(flex_cp_allgather, example)
|
||||
|
||||
|
||||
class TestSharding(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_context_parallel_shard(self) -> None:
|
||||
B = 4
|
||||
seq_len = 32
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
mesh_shape=(2,), mesh_dim_names=("cp",), device_type=self.device_type
|
||||
)
|
||||
freqs_cis = torch.arange(0, seq_len, device=self.device_type)
|
||||
q = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
k = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
v = torch.ones(B * seq_len, device=self.device_type).reshape(B, seq_len)
|
||||
|
||||
load_balancer = _HeadTailLoadBalancer(
|
||||
seq_len, self.world_size, torch.device(self.device_type)
|
||||
)
|
||||
freqs_cis_shard, q_shard, k_shard, v_shard = _context_parallel_shard(
|
||||
device_mesh, [freqs_cis, q, k, v], [0, 1, 1, 1], load_balancer=load_balancer
|
||||
)
|
||||
self.assertEqual(freqs_cis_shard.size(), (seq_len // 2,))
|
||||
chunks = freqs_cis.chunk(self.world_size * 2)
|
||||
self.assertEqual(
|
||||
freqs_cis_shard,
|
||||
torch.cat(
|
||||
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1068,10 +1068,16 @@ def _context_parallel_buffers(
|
||||
for buffer, seq_dim in zip(buffers, buffer_seq_dims):
|
||||
if isinstance(buffer, torch.Tensor):
|
||||
# TODO: the load balance doesn't perform error handling.
|
||||
|
||||
# NOTE: assuming batch dim is 0
|
||||
|
||||
if load_balance_indices is not None:
|
||||
# NOTE: assuming batch dim is 0
|
||||
# TODO: we should expclitly ask users to unsqueeze the batch dim.
|
||||
# But this is a BC breaking ask.
|
||||
# However, what we have done today is also not very safe.
|
||||
idx_batch_size = load_balance_indices.size(0)
|
||||
data_batch_size = buffer.size(0)
|
||||
data_batch_size = buffer.size(0) if seq_dim > 0 else 1
|
||||
|
||||
if idx_batch_size != 1 and idx_batch_size != data_batch_size:
|
||||
raise ValueError(
|
||||
"Cannot rearrange buffer: "
|
||||
@ -1079,16 +1085,20 @@ def _context_parallel_buffers(
|
||||
f"but buffer has shape {buffer.shape}."
|
||||
)
|
||||
|
||||
for i in range(data_batch_size):
|
||||
index = (
|
||||
load_balance_indices[0] # identical load-balance in batch
|
||||
if idx_batch_size == 1
|
||||
else load_balance_indices[i]
|
||||
if seq_dim == 0:
|
||||
buffer = torch.index_select(
|
||||
buffer, dim=0, index=load_balance_indices[0]
|
||||
)
|
||||
buffer_batch_i = torch.index_select(
|
||||
buffer[i], dim=seq_dim - 1, index=index
|
||||
)
|
||||
buffer[i] = buffer_batch_i
|
||||
else:
|
||||
indices = load_balance_indices
|
||||
if idx_batch_size == 1:
|
||||
size = [data_batch_size] + list(indices.size())[1:]
|
||||
indices = indices.expand(*size)
|
||||
|
||||
for i in range(data_batch_size):
|
||||
buffer[i] = torch.index_select(
|
||||
buffer[i], dim=seq_dim - 1, index=indices[i]
|
||||
)
|
||||
|
||||
# use DTensor to shard the buffer on sequence dimension, retain the local tensor
|
||||
sharded_buffer = distribute_tensor(
|
||||
|
Reference in New Issue
Block a user