there is no var `decoder_input_ids`, but there is `input_ids` for decoder :)
This commit is contained in:
Stas Bekman 2020-09-07 02:16:24 -07:00 committed by GitHub
parent 10c6f94adc
commit c3317e1f80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -411,7 +411,7 @@ class GenerationMixin:
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
# create empty decoder_input_ids
# create empty decoder input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
decoder_start_token_id,