diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5b7d18e06c1..4bfae470cf7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1453,6 +1453,7 @@ class GenerationMixin: and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static": diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 00433f3ea34..3e63fac66fd 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1827,6 +1827,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + outputs = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index e175cd57285..da0b02551ff 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1537,6 +1537,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) + # add image_embeds length to max_length, so that the final max_length in counted only on token embeds + if not self.language_model.config.is_encoder_decoder: + generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] + generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] + outputs = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask,