[Bugfix] Check chain_speculative_sampling before calling it (#11673)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-01-02 16:58:56 -08:00
committed by GitHub
parent 2f1e8e8f54
commit 07064cb1d4

View File

@ -118,7 +118,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if self.use_flashinfer:
if self.use_flashinfer and chain_speculative_sampling is not None:
batch_size, k, _ = draft_probs.shape
uniform_samples = self._create_uniform_samples(
seeded_seqs, batch_size, k, draft_probs.device)