mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix max length for BLIP generation (#29296)
* fix mal_length for blip * update also min length * fixes * add a comment * Update src/transformers/models/instructblip/modeling_instructblip.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/blip_2/modeling_blip_2.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * make fixup * fix length when user passed * remove else * remove brackets --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
4fc708f98c
commit
bd891aed01
@ -1453,6 +1453,7 @@ class GenerationMixin:
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
|
||||
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static":
|
||||
|
@ -1827,6 +1827,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -1537,6 +1537,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user