naive_speculate.speculate.utils¶
Provide speculative_sampling and greedy_match utility functions.
greedy_match(target_dists, candidate_tokens)
¶
Verify candidate tokens against target distributions using greedy matching.
Expect target_dists to be proper probability distributions, instead of raw logits.
target_dists should have shape [num_draft_tokens + 1, vocab_size], where the
extra distribution at the end is for sampling new possible tokens if rejection happens.
If rejection happens at position i, the token for position i will be resampled from
the target distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_dists
|
Tensor
|
Target distributions of shape [num_draft_tokens + 1, vocab_size]. |
required |
candidate_tokens
|
Tensor
|
Candidate sequence of shape [num_draft_tokens]. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
rejected_idx |
Tensor
|
Index of the first rejected token. Scalar tensor with empty shape.
Range: |
resampled_token |
Tensor
|
Resampled token at the rejected position. Scalar tensor with empty shape.
If no rejection happens, this will be the token sampled from the extra distribution
at the end of |
speculative_sampling(target_dists, proposal_dists, candidate_tokens)
¶
Verify candidate samples against target distributions using speculative sampling.
Expect target_dists and proposal_dists to be proper probability distributions,
instead of raw logits.
target_dists should have shape [num_draft_tokens + 1, vocab_size], where the
extra distribution at the end is for sampling new possible tokens if rejection happens.
If rejection happens at position i, the token for position i will be resampled from
the residual distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_dists
|
Tensor
|
Target distributions of shape [num_draft_tokens + 1, vocab_size]. |
required |
proposal_dists
|
Tensor
|
Proposal distributions of shape [num_draft_tokens, vocab_size]. |
required |
candidate_tokens
|
Tensor
|
Candidate tokens of shape [num_draft_tokens]. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
rejected_idx |
Tensor
|
Index of the first rejected token. Scalar tensor with empty shape.
Range: |
resampled_token |
Tensor
|
Resampled token at the rejected position.
Scalar tensor with empty shape. If no rejection happens,
this will be the token sampled from the extra distribution at the end of |