mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Generate: multi-device support for contrastive search (#24635)
This commit is contained in:
parent
4b26a61631
commit
9934bb1f42
@ -2060,8 +2060,10 @@ class GenerationMixin:
|
||||
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
||||
|
||||
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
||||
# model confidence
|
||||
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
|
||||
# introduce (noticeable) slowdowns on single-device runs.
|
||||
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
|
||||
selected_idx = selected_idx.to("cpu")
|
||||
|
||||
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
|
||||
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
|
||||
|
Loading…
Reference in New Issue
Block a user