mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM][BugFix] Make sure that multi_modal_kwargs
can broadcast properly with ring buffer. (#5905)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@ -45,7 +45,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||
"""Split the tensor dictionary into two parts:
|
||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||
@ -473,11 +473,11 @@ class GroupCoordinator:
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
metadata_group: Optional[ProcessGroup] = None
|
||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
@ -558,9 +558,9 @@ class GroupCoordinator:
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
||||
dst: Optional[int] = None
|
||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Send the input tensor dictionary.
|
||||
NOTE: `dst` is the local rank of the source rank.
|
||||
"""
|
||||
@ -599,7 +599,7 @@ class GroupCoordinator:
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: Optional[int] = None
|
||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
||||
"""Recv the input tensor dictionary.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
@ -615,7 +615,7 @@ class GroupCoordinator:
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict = {}
|
||||
tensor_dict: Dict[str, Any] = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
@ -623,7 +623,7 @@ class GroupCoordinator:
|
||||
device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
_update_nested_dict(tensor_dict, key, tensor)
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
@ -633,9 +633,9 @@ class GroupCoordinator:
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor, src=src, group=group)
|
||||
tensor_dict[key] = tensor
|
||||
_update_nested_dict(tensor_dict, key, tensor)
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
_update_nested_dict(tensor_dict, key, value)
|
||||
return tensor_dict
|
||||
|
||||
def barrier(self):
|
||||
|
Reference in New Issue
Block a user