naive_speculate.infer.inferencer.utils.collection¶
OutputCollection
¶
Container for model intermediate outputs during decode or prefill.
The user should call update method to collect outputs on demand during
decode or prefill process, and call finalize method to get the final collected outputs.
Attributes:
| Name | Type | Description |
|---|---|---|
output_ids |
list[Tensor]
|
Collected output ids, each of shape |
output_logits |
list[Tensor]
|
Collected output logits, each of shape |
finalize(num_tokens_trim=0)
¶
Finalize collected outputs and return them.
Return empty tensors if no outputs have been collected or
num_tokens_trim is greater than or equal to the number of collected tokens.
If num_tokens_trim <= 0, return all collected outputs.
Otherwise, trim the last num_tokens_trim tokens from the collected outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_tokens_trim
|
int
|
Number of tokens to trim from the end of the outputs. |
0
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
tuple[torch.Tensor, torch.Tensor]: A tuple of collected output ids and logits. |
find(token_id, start_idx)
¶
Find the first occurrence of a token id in the collected output ids, starting from start_idx.
For start_idx which is out of bounds (i.e. less than 0 or greater than or equal to
the number of collected tokens), -1 is returned.
Incurs one device synchronization.
Currently only reasonable for batch_size=1 scenarios.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_id
|
int
|
The token id to search for. |
required |
start_idx
|
int
|
The index to start searching from. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
The index of the first occurrence of the token id within the search length. -1 if not found. |
update(output_ids, output_logits)
¶
Update collected outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_ids
|
Tensor
|
New output ids to collect. |
required |
output_logits
|
Tensor
|
New output logits to collect. |
required |