From b369e507aaa78103baf5d3f3563952b44a0408a1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 4 May 2023 18:36:23 +0100 Subject: [PATCH] Generate: text generation pipeline no longer emits `max_length` warning when it is not set (#23139) --- src/transformers/generation/flax_utils.py | 2 +- src/transformers/generation/tf_utils.py | 2 +- src/transformers/generation/utils.py | 2 +- src/transformers/pipelines/text_generation.py | 32 +++++++++++++------ .../test_pipelines_text_generation.py | 32 ++++++++++++++++++- 5 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 58a2bf13ba6..65d65869afd 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -385,7 +385,6 @@ class FlaxGenerationMixin: UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -393,6 +392,7 @@ class FlaxGenerationMixin: "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 5cd8153c2bf..5e4bc58c840 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -858,7 +858,6 @@ class TFGenerationMixin: UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -866,6 +865,7 @@ class TFGenerationMixin: "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length # If the input length is a tensor (i.e. dynamic length), skip length checks if not isinstance(input_ids_seq_length, tf.Tensor): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 06836d4d4ae..0f0191fb144 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1348,7 +1348,6 @@ class GenerationMixin: UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -1356,6 +1355,7 @@ class GenerationMixin: "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index f95acf7d307..60037339072 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,3 +1,4 @@ +import copy import enum import warnings @@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline): prefix_inputs = self.tokenizer( prefix, padding=False, add_special_tokens=False, return_tensors=self.framework ) - prefix_length = prefix_inputs["input_ids"].shape[-1] + generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1] - if "max_new_tokens" in generate_kwargs: - pass - elif "max_length" in generate_kwargs: - generate_kwargs["max_length"] += prefix_length - else: - generate_kwargs["max_length"] = self.model.config.max_length + prefix_length - - if "min_length" in generate_kwargs: - generate_kwargs["min_length"] += prefix_length if handle_long_generation is not None: if handle_long_generation not in {"hole"}: raise ValueError( @@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline): else: in_b = input_ids.shape[0] prompt_text = model_inputs.pop("prompt_text") + + # If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying + # generate_kwargs, as some of the parameterization may come from the initialization of the pipeline. + generate_kwargs = copy.deepcopy(generate_kwargs) + prefix_length = generate_kwargs.pop("prefix_length", 0) + if prefix_length > 0: + has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( + "generation_config" in generate_kwargs + and generate_kwargs["generation_config"].max_new_tokens is not None + ) + if not has_max_new_tokens: + generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length + generate_kwargs["max_length"] += prefix_length + has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( + "generation_config" in generate_kwargs + and generate_kwargs["generation_config"].min_new_tokens is not None + ) + if not has_min_new_tokens and "min_length" in generate_kwargs: + generate_kwargs["min_length"] += prefix_length + # BS x SL generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) out_b = generated_sequence.shape[0] diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 34dbef6df2d..84b14490521 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -14,8 +14,15 @@ import unittest -from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline +from transformers import ( + MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TextGenerationPipeline, + logging, + pipeline, +) from transformers.testing_utils import ( + CaptureLogger, is_pipeline_test, require_accelerate, require_tf, @@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase): pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) pipe("This is a test", do_sample=True, top_p=0.5) + + def test_pipeline_length_setting_warning(self): + prompt = """Hello world""" + text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2") + if text_generator.model.framework == "tf": + logger = logging.get_logger("transformers.generation.tf_utils") + else: + logger = logging.get_logger("transformers.generation.utils") + logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test + + # Both are set by the user -> log warning + with CaptureLogger(logger) as cl: + _ = text_generator(prompt, max_length=10, max_new_tokens=1) + self.assertIn(logger_msg, cl.out) + + # The user only sets one -> no warning + with CaptureLogger(logger) as cl: + _ = text_generator(prompt, max_new_tokens=1) + self.assertNotIn(logger_msg, cl.out) + + with CaptureLogger(logger) as cl: + _ = text_generator(prompt, max_length=10) + self.assertNotIn(logger_msg, cl.out)