Skip to content

naive_speculate.speculate.utils

Provide speculative_sampling and greedy_match utility functions.

greedy_match(target_dists, candidate_tokens)

Verify candidate tokens against target distributions using greedy matching.

Expect target_dists to be proper probability distributions, instead of raw logits.

target_dists should have shape [num_draft_tokens + 1, vocab_size], where the extra distribution at the end is for sampling new possible tokens if rejection happens.

If rejection happens at position i, the token for position i will be resampled from the target distribution.

Parameters:

Name Type Description Default
target_dists Tensor

Target distributions of shape [num_draft_tokens + 1, vocab_size].

required
candidate_tokens Tensor

Candidate sequence of shape [num_draft_tokens].

required

Returns:

Name Type Description
rejected_idx Tensor

Index of the first rejected token. Scalar tensor with empty shape. Range: [0, num_draft_tokens]. If no rejection happens, equal to num_draft_tokens.

resampled_token Tensor

Resampled token at the rejected position. Scalar tensor with empty shape. If no rejection happens, this will be the token sampled from the extra distribution at the end of target_dists.

speculative_sampling(target_dists, proposal_dists, candidate_tokens)

Verify candidate samples against target distributions using speculative sampling.

Expect target_dists and proposal_dists to be proper probability distributions, instead of raw logits.

target_dists should have shape [num_draft_tokens + 1, vocab_size], where the extra distribution at the end is for sampling new possible tokens if rejection happens.

If rejection happens at position i, the token for position i will be resampled from the residual distribution.

Parameters:

Name Type Description Default
target_dists Tensor

Target distributions of shape [num_draft_tokens + 1, vocab_size].

required
proposal_dists Tensor

Proposal distributions of shape [num_draft_tokens, vocab_size].

required
candidate_tokens Tensor

Candidate tokens of shape [num_draft_tokens].

required

Returns:

Name Type Description
rejected_idx Tensor

Index of the first rejected token. Scalar tensor with empty shape. Range: [0, num_draft_tokens]. If no rejection happens, equal to num_draft_tokens.

resampled_token Tensor

Resampled token at the rejected position. Scalar tensor with empty shape. If no rejection happens, this will be the token sampled from the extra distribution at the end of target_dists.