Fix max length for BLIP generation (#29296)

* fix mal_length for blip

* update also min length

* fixes

* add a comment

* Update src/transformers/models/instructblip/modeling_instructblip.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip_2/modeling_blip_2.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* make fixup

* fix length when user passed

* remove else

* remove brackets

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay 2024-03-05 12:18:22 +05:00 committed by GitHub
parent 4fc708f98c
commit bd891aed01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 0 deletions

View File

@ -1453,6 +1453,7 @@ class GenerationMixin:
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static":

View File

@ -1827,6 +1827,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,

View File

@ -1537,6 +1537,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,