TF: generate without tf.TensorArray (#17801)

This commit is contained in:
Joao Gante 2022-06-23 12:28:08 +01:00 committed by GitHub
parent ab223fc148
commit 5cce3076c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 200 deletions

View File

@ -16,7 +16,6 @@
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
@ -1979,6 +1978,8 @@ class TFGenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
use_xla = not tf.executing_eagerly()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
@ -1989,34 +1990,25 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = input_ids.shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray(
element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated
for i in range(cur_len):
generated = generated.write(i, input_ids[:, i])
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
def greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
"""state termination condition fn."""
return ~tf.reduce_all(finished_sequences)
# define condition fn
def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
"""state update fn."""
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
if model_kwargs.get("past") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token logits
outputs = self(
**model_inputs,
@ -2043,8 +2035,7 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
# argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
@ -2057,8 +2048,8 @@ class TFGenerationMixin:
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len`
generated = generated.write(cur_len, next_tokens)
next_tokens = next_tokens[:, None]
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1
# update model_kwargs
@ -2073,34 +2064,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
return generated, finished_sequences, cur_len, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs
generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, cur_len, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, _, cur_len, _ = tf.while_loop(
generated, _, cur_len, _ = tf.while_loop(
greedy_search_cond_fn,
greedy_search_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, :cur_len]
generated = generated[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
@ -2117,7 +2103,7 @@ class TFGenerationMixin:
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFGreedySearchEncoderDecoderOutput(
sequences=output_ids,
sequences=generated,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
@ -2127,13 +2113,13 @@ class TFGenerationMixin:
)
else:
return TFGreedySearchDecoderOnlyOutput(
sequences=output_ids,
sequences=generated,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return output_ids
return generated
def sample(
self,
@ -2250,6 +2236,8 @@ class TFGenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
use_xla = not tf.executing_eagerly()
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
@ -2261,29 +2249,20 @@ class TFGenerationMixin:
batch_size, cur_len = input_ids.shape
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
generated = tf.TensorArray(
element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
# write prompt to generated
for i in range(cur_len):
generated = generated.write(i, input_ids[:, i])
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
def sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
return ~tf.reduce_all(finished_sequences)
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
if model_kwargs.get("past") is None or needs_full_input:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token logits
outputs = self(
**model_inputs,
@ -2310,9 +2289,8 @@ class TFGenerationMixin:
decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len)
# sample
if seed is not None:
@ -2334,8 +2312,8 @@ class TFGenerationMixin:
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len`
generated = generated.write(cur_len, next_tokens)
next_tokens = next_tokens[:, None]
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1
# update model_kwargs
@ -2350,34 +2328,29 @@ class TFGenerationMixin:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
next_tokens = tf.transpose(next_tokens[:cur_len])
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
return generated, finished_sequences, cur_len, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, input_ids, cur_len, model_kwargs
generated, finished_sequences, cur_len, model_kwargs = sample_body_fn(
generated, finished_sequences, cur_len, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
maximum_iterations = max_length - cur_len
generated, _, _, cur_len, _ = tf.while_loop(
generated, _, cur_len, _ = tf.while_loop(
sample_cond_fn,
sample_body_fn,
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
(generated, finished_sequences, cur_len, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, :cur_len]
generated = generated[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
@ -2394,7 +2367,7 @@ class TFGenerationMixin:
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFSampleEncoderDecoderOutput(
sequences=output_ids,
sequences=generated,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
@ -2404,13 +2377,13 @@ class TFGenerationMixin:
)
else:
return TFSampleDecoderOnlyOutput(
sequences=output_ids,
sequences=generated,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return output_ids
return generated
def beam_search(
self,
@ -2585,6 +2558,8 @@ class TFGenerationMixin:
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# some models, like XLNet, need more than the last token in the presence of past
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
@ -2594,41 +2569,13 @@ class TFGenerationMixin:
# 3. init tensors to use for "xla-compileable" generate function
batch_size, num_beams, cur_len = input_ids.shape
input_ids_length = cur_len
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
pad_token_id or 0
)
running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
intermediary_running_sequences = tf.TensorArray(
element_shape=(batch_size, num_beams * 2),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
if pad_token_id: # ignores the cases when it is 0 or None
for i in range(max_length):
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
intermediary_running_sequences = intermediary_running_sequences.write(
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
)
# write prompt to running_sequences
for i in range(cur_len):
running_sequences = running_sequences.write(i, input_ids[:, :, i])
running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0)
# per batch,beam-item state bit indicating if sentence has finished.
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
@ -2656,7 +2603,6 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
):
"""
@ -2685,27 +2631,18 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
intermediary_running_sequences=None,
):
"""
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
seen so far
"""
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
# 1. Forward current tokens
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
input_token = tf.slice(
running_sequences_seq_last,
(0, 0, cur_len - input_ids_length),
(batch_size, num_beams, input_ids_length),
)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
if model_kwargs.get("past") is None or needs_full_input:
input_ids = running_sequences[:, :, :cur_len]
else:
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs)
model_outputs = self(
**model_inputs,
return_dict=True,
@ -2734,9 +2671,7 @@ class TFGenerationMixin:
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
vocab_size = log_probs.shape[2]
@ -2755,23 +2690,28 @@ class TFGenerationMixin:
beams_to_keep = 2 * num_beams
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices)
topk_running_sequences = gather_beams(running_sequences, topk_beam_indices)
topk_ids = topk_indices % vocab_size
# writes the new token
intermediary_running_sequences = intermediary_running_sequences.unstack(
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1])
indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep])
indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size])
update_indices = tf.stack(
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
)
topk_sequences = tf.tensor_scatter_nd_update(
tensor=topk_running_sequences,
indices=update_indices,
updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]),
)
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])
# 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences
# set of active beam search sequences, set their log probs to a very large negative value.
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id
eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id
if eos_token_id is None:
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape)
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
eos_in_next_token.shape,
@ -2785,8 +2725,8 @@ class TFGenerationMixin:
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams).
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences_seq_last, next_running_scores = gather_beams(
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, running_topk_log_probs], next_topk_indices
)
# 6. Process topk logits
@ -2807,18 +2747,18 @@ class TFGenerationMixin:
# 7. Get scores, sequences, is sentence finished for next.
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
# to existing finished scores and select the best from the new set of beams
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
merged_sequences = tf.concat([sequences, topk_sequences], axis=1)
merged_scores = tf.concat([scores, topk_log_probs], axis=1)
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams(
next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
)
# 8. Prepare data for the next iteration
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
# beam-associated caches.
cur_len = cur_len + 1
if "past_key_values" in model_outputs:
cache = tf.nest.map_structure(
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
@ -2841,35 +2781,20 @@ class TFGenerationMixin:
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
next_input_ids_length = cur_len + 1
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
else:
next_input_ids_length = 1
# 9. Prepare the `tf.TensorArray` for the next iteration
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
next_running_sequences = running_sequences.unstack(
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
)
return (
cur_len + 1,
cur_len,
next_running_sequences,
next_running_scores,
next_sequences,
next_scores,
next_is_sent_finished,
next_input_ids_length,
next_model_kwargs,
)
# 5. run generation
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
beam_search_body_fn = partial(
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
)
# 1st generation step has to be run before to initialize `past` (if active)
(
cur_len,
@ -2878,66 +2803,38 @@ class TFGenerationMixin:
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
) = beam_search_body_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though)
if beam_search_cond_fn(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
):
maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
beam_search_cond_fn,
beam_search_body_fn,
(
cur_len,
running_sequences,
running_scores,
sequences,
scores,
is_sent_finished,
input_ids_length,
model_kwargs,
),
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
# running sequences for that batch item.
none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last)
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in ascending order)
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :])
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])
if not use_xla:
# Cut for backward compatibility
sequences_seq_last = sequences_seq_last[:, :cur_len]
sequences = sequences[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
@ -2948,7 +2845,7 @@ class TFGenerationMixin:
)
return TFBeamSearchEncoderDecoderOutput(
sequences=sequences_seq_last,
sequences=sequences,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
@ -2958,13 +2855,13 @@ class TFGenerationMixin:
)
else:
return TFBeamSearchDecoderOnlyOutput(
sequences=sequences_seq_last,
sequences=sequences,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return sequences_seq_last
return sequences
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):

View File

@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
# -1 because current_pos has already been incremented before this function
# -1 again because last index = len - 1
new_past_index = current_pos - 2
for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]

View File

@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset = 2
if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else:
inputs = tf.concat([inputs, dummy_token], axis=1)
input_ids = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1]
sequence_length = input_ids.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {
"input_ids": inputs,
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_mems": use_mems,