naive_speculate.infer.interface.inferencer¶
Define Inferencer, PrefillOutput and DecodeOutput.
DecodeOutput
¶
Bases: NamedTuple
Output of Inferencer.decode method.
Attributes:
| Name | Type | Description |
|---|---|---|
token_ids |
Tensor
|
The newly generated token ids after decode.
Shape |
token_logits |
Tensor
|
The logits used to sample the newly generated tokens.
Shape |
Inferencer
¶
Bases: Protocol
Inferencer is able to process token sequences and generate new tokens.
Inferencer processes token sequences and generates new tokens using specified sampling strategies.
Inferencer should either be itself a transformer model or be able to delegate the token processing to a transformer model. The transformer model is expected to support using KV cache to avoid redundant computations during inference.
Inferencer will update the KVCache internally when processing the query tokens.
decode(query_token_ids, kv_cache, max_new_tokens, sample_strategy)
¶
Process query_token_ids and auto-regressively generate next new tokens.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
Expect query_token_ids to contain only the new query tokens
since the last call to prefill or decode, i.e., of shape [batch_size, 1].
Stop when max_new_tokens is reached or an EOS token is generated.
Return DecodeOutput, which includes:
- the newly generated token ids. Shape [batch_size, num_generated_tokens].
- the logits used to sample the newly generated tokens.
Shape [batch_size, num_generated_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the past key and value tensors for each transformer layer. |
required |
max_new_tokens
|
int
|
Limit on the number of new tokens to generate, should be positive ( |
required |
sample_strategy
|
SampleStrategy
|
Token sampling strategy during decoding. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
DecodeOutput |
DecodeOutput
|
Contains generated new token ids of shape
|
prefill(query_token_ids, kv_cache, sample_strategy)
¶
Process the query_token_ids in parallel and generate the next new tokens.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
Return PrefillOutput, which includes:
- the generated new token ids. Shape [batch_size, 1].
- the token logits at the query token positions.
Shape [batch_size, num_query_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the past key and value tensors for each transformer layer. |
required |
sample_strategy
|
SampleStrategy
|
Token sampling strategy for generating new tokens. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
PrefillOutput |
PrefillOutput
|
Contains generated new token ids of shape |
PrefillOutput
¶
Bases: NamedTuple
Output of Inferencer.prefill method.
Attributes:
| Name | Type | Description |
|---|---|---|
token_ids |
Tensor
|
The newly generated token ids after prefill. Shape |
token_logits |
Tensor
|
The logits at the query token positions.
Shape |