Skip to content

naive_speculate.speculate

Speculative decoding implementations based on Drafter and Scorer interfaces.

Exports

SpeculativeDecoder: Entry class for speculative decoding. SpeculativeDecodeOut: The output of SpeculativeDecoder.decode method.

SpeculativeDecodeOut

Bases: NamedTuple

The output of SpeculativeDecoder.speculative_decode method.

Attributes:

Name Type Description
token_ids Tensor

The generated token ids. Shape: [batch_size, num_generated_tokens], where num_generated_tokens <= num_draft_tokens + 1.

num_accepted_tokens int

The number of accepted drafted tokens.

SpeculativeDecoder

Performs speculative decoding using a drafter and a scorer.

Attributes:

Name Type Description
drafter Drafter

The drafter used for drafting tokens.

scorer Scorer

The scorer used for scoring tokens.

drafter_kvcache KVCache

The key-value cache used for the drafter.

scorer_kvcache KVCache

The key-value cache used for the scorer.

decode(query_token_ids, num_draft_tokens, sample_strategy, verify_strategy)

Perform speculative decoding.

Currently supports batch_size=1 only.

Decoding stops when <eos> is generated, or when the number of returned tokens (accepted draft tokens plus one resampled token) would exceed num_draft_tokens + 1.

sample_strategy defines how the drafter samples draft tokens. verify_strategy defines how the drafted tokens are verified.

The GREEDY_MATCH verification strategy is legal to combine with any drafter sampling strategy.

However, to achieve real speedup, it is suggested to use SPECULATIVE_SAMPLING verification strategy in favor of GREEDY_MATCH, because the latter normally leads to more rejections, since it requires exact matches between the drafter's greedy tokens and the target model's greedy tokens.

Also, greedy decoding normally performs worse than random sampling decoding in terms of generation quality.

The SPECULATIVE_SAMPLING verification strategy is legal to combine with any drafter sampling strategy in definition, as long as the sampling strategy defines a valid proposal distribution.

However, to achieve real speedup, it is suggested to not use SPECULATIVE_SAMPLING verification with drafter greedy sampling.

The reason is: If the drafter uses greedy sampling, speculative sampling for verification will not be a legal option, because in this case the drafter's distribution is always a delta distribution, which makes the acceptance probability ill-defined. In this case, rejection always happens for each draft token, and the speculative decoding degenerates to modified auto-regressive decoding, with the drafted token being removed from the vocabulary in sampling.

For consecutive call of speculative decoding, special attention should be paid to the drafter's kvcache, and the input query token for draft in the next round. Here we discuss it in three situation: 1. If no token is accepted The scorer and the drafter both cached the input query tokens. In next round, the drafter's input query token is the newly resampled token by speculative decoder. 2. If all tokens are accepted The scorer's kvcache will be one token longer than the drafter's kvcache, because the drafter does not have the last drafted token's kvcache. In next round, the drafter's input query tokens should be the concatenation of the last drafted token and the newly sampled token by speculative decoder. 3. if some but not all tokens are accepted For example, one token is accepted, the length of scorer's and drafter's kvcache is consistent. In the next round, the drafter's input query token is the newly resampled token by speculative decoder.

Therefore, the caller should pay attention to the num_accepted_tokens field in the output of decode method, and construct the right input query token for the next round accordingly, otherwise the kvcache of the drafter will be wrong, which may lead to wrong generation results or even errors.

Parameters:

Name Type Description Default
query_token_ids Tensor

Ids of the query tokens. Shape: [batch_size, num_query_tokens].

required
num_draft_tokens int

Number of tokens to draft.

required
sample_strategy SampleStrategy

Sampling strategy for drafting tokens.

required
verify_strategy VerifyStrategy

Verification strategy for drafted tokens.

required

Returns:

Name Type Description
SpeculativeDecodeOut SpeculativeDecodeOut

Include generated token ids and number of accepted drafted tokens.