Skip to content

naive_speculate.infer

Inference basis support for speculative decoding.

Exports

LanguageModel: Interface for language models. Inferencer: Interface for inference engines. PrefillOutput: Data structure for output from prefill operations. DecodeOutput: Data structure for output from decoding operations. KVCache: Interface for key-value cache used in transformer models. KVState: Data structure for the state of the key-value cache.

DecodeOutput

Bases: NamedTuple

Output of Inferencer.decode method.

Attributes:

Name Type Description
token_ids Tensor

The newly generated token ids after decode. Shape [batch_size, num_generated_tokens].

token_logits Tensor

The logits used to sample the newly generated tokens. Shape [batch_size, num_generated_tokens, vocab_size].

Inferencer

Bases: Protocol

Inferencer is able to process token sequences and generate new tokens.

Inferencer processes token sequences and generates new tokens using specified sampling strategies.

Inferencer should either be itself a transformer model or be able to delegate the token processing to a transformer model. The transformer model is expected to support using KV cache to avoid redundant computations during inference.

Inferencer will update the KVCache internally when processing the query tokens.

decode(query_token_ids, kv_cache, max_new_tokens, sample_strategy)

Process query_token_ids and auto-regressively generate next new tokens.

Expect kv_cache to contain the key and value tensors for all tokens preceding the query tokens.

kv_cache will be updated internally with the newly computed key and value tensors, i.e. the key and value tensors corresponding to the query tokens.

Expect query_token_ids to contain only the new query tokens since the last call to prefill or decode, i.e., of shape [batch_size, 1].

Stop when max_new_tokens is reached or an EOS token is generated.

Return DecodeOutput, which includes: - the newly generated token ids. Shape [batch_size, num_generated_tokens]. - the logits used to sample the newly generated tokens. Shape [batch_size, num_generated_tokens, vocab_size].

Parameters:

Name Type Description Default
query_token_ids Tensor

Query token ids of shape [batch_size, 1]

required
kv_cache KVCache

Contains the past key and value tensors for each transformer layer.

required
max_new_tokens int

Limit on the number of new tokens to generate, should be positive (> 0).

required
sample_strategy SampleStrategy

Token sampling strategy during decoding.

required

Returns:

Name Type Description
DecodeOutput DecodeOutput

Contains generated new token ids of shape [batch_size, num_generated_tokens] and token logits of shape [batch_size, num_generated_tokens, vocab_size].

prefill(query_token_ids, kv_cache, sample_strategy)

Process the query_token_ids in parallel and generate the next new tokens.

Expect kv_cache to contain the key and value tensors for all tokens preceding the query tokens.

kv_cache will be updated internally with the newly computed key and value tensors, i.e. the key and value tensors corresponding to the query tokens.

Return PrefillOutput, which includes: - the generated new token ids. Shape [batch_size, 1]. - the token logits at the query token positions. Shape [batch_size, num_query_tokens, vocab_size].

Parameters:

Name Type Description Default
query_token_ids Tensor

Query token ids of shape [batch_size, num_query_tokens].

required
kv_cache KVCache

Contains the past key and value tensors for each transformer layer.

required
sample_strategy SampleStrategy

Token sampling strategy for generating new tokens.

required

Returns:

Name Type Description
PrefillOutput PrefillOutput

Contains generated new token ids of shape [batch_size, 1] and token logits of shape [batch_size, num_query_tokens, vocab_size].

KVCache

Bases: Protocol

Stores layerwise key and value tensors, which are used by an Inferencer during inference.

crop(num_tokens_crop)

Crop the latest num_tokens_crop tokens from the cache.

num_tokens_crop should be non-negative and not exceed the current number of tokens in the cache.

Parameters:

Name Type Description Default
num_tokens_crop int

Number of latest tokens to crop from the cache.

required

get_kv_states()

Get the current stored KV states for all layers.

If self is just initialized and there is no KV state stored (i.e. update has never been called), return a empty list.

Returns:

Type Description
list[KVState]

list[KVState]: A list of KVState. If update has never been called, it is empty. Otherwise its length is equal to the number of transformer layers.

get_num_tokens()

Get the current number of tokens stored in the cache.

Returns:

Name Type Description
int int

Current number of tokens in the cache.

update(kv_states)

Update the storage with new key and value tensors.

The length of kv_states should be equal to the number of transformer layers, and each KVState contains the new key and value tensors for the corresponding layer.

All key and value tensors in kv_states should have the same shape.

Parameters:

Name Type Description Default
kv_states Sequence[KVState]

New key and value tensors for each transformer layer.

required

KVState

Bases: NamedTuple

Keys and values tensor for a single transformer layer.

The shape of keys and values are [batch_size, num_attention_heads, num_tokens, head_dim].

Attributes:

Name Type Description
keys Tensor

The keys tensor.

values Tensor

The values tensor.

LanguageModel

Bases: Protocol

LanguageModel is able to execute forward computation given input token ids and kv cache.

LanguageModel possesses pre-trained weights and model configuration such as special token ids, vocabulary size, etc.

eos_token_id property

Id of the end-of-sequence (EOS) token specified in the model's configuration.

Eos token id is used to check for the generation stopping criteria.

Raises:

Type Description
ValueError

If the model configuration does not have an eos_token_id.

forward(query_token_ids, kv_cache)

Forward the query_token_ids with given kv_cache.

Expect kv_cache to contain the key and value tensors for all tokens preceding the query tokens.

kv_cache will be updated internally with the newly computed key and value tensors, i.e. the key and value tensors corresponding to the query tokens. (Currently, I think it simplifies the implementation, but also makes this invocation not purely functional, further consideration may be needed in the future.)

Return the logits at every query token positions, where position i gives the logits for sampling the token at position i+1. The shape of output logits is [batch_size, num_query_tokens, vocab_size].

Parameters:

Name Type Description Default
query_token_ids Tensor

Query token ids of shape [batch_size, num_query_tokens].

required
kv_cache KVCache

Contains the key value tensors of past tokens.

required

Returns:

Type Description
Tensor

torch.Tensor: Logits of shape [batch_size, num_query_tokens, vocab_size].

PrefillOutput

Bases: NamedTuple

Output of Inferencer.prefill method.

Attributes:

Name Type Description
token_ids Tensor

The newly generated token ids after prefill. Shape [batch_size, 1].

token_logits Tensor

The logits at the query token positions. Shape [batch_size, num_query_tokens, vocab_size].