Skip to content

naive_speculate.speculate.drafter

Define Drafter class, implementing token drafting functionality.

DraftOut

Bases: NamedTuple

Output of Drafter.draft method.

Attributes:

Name Type Description
token_ids Tensor

The token ids of the drafted tokens. Shape [batch_size, num_drafted_tokens].

token_logits Tensor

The logits used to sample the drafted tokens. Shape [batch_size, num_drafted_tokens, vocab_size].

Drafter

Drafter is able to generate draft tokens given query tokens and KV cache.

Drafter delegates token drafting to an Inferencer instance.

In the context of speculative decoding, a drafter generates draft tokens that are later verified by a more accurate model (the verifier). Typically, the drafter is a smaller but faster model than the verifier.

Attributes:

Name Type Description
inferencer Inferencer

The inferencer used for drafting tokens.

draft(query_token_ids, kv_cache, num_draft_tokens, sample_strategy)

Generate candidate tokens given query tokens and KV cache.

kv_cache will be updated internally as a side effect of this method.

Return DraftOut, which includes: - token_ids: the generated draft token ids, of shape [batch_size, num_drafted_tokens], where num_drafted_tokens <= num_draft_tokens, because the generation may stop early if the end-of-sequence token is generated. - token_logits: the logits used to sample the drafted token, of shape [batch_size, num_drafted_tokens, vocab_size].

num_query_tokens := query_token_ids.shape[1] is expected to be positive. num_draft_tokens is expected to be positive.

Parameters:

Name Type Description Default
query_token_ids Tensor

Query tokens of shape [batch_size, num_query_tokens].

required
kv_cache KVCache

Key and value tensors of past tokens.

required
num_draft_tokens int

Limit on the number of tokens to draft, should be positive.

required
sample_strategy SampleStrategy

The sampling strategy to use during generation.

required

Returns:

Name Type Description
DraftOut DraftOut

A named tuple containing: - token_ids (torch.Tensor): The generated draft token ids. - token_logits (torch.Tensor): The logits for the drafted tokens.