naive_speculate.infer¶
Inference basis support for speculative decoding.
Exports
LanguageModel: Interface for language models. Inferencer: Interface for inference engines. PrefillOutput: Data structure for output from prefill operations. DecodeOutput: Data structure for output from decoding operations. KVCache: Interface for key-value cache used in transformer models. KVState: Data structure for the state of the key-value cache.
DecodeOutput
¶
Bases: NamedTuple
Output of Inferencer.decode method.
Attributes:
| Name | Type | Description |
|---|---|---|
token_ids |
Tensor
|
The newly generated token ids after decode.
Shape |
token_logits |
Tensor
|
The logits used to sample the newly generated tokens.
Shape |
Inferencer
¶
Bases: Protocol
Inferencer is able to process token sequences and generate new tokens.
Inferencer processes token sequences and generates new tokens using specified sampling strategies.
Inferencer should either be itself a transformer model or be able to delegate the token processing to a transformer model. The transformer model is expected to support using KV cache to avoid redundant computations during inference.
Inferencer will update the KVCache internally when processing the query tokens.
decode(query_token_ids, kv_cache, max_new_tokens, sample_strategy)
¶
Process query_token_ids and auto-regressively generate next new tokens.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
Expect query_token_ids to contain only the new query tokens
since the last call to prefill or decode, i.e., of shape [batch_size, 1].
Stop when max_new_tokens is reached or an EOS token is generated.
Return DecodeOutput, which includes:
- the newly generated token ids. Shape [batch_size, num_generated_tokens].
- the logits used to sample the newly generated tokens.
Shape [batch_size, num_generated_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the past key and value tensors for each transformer layer. |
required |
max_new_tokens
|
int
|
Limit on the number of new tokens to generate, should be positive ( |
required |
sample_strategy
|
SampleStrategy
|
Token sampling strategy during decoding. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DecodeOutput |
DecodeOutput
|
Contains generated new token ids of shape
|
prefill(query_token_ids, kv_cache, sample_strategy)
¶
Process the query_token_ids in parallel and generate the next new tokens.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
Return PrefillOutput, which includes:
- the generated new token ids. Shape [batch_size, 1].
- the token logits at the query token positions.
Shape [batch_size, num_query_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the past key and value tensors for each transformer layer. |
required |
sample_strategy
|
SampleStrategy
|
Token sampling strategy for generating new tokens. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
PrefillOutput |
PrefillOutput
|
Contains generated new token ids of shape |
KVCache
¶
Bases: Protocol
Stores layerwise key and value tensors, which are used by an Inferencer during inference.
crop(num_tokens_crop)
¶
Crop the latest num_tokens_crop tokens from the cache.
num_tokens_crop should be non-negative and not exceed the current number of tokens in the cache.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_tokens_crop
|
int
|
Number of latest tokens to crop from the cache. |
required |
get_kv_states()
¶
Get the current stored KV states for all layers.
If self is just initialized and there is no KV state stored (i.e. update has never been called),
return a empty list.
Returns:
| Type | Description |
|---|---|
list[KVState]
|
list[KVState]: A list of KVState. If |
get_num_tokens()
¶
Get the current number of tokens stored in the cache.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Current number of tokens in the cache. |
update(kv_states)
¶
Update the storage with new key and value tensors.
The length of kv_states should be equal to the number of transformer layers,
and each KVState contains the new key and value tensors for the corresponding layer.
All key and value tensors in kv_states should have the same shape.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kv_states
|
Sequence[KVState]
|
New key and value tensors for each transformer layer. |
required |
KVState
¶
Bases: NamedTuple
Keys and values tensor for a single transformer layer.
The shape of keys and values are
[batch_size, num_attention_heads, num_tokens, head_dim].
Attributes:
| Name | Type | Description |
|---|---|---|
keys |
Tensor
|
The keys tensor. |
values |
Tensor
|
The values tensor. |
LanguageModel
¶
Bases: Protocol
LanguageModel is able to execute forward computation given input token ids and kv cache.
LanguageModel possesses pre-trained weights and model configuration such as special token ids, vocabulary size, etc.
eos_token_id
property
¶
Id of the end-of-sequence (EOS) token specified in the model's configuration.
Eos token id is used to check for the generation stopping criteria.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model configuration does not have an eos_token_id. |
forward(query_token_ids, kv_cache)
¶
Forward the query_token_ids with given kv_cache.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
(Currently, I think it simplifies the implementation, but also makes this invocation
not purely functional, further consideration may be needed in the future.)
Return the logits at every query token positions, where position i gives the logits
for sampling the token at position i+1.
The shape of output logits is [batch_size, num_query_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the key value tensors of past tokens. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Logits of shape |
PrefillOutput
¶
Bases: NamedTuple
Output of Inferencer.prefill method.
Attributes:
| Name | Type | Description |
|---|---|---|
token_ids |
Tensor
|
The newly generated token ids after prefill. Shape |
token_logits |
Tensor
|
The logits at the query token positions.
Shape |