naive_speculate.testing.infer.lm.transformer¶
Provide Transformer, which is a simple implementation of decoder-only transformer Model.
Transformer
¶
Bases: Module
A simple implementation of a transformer model, consisting of an embedding layer and multiple transformer blocks.
forward(query_token_ids, kv_cache)
¶
Forward the input token ids through the transformer model.
The kv_cache will be updated with the newly computed KV states as a side effect.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query_token_ids
|
Tensor
|
Input token IDs of shape (batch_size, seq_len). |
required |
kv_cache
|
KVCache
|
Cache of past key and value tensors. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Output logits of shape (batch_size, seq_len, vocab_size). |