Skip to content

naive_speculate.infer.inferencer.utils.sample

sample_tokens(token_logits, sampling_strategy)

Sample token ids with token_logits according to sampling_strategy.

token_logits is expected to be raw logits.

Parameters:

Name Type Description Default
token_logits Tensor

Logits of shape [batch_size, vocab_size].

required
sampling_strategy SampleStrategy

Sampling strategy to use.

required

Returns:

Type Description
Tensor

torch.Tensor: Sampled next token ids of shape [batch_size, 1].