Skip to content

naive_speculate.infer.interface.inferencer

Define Inferencer, PrefillOutput and DecodeOutput.

DecodeOutput

Bases: NamedTuple

Output of Inferencer.decode method.

Attributes:

Name Type Description
token_ids Tensor

The newly generated token ids after decode. Shape [batch_size, num_generated_tokens].

token_logits Tensor

The logits used to sample the newly generated tokens. Shape [batch_size, num_generated_tokens, vocab_size].

Inferencer

Bases: Protocol

Inferencer is able to process token sequences and generate new tokens.

Inferencer processes token sequences and generates new tokens using specified sampling strategies.

Inferencer should either be itself a transformer model or be able to delegate the token processing to a transformer model. The transformer model is expected to support using KV cache to avoid redundant computations during inference.

Inferencer will update the KVCache internally when processing the query tokens.

decode(query_token_ids, kv_cache, max_new_tokens, sample_strategy)

Process query_token_ids and auto-regressively generate next new tokens.

Expect kv_cache to contain the key and value tensors for all tokens preceding the query tokens.

kv_cache will be updated internally with the newly computed key and value tensors, i.e. the key and value tensors corresponding to the query tokens.

Expect query_token_ids to contain only the new query tokens since the last call to prefill or decode, i.e., of shape [batch_size, 1].

Stop when max_new_tokens is reached or an EOS token is generated.

Return DecodeOutput, which includes: - the newly generated token ids. Shape [batch_size, num_generated_tokens]. - the logits used to sample the newly generated tokens. Shape [batch_size, num_generated_tokens, vocab_size].

Parameters:

Name Type Description Default
query_token_ids Tensor

Query token ids of shape [batch_size, 1]

required
kv_cache KVCache

Contains the past key and value tensors for each transformer layer.

required
max_new_tokens int

Limit on the number of new tokens to generate, should be positive (> 0).

required
sample_strategy SampleStrategy

Token sampling strategy during decoding.

required

Returns:

Name Type Description
DecodeOutput DecodeOutput

Contains generated new token ids of shape [batch_size, num_generated_tokens] and token logits of shape [batch_size, num_generated_tokens, vocab_size].

prefill(query_token_ids, kv_cache, sample_strategy)

Process the query_token_ids in parallel and generate the next new tokens.

Expect kv_cache to contain the key and value tensors for all tokens preceding the query tokens.

kv_cache will be updated internally with the newly computed key and value tensors, i.e. the key and value tensors corresponding to the query tokens.

Return PrefillOutput, which includes: - the generated new token ids. Shape [batch_size, 1]. - the token logits at the query token positions. Shape [batch_size, num_query_tokens, vocab_size].

Parameters:

Name Type Description Default
query_token_ids Tensor

Query token ids of shape [batch_size, num_query_tokens].

required
kv_cache KVCache

Contains the past key and value tensors for each transformer layer.

required
sample_strategy SampleStrategy

Token sampling strategy for generating new tokens.

required

Returns:

Name Type Description
PrefillOutput PrefillOutput

Contains generated new token ids of shape [batch_size, 1] and token logits of shape [batch_size, num_query_tokens, vocab_size].

PrefillOutput

Bases: NamedTuple

Output of Inferencer.prefill method.

Attributes:

Name Type Description
token_ids Tensor

The newly generated token ids after prefill. Shape [batch_size, 1].

token_logits Tensor

The logits at the query token positions. Shape [batch_size, num_query_tokens, vocab_size].