diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 638bb3b12e6..baa26202655 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -411,7 +411,7 @@ class GenerationMixin: ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - # create empty decoder_input_ids + # create empty decoder input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), decoder_start_token_id,