naive_speculate.infer.inferencer.utils.sample¶
sample_tokens(token_logits, sampling_strategy)
¶
Sample token ids with token_logits according to sampling_strategy.
token_logits is expected to be raw logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_logits
|
Tensor
|
Logits of shape |
required |
sampling_strategy
|
SampleStrategy
|
Sampling strategy to use. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Sampled next token ids of shape |