mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Generalize CFG to allow for positive prompts (#25339)
* Generalize CFG to allow for positive prompts * Add documentation, fix the correct class
This commit is contained in:
parent
b0f23036f1
commit
d6bfba76be
@ -1346,9 +1346,10 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
guidance_scale (`float`):
|
guidance_scale (`float`):
|
||||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`.
|
||||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||||
prompt, usually at the expense of poorer quality.
|
prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while
|
||||||
|
making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt.
|
||||||
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||||
the last token of the prompt.
|
the last token of the prompt.
|
||||||
@ -1383,6 +1384,12 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
|||||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
|
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
|
||||||
people and injuring more than 350.
|
people and injuring more than 350.
|
||||||
|
|
||||||
|
>>> # with a positive prompt
|
||||||
|
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
|
||||||
|
>>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
|
||||||
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
|
Today, a dragon flew over Paris, France, and I'm very happy to be here.
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -904,7 +904,7 @@ class GenerationMixin:
|
|||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
|
||||||
processors.append(
|
processors.append(
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||||
generation_config.guidance_scale,
|
generation_config.guidance_scale,
|
||||||
|
Loading…
Reference in New Issue
Block a user