naive_speculate.speculate.drafter¶
Define Drafter class, implementing token drafting functionality.
DraftOut
¶
Bases: NamedTuple
Output of Drafter.draft method.
Attributes:
| Name | Type | Description |
|---|---|---|
token_ids |
Tensor
|
The token ids of the drafted tokens.
Shape |
token_logits |
Tensor
|
The logits used to sample the drafted tokens.
Shape |
Drafter
¶
Drafter is able to generate draft tokens given query tokens and KV cache.
Drafter delegates token drafting to an Inferencer instance.
In the context of speculative decoding, a drafter generates draft tokens that are later verified by a more accurate model (the verifier). Typically, the drafter is a smaller but faster model than the verifier.
Attributes:
| Name | Type | Description |
|---|---|---|
inferencer |
Inferencer
|
The inferencer used for drafting tokens. |
draft(query_token_ids, kv_cache, num_draft_tokens, sample_strategy)
¶
Generate candidate tokens given query tokens and KV cache.
kv_cache will be updated internally as a side effect of this method.
Return DraftOut, which includes:
- token_ids: the generated draft token ids, of shape [batch_size, num_drafted_tokens],
where num_drafted_tokens <= num_draft_tokens, because the generation may stop early if
the end-of-sequence token is generated.
- token_logits: the logits used to sample the drafted token, of shape
[batch_size, num_drafted_tokens, vocab_size].
num_query_tokens := query_token_ids.shape[1] is expected to be positive.
num_draft_tokens is expected to be positive.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query tokens of shape |
required |
kv_cache
|
KVCache
|
Key and value tensors of past tokens. |
required |
num_draft_tokens
|
int
|
Limit on the number of tokens to draft, should be positive. |
required |
sample_strategy
|
SampleStrategy
|
The sampling strategy to use during generation. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DraftOut |
DraftOut
|
A named tuple containing: - token_ids (torch.Tensor): The generated draft token ids. - token_logits (torch.Tensor): The logits for the drafted tokens. |