Skip to content

naive_speculate.infer.interface.kvcache

Define KVCache and KVState.

KVCache

Bases: Protocol

Stores layerwise key and value tensors, which are used by an Inferencer during inference.

crop(num_tokens_crop)

Crop the latest num_tokens_crop tokens from the cache.

num_tokens_crop should be non-negative and not exceed the current number of tokens in the cache.

Parameters:

Name Type Description Default
num_tokens_crop int

Number of latest tokens to crop from the cache.

required

get_kv_states()

Get the current stored KV states for all layers.

If self is just initialized and there is no KV state stored (i.e. update has never been called), return a empty list.

Returns:

Type Description
list[KVState]

list[KVState]: A list of KVState. If update has never been called, it is empty. Otherwise its length is equal to the number of transformer layers.

get_num_tokens()

Get the current number of tokens stored in the cache.

Returns:

Name Type Description
int int

Current number of tokens in the cache.

update(kv_states)

Update the storage with new key and value tensors.

The length of kv_states should be equal to the number of transformer layers, and each KVState contains the new key and value tensors for the corresponding layer.

All key and value tensors in kv_states should have the same shape.

Parameters:

Name Type Description Default
kv_states Sequence[KVState]

New key and value tensors for each transformer layer.

required

KVState

Bases: NamedTuple

Keys and values tensor for a single transformer layer.

The shape of keys and values are [batch_size, num_attention_heads, num_tokens, head_dim].

Attributes:

Name Type Description
keys Tensor

The keys tensor.

values Tensor

The values tensor.