Skip to content

naive_speculate.infer.inferencer.utils.collection

OutputCollection

Container for model intermediate outputs during decode or prefill.

The user should call update method to collect outputs on demand during decode or prefill process, and call finalize method to get the final collected outputs.

Attributes:

Name Type Description
output_ids list[Tensor]

Collected output ids, each of shape [batch_size, 1].

output_logits list[Tensor]

Collected output logits, each of shape [batch_size, 1, vocab_size].

finalize(num_tokens_trim=0)

Finalize collected outputs and return them.

Return empty tensors if no outputs have been collected or num_tokens_trim is greater than or equal to the number of collected tokens.

If num_tokens_trim <= 0, return all collected outputs.

Otherwise, trim the last num_tokens_trim tokens from the collected outputs.

Parameters:

Name Type Description Default
num_tokens_trim int

Number of tokens to trim from the end of the outputs.

0

Returns:

Type Description
tuple[Tensor, Tensor]

tuple[torch.Tensor, torch.Tensor]: A tuple of collected output ids and logits.

find(token_id, start_idx)

Find the first occurrence of a token id in the collected output ids, starting from start_idx.

For start_idx which is out of bounds (i.e. less than 0 or greater than or equal to the number of collected tokens), -1 is returned.

Incurs one device synchronization.

Currently only reasonable for batch_size=1 scenarios.

Parameters:

Name Type Description Default
token_id int

The token id to search for.

required
start_idx int

The index to start searching from.

required

Returns:

Name Type Description
int int

The index of the first occurrence of the token id within the search length. -1 if not found.

update(output_ids, output_logits)

Update collected outputs.

Parameters:

Name Type Description Default
output_ids Tensor

New output ids to collect.

required
output_logits Tensor

New output logits to collect.

required