Skip to content

naive_speculate.testing.infer.lm.transformer

Provide Transformer, which is a simple implementation of decoder-only transformer Model.

Transformer

Bases: Module

A simple implementation of a transformer model, consisting of an embedding layer and multiple transformer blocks.

forward(query_token_ids, kv_cache)

Forward the input token ids through the transformer model.

The kv_cache will be updated with the newly computed KV states as a side effect.

Parameters:

Name Type Description Default
query_token_ids Tensor

Input token IDs of shape (batch_size, seq_len).

required
kv_cache KVCache

Cache of past key and value tensors.

required

Returns:

Type Description
Tensor

torch.Tensor: Output logits of shape (batch_size, seq_len, vocab_size).