mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Generalize CFG to allow for positive prompts (#25339)
* Generalize CFG to allow for positive prompts * Add documentation, fix the correct class
This commit is contained in:
parent
b0f23036f1
commit
d6bfba76be
@ -1346,9 +1346,10 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while
|
||||
making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt.
|
||||
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||
the last token of the prompt.
|
||||
@ -1383,6 +1384,12 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
|
||||
people and injuring more than 350.
|
||||
|
||||
>>> # with a positive prompt
|
||||
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
|
||||
>>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
Today, a dragon flew over Paris, France, and I'm very happy to be here.
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -904,7 +904,7 @@ class GenerationMixin:
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
|
||||
processors.append(
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
generation_config.guidance_scale,
|
||||
|
Loading…
Reference in New Issue
Block a user