mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update logits_process.py
This commit is contained in:
parent
ab75556e28
commit
5f3eb85aa1
@ -1085,17 +1085,19 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for
|
||||
[`LogitsProcessor`] that enforces diverse beam search.
|
||||
|
||||
Note that this logits processor is only effective for
|
||||
[`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
|
||||
Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
|
||||
<Tip> again: this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. </Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Diverse beam search can be particularly useful in scenarios where a variety of different outputs is desired, rather than multiple similar sequences. It allows the model to explore different generation paths and provides a broader coverage of possible outputs.
|
||||
Diverse beam search can be particularly useful in scenarios where a variety of different outputs is desired, rather than multiple similar sequences.
|
||||
It allows the model to explore different generation paths and provides a broader coverage of possible outputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Warning>
|
||||
|
||||
This logits processor can be resource-intensive, especially when using large models or long sequences.
|
||||
@ -1166,7 +1168,7 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
|
||||
# Set up for diverse beam search
|
||||
num_beams = 6
|
||||
num_beam_groups = 2 # To generate two diverse summaries
|
||||
num_beam_groups = 2
|
||||
|
||||
model_kwargs = {
|
||||
"encoder_outputs": model.get_encoder()(
|
||||
@ -1189,7 +1191,7 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
|
||||
MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
|
||||
]
|
||||
)
|
||||
|
||||
#generate the diverse summary using group_beam_search
|
||||
outputs_diverse = model.group_beam_search(
|
||||
encoder_input_ids.repeat_interleave(num_beams, dim=0), beam_scorer, logits_processor=logits_processor_diverse, **model_kwargs
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user