[Bugfix] multi_modal_kwargs broadcast for CPU tensor parallel (#10541)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-23 13:25:46 +08:00
committed by GitHub
parent c8acd80548
commit 4cfe5d2bca
2 changed files with 2 additions and 0 deletions

View File

@ -35,6 +35,7 @@ class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,

View File

@ -83,6 +83,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,