Generalize CFG to allow for positive prompts (#25339)

* Generalize CFG to allow for positive prompts

* Add documentation, fix the correct class
This commit is contained in:
oobabooga 2023-08-07 11:25:15 -03:00 committed by GitHub
parent b0f23036f1
commit d6bfba76be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -1346,9 +1346,10 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
Args:
guidance_scale (`float`):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while
making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt.
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
the last token of the prompt.
@ -1383,6 +1384,12 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
people and injuring more than 350.
>>> # with a positive prompt
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
>>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
Today, a dragon flew over Paris, France, and I'm very happy to be here.
```
"""

View File

@ -904,7 +904,7 @@ class GenerationMixin:
# instantiate processors list
processors = LogitsProcessorList()
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
processors.append(
UnbatchedClassifierFreeGuidanceLogitsProcessor(
generation_config.guidance_scale,