[BugFix] Update draft model TP size check to allow matching target TP size (#9394)

Co-authored-by: Baoyuan Qi <qibaoyuan@126.com>
This commit is contained in:
Nick Hill
2024-10-21 22:14:29 +01:00
committed by GitHub
parent d621c43df7
commit 15713e3b75

View File

@ -1408,11 +1408,11 @@ class SpeculativeConfig:
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
elif speculative_draft_tensor_parallel_size not in (
1, target_parallel_config.tensor_parallel_size):
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")
f"other value than 1 or target model tensor_parallel_size")
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.