mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension. This PR fixes the issue and also adds a test. In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792 Approved by: https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad67170c8b
commit
4740ce7787
@ -1068,10 +1068,16 @@ def _context_parallel_buffers(
|
||||
for buffer, seq_dim in zip(buffers, buffer_seq_dims):
|
||||
if isinstance(buffer, torch.Tensor):
|
||||
# TODO: the load balance doesn't perform error handling.
|
||||
|
||||
# NOTE: assuming batch dim is 0
|
||||
|
||||
if load_balance_indices is not None:
|
||||
# NOTE: assuming batch dim is 0
|
||||
# TODO: we should expclitly ask users to unsqueeze the batch dim.
|
||||
# But this is a BC breaking ask.
|
||||
# However, what we have done today is also not very safe.
|
||||
idx_batch_size = load_balance_indices.size(0)
|
||||
data_batch_size = buffer.size(0)
|
||||
data_batch_size = buffer.size(0) if seq_dim > 0 else 1
|
||||
|
||||
if idx_batch_size != 1 and idx_batch_size != data_batch_size:
|
||||
raise ValueError(
|
||||
"Cannot rearrange buffer: "
|
||||
@ -1079,16 +1085,20 @@ def _context_parallel_buffers(
|
||||
f"but buffer has shape {buffer.shape}."
|
||||
)
|
||||
|
||||
for i in range(data_batch_size):
|
||||
index = (
|
||||
load_balance_indices[0] # identical load-balance in batch
|
||||
if idx_batch_size == 1
|
||||
else load_balance_indices[i]
|
||||
if seq_dim == 0:
|
||||
buffer = torch.index_select(
|
||||
buffer, dim=0, index=load_balance_indices[0]
|
||||
)
|
||||
buffer_batch_i = torch.index_select(
|
||||
buffer[i], dim=seq_dim - 1, index=index
|
||||
)
|
||||
buffer[i] = buffer_batch_i
|
||||
else:
|
||||
indices = load_balance_indices
|
||||
if idx_batch_size == 1:
|
||||
size = [data_batch_size] + list(indices.size())[1:]
|
||||
indices = indices.expand(*size)
|
||||
|
||||
for i in range(data_batch_size):
|
||||
buffer[i] = torch.index_select(
|
||||
buffer[i], dim=seq_dim - 1, index=indices[i]
|
||||
)
|
||||
|
||||
# use DTensor to shard the buffer on sequence dimension, retain the local tensor
|
||||
sharded_buffer = distribute_tensor(
|
||||
|
Reference in New Issue
Block a user