[VLM][BugFix] Make sure that multi_modal_kwargs can broadcast properly with ring buffer. (#5905)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-06-28 00:29:13 -07:00
committed by GitHub
parent f136da15e1
commit 74d55c065b

View File

@ -45,7 +45,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
@ -473,11 +473,11 @@ class GroupCoordinator:
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
@ -558,9 +558,9 @@ class GroupCoordinator:
def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
@ -599,7 +599,7 @@ class GroupCoordinator:
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
@ -615,7 +615,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
@ -623,7 +623,7 @@ class GroupCoordinator:
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
@ -633,9 +633,9 @@ class GroupCoordinator:
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
else:
tensor_dict[key] = value
_update_nested_dict(tensor_dict, key, value)
return tensor_dict
def barrier(self):