naive_speculate.infer.interface.lm¶
Define LanguageModel.
LanguageModel
¶
Bases: Protocol
LanguageModel is able to execute forward computation given input token ids and kv cache.
LanguageModel possesses pre-trained weights and model configuration such as special token ids, vocabulary size, etc.
eos_token_id
property
¶
Id of the end-of-sequence (EOS) token specified in the model's configuration.
Eos token id is used to check for the generation stopping criteria.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the model configuration does not have an eos_token_id. |
forward(query_token_ids, kv_cache)
¶
Forward the query_token_ids with given kv_cache.
Expect kv_cache to contain the key and value tensors for all tokens preceding
the query tokens.
kv_cache will be updated internally with the newly computed key and value tensors,
i.e. the key and value tensors corresponding to the query tokens.
(Currently, I think it simplifies the implementation, but also makes this invocation
not purely functional, further consideration may be needed in the future.)
Return the logits at every query token positions, where position i gives the logits
for sampling the token at position i+1.
The shape of output logits is [batch_size, num_query_tokens, vocab_size].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Query token ids of shape |
required |
kv_cache
|
KVCache
|
Contains the key value tensors of past tokens. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Logits of shape |