naive_speculate.infer.interface.kvcache¶
Define KVCache and KVState.
KVCache
¶
Bases: Protocol
Stores layerwise key and value tensors, which are used by an Inferencer during inference.
crop(num_tokens_crop)
¶
Crop the latest num_tokens_crop tokens from the cache.
num_tokens_crop should be non-negative and not exceed the current number of tokens in the cache.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_tokens_crop
|
int
|
Number of latest tokens to crop from the cache. |
required |
get_kv_states()
¶
Get the current stored KV states for all layers.
If self is just initialized and there is no KV state stored (i.e. update has never been called),
return a empty list.
Returns:
| Type | Description |
|---|---|
list[KVState]
|
list[KVState]: A list of KVState. If |
get_num_tokens()
¶
Get the current number of tokens stored in the cache.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Current number of tokens in the cache. |
update(kv_states)
¶
Update the storage with new key and value tensors.
The length of kv_states should be equal to the number of transformer layers,
and each KVState contains the new key and value tensors for the corresponding layer.
All key and value tensors in kv_states should have the same shape.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kv_states
|
Sequence[KVState]
|
New key and value tensors for each transformer layer. |
required |
KVState
¶
Bases: NamedTuple
Keys and values tensor for a single transformer layer.
The shape of keys and values are
[batch_size, num_attention_heads, num_tokens, head_dim].
Attributes:
| Name | Type | Description |
|---|---|---|
keys |
Tensor
|
The keys tensor. |
values |
Tensor
|
The values tensor. |