mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
TF: generate without tf.TensorArray
(#17801)
This commit is contained in:
parent
ab223fc148
commit
5cce3076c4
@ -16,7 +16,6 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -1979,6 +1978,8 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
use_xla = not tf.executing_eagerly()
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
@ -1989,34 +1990,25 @@ class TFGenerationMixin:
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||
|
||||
# write prompt to generated
|
||||
for i in range(cur_len):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
|
||||
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
||||
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
# define condition fn
|
||||
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
def greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
"""state termination condition fn."""
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
# define condition fn
|
||||
def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
def greedy_search_body_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
"""state update fn."""
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
input_ids = generated[:, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -2043,8 +2035,7 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
|
||||
|
||||
# argmax
|
||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||
@ -2057,8 +2048,8 @@ class TFGenerationMixin:
|
||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||
|
||||
# update `generated` and `cur_len`
|
||||
generated = generated.write(cur_len, next_tokens)
|
||||
next_tokens = next_tokens[:, None]
|
||||
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
|
||||
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
|
||||
cur_len += 1
|
||||
|
||||
# update model_kwargs
|
||||
@ -2073,34 +2064,29 @@ class TFGenerationMixin:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||
|
||||
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||
return generated, finished_sequences, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||
generated, finished_sequences, cur_len, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
if greedy_search_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, _, cur_len, _ = tf.while_loop(
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
greedy_search_cond_fn,
|
||||
greedy_search_body_fn,
|
||||
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
|
||||
if not use_xla:
|
||||
# cut for backward compatibility
|
||||
output_ids = output_ids[:, :cur_len]
|
||||
generated = generated[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -2117,7 +2103,7 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
|
||||
|
||||
return TFGreedySearchEncoderDecoderOutput(
|
||||
sequences=output_ids,
|
||||
sequences=generated,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@ -2127,13 +2113,13 @@ class TFGenerationMixin:
|
||||
)
|
||||
else:
|
||||
return TFGreedySearchDecoderOnlyOutput(
|
||||
sequences=output_ids,
|
||||
sequences=generated,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return output_ids
|
||||
return generated
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@ -2250,6 +2236,8 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
use_xla = not tf.executing_eagerly()
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
@ -2261,29 +2249,20 @@ class TFGenerationMixin:
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||
|
||||
# write prompt to generated
|
||||
for i in range(cur_len):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
||||
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
def sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||
def sample_body_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
input_ids = generated[:, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -2310,9 +2289,8 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
||||
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
|
||||
next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len)
|
||||
|
||||
# sample
|
||||
if seed is not None:
|
||||
@ -2334,8 +2312,8 @@ class TFGenerationMixin:
|
||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||
|
||||
# update `generated` and `cur_len`
|
||||
generated = generated.write(cur_len, next_tokens)
|
||||
next_tokens = next_tokens[:, None]
|
||||
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
|
||||
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
|
||||
cur_len += 1
|
||||
|
||||
# update model_kwargs
|
||||
@ -2350,34 +2328,29 @@ class TFGenerationMixin:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||
|
||||
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||
return generated, finished_sequences, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
generated, finished_sequences, next_tokens, cur_len, model_kwargs = sample_body_fn(
|
||||
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||
generated, finished_sequences, cur_len, model_kwargs = sample_body_fn(
|
||||
generated, finished_sequences, cur_len, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
if sample_cond_fn(generated, finished_sequences, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, _, cur_len, _ = tf.while_loop(
|
||||
generated, _, cur_len, _ = tf.while_loop(
|
||||
sample_cond_fn,
|
||||
sample_body_fn,
|
||||
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||
(generated, finished_sequences, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
|
||||
if not use_xla:
|
||||
# cut for backward compatibility
|
||||
output_ids = output_ids[:, :cur_len]
|
||||
generated = generated[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -2394,7 +2367,7 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
|
||||
|
||||
return TFSampleEncoderDecoderOutput(
|
||||
sequences=output_ids,
|
||||
sequences=generated,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@ -2404,13 +2377,13 @@ class TFGenerationMixin:
|
||||
)
|
||||
else:
|
||||
return TFSampleDecoderOnlyOutput(
|
||||
sequences=output_ids,
|
||||
sequences=generated,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return output_ids
|
||||
return generated
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
@ -2585,6 +2558,8 @@ class TFGenerationMixin:
|
||||
# GPT2 and other models has a slightly different cache structure, with a different batch axis
|
||||
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
|
||||
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
|
||||
# some models, like XLNet, need more than the last token in the presence of past
|
||||
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
|
||||
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
@ -2594,41 +2569,13 @@ class TFGenerationMixin:
|
||||
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
input_ids_length = cur_len
|
||||
|
||||
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
|
||||
sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
|
||||
pad_token_id or 0
|
||||
)
|
||||
running_sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
intermediary_running_sequences = tf.TensorArray(
|
||||
element_shape=(batch_size, num_beams * 2),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
intermediary_running_sequences = intermediary_running_sequences.write(
|
||||
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
|
||||
)
|
||||
|
||||
# write prompt to running_sequences
|
||||
for i in range(cur_len):
|
||||
running_sequences = running_sequences.write(i, input_ids[:, :, i])
|
||||
running_sequences = tf.concat([input_ids, input_ids_padding], axis=-1)
|
||||
sequences = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * (pad_token_id or 0)
|
||||
|
||||
# per batch,beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = tf.zeros((batch_size, num_beams), dtype=tf.bool)
|
||||
@ -2656,7 +2603,6 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
):
|
||||
"""
|
||||
@ -2685,27 +2631,18 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
intermediary_running_sequences=None,
|
||||
):
|
||||
"""
|
||||
Beam Search iterative update function -- each iteration adds a new token and updates the best sequences
|
||||
seen so far
|
||||
"""
|
||||
# TODO (joao): this loop is probably faster with gather/scatters, instead of using `tf.TensorArray`.
|
||||
# Alternativelly, attempt to rewrite function with permuted axis, when enabling XLA.
|
||||
|
||||
# 1. Forward current tokens
|
||||
|
||||
# TF places the dynamic dimension (seq_len) in the first axis, we want it in the last
|
||||
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
|
||||
input_token = tf.slice(
|
||||
running_sequences_seq_last,
|
||||
(0, 0, cur_len - input_ids_length),
|
||||
(batch_size, num_beams, input_ids_length),
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
|
||||
if model_kwargs.get("past") is None or needs_full_input:
|
||||
input_ids = running_sequences[:, :, :cur_len]
|
||||
else:
|
||||
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
|
||||
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs)
|
||||
model_outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
@ -2734,9 +2671,7 @@ class TFGenerationMixin:
|
||||
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = tf.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences_seq_last), flatten_beam_dim(log_probs), cur_len
|
||||
)
|
||||
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
@ -2755,23 +2690,28 @@ class TFGenerationMixin:
|
||||
beams_to_keep = 2 * num_beams
|
||||
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
|
||||
topk_beam_indices = topk_indices // vocab_size
|
||||
topk_running_sequences_seq_last = gather_beams(running_sequences_seq_last, topk_beam_indices)
|
||||
topk_running_sequences = gather_beams(running_sequences, topk_beam_indices)
|
||||
topk_ids = topk_indices % vocab_size
|
||||
|
||||
# writes the new token
|
||||
intermediary_running_sequences = intermediary_running_sequences.unstack(
|
||||
tf.transpose(topk_running_sequences_seq_last, perm=[2, 0, 1])
|
||||
indices_batch = tf.repeat(tf.range(batch_size), [beams_to_keep])
|
||||
indices_beam = tf.tile(tf.range(beams_to_keep), [batch_size])
|
||||
update_indices = tf.stack(
|
||||
[indices_batch, indices_beam, tf.broadcast_to(cur_len, [batch_size * beams_to_keep])], axis=-1
|
||||
)
|
||||
topk_sequences = tf.tensor_scatter_nd_update(
|
||||
tensor=topk_running_sequences,
|
||||
indices=update_indices,
|
||||
updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]),
|
||||
)
|
||||
topk_sequences = intermediary_running_sequences.write(cur_len, topk_ids)
|
||||
topk_sequences_seq_last = tf.transpose(topk_sequences.stack(), perm=[1, 2, 0])
|
||||
|
||||
# 4. Check which sequences have ended
|
||||
# Update current sequences: Did the top `num_beams` sequences reach an end marker?
|
||||
# To prevent these just finished sequences from being added to the current sequences
|
||||
# set of active beam search sequences, set their log probs to a very large negative value.
|
||||
eos_in_next_token = topk_sequences_seq_last[:, :, cur_len] == eos_token_id
|
||||
eos_in_next_token = topk_sequences[:, :, cur_len] == eos_token_id
|
||||
if eos_token_id is None:
|
||||
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences_seq_last[:, :, cur_len].shape)
|
||||
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
|
||||
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
|
||||
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
|
||||
eos_in_next_token.shape,
|
||||
@ -2785,8 +2725,8 @@ class TFGenerationMixin:
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
|
||||
# (from top 2*k beams).
|
||||
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
|
||||
next_running_sequences_seq_last, next_running_scores = gather_beams(
|
||||
[topk_sequences_seq_last, running_topk_log_probs], next_topk_indices
|
||||
next_running_sequences, next_running_scores = gather_beams(
|
||||
[topk_sequences, running_topk_log_probs], next_topk_indices
|
||||
)
|
||||
|
||||
# 6. Process topk logits
|
||||
@ -2807,18 +2747,18 @@ class TFGenerationMixin:
|
||||
# 7. Get scores, sequences, is sentence finished for next.
|
||||
# Combine sequences, scores, and flags along the beam dimension and compare new finished sequence scores
|
||||
# to existing finished scores and select the best from the new set of beams
|
||||
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
|
||||
merged_sequences = tf.concat([sequences_seq_last, topk_sequences_seq_last], axis=1)
|
||||
merged_sequences = tf.concat([sequences, topk_sequences], axis=1)
|
||||
merged_scores = tf.concat([scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
|
||||
next_sequences_seq_last, next_scores, next_is_sent_finished = gather_beams(
|
||||
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices
|
||||
)
|
||||
|
||||
# 8. Prepare data for the next iteration
|
||||
# Determine the top k beam indices from the original set of all beams. With these, gather the top k
|
||||
# beam-associated caches.
|
||||
cur_len = cur_len + 1
|
||||
if "past_key_values" in model_outputs:
|
||||
cache = tf.nest.map_structure(
|
||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=cache_batch_axis),
|
||||
@ -2841,35 +2781,20 @@ class TFGenerationMixin:
|
||||
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
next_input_ids_length = cur_len + 1
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
else:
|
||||
next_input_ids_length = 1
|
||||
|
||||
# 9. Prepare the `tf.TensorArray` for the next iteration
|
||||
next_sequences = sequences.unstack(tf.transpose(next_sequences_seq_last, perm=[2, 0, 1]))
|
||||
next_running_sequences = running_sequences.unstack(
|
||||
tf.transpose(next_running_sequences_seq_last, perm=[2, 0, 1])
|
||||
)
|
||||
|
||||
return (
|
||||
cur_len + 1,
|
||||
cur_len,
|
||||
next_running_sequences,
|
||||
next_running_scores,
|
||||
next_sequences,
|
||||
next_scores,
|
||||
next_is_sent_finished,
|
||||
next_input_ids_length,
|
||||
next_model_kwargs,
|
||||
)
|
||||
|
||||
# 5. run generation
|
||||
# Adds the `intermediary_running_sequences` TensorArray into the body, needed as a scratchpad
|
||||
beam_search_body_fn = partial(
|
||||
beam_search_body_fn, intermediary_running_sequences=intermediary_running_sequences
|
||||
)
|
||||
|
||||
# 1st generation step has to be run before to initialize `past` (if active)
|
||||
(
|
||||
cur_len,
|
||||
@ -2878,66 +2803,38 @@ class TFGenerationMixin:
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
) = beam_search_body_fn(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
|
||||
# NOT yield EOS token though)
|
||||
if beam_search_cond_fn(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs
|
||||
):
|
||||
maximum_iterations = max_length - cur_len
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _, _ = tf.while_loop(
|
||||
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop(
|
||||
beam_search_cond_fn,
|
||||
beam_search_body_fn,
|
||||
(
|
||||
cur_len,
|
||||
running_sequences,
|
||||
running_scores,
|
||||
sequences,
|
||||
scores,
|
||||
is_sent_finished,
|
||||
input_ids_length,
|
||||
model_kwargs,
|
||||
),
|
||||
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
# convert the sequneces to tf.Tensor with shape (batch_size, num_beams, seq_len)
|
||||
sequences_seq_last = tf.transpose(sequences.stack(), perm=[1, 2, 0])
|
||||
running_sequences_seq_last = tf.transpose(running_sequences.stack(), perm=[1, 2, 0])
|
||||
|
||||
# Account for the edge-case where there are no finished sequences for a particular batch item. If so, return
|
||||
# running sequences for that batch item.
|
||||
none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
|
||||
sequences_seq_last = tf.where(none_finished[:, None, None], sequences_seq_last, running_sequences_seq_last)
|
||||
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
|
||||
scores = tf.where(none_finished[:, None], scores, running_scores)
|
||||
|
||||
# Take best beams for each batch (the score is sorted in ascending order)
|
||||
sequences_seq_last = flatten_beam_dim(sequences_seq_last[:, :num_return_sequences, :])
|
||||
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
||||
scores = flatten_beam_dim(scores[:, :num_return_sequences])
|
||||
|
||||
if not use_xla:
|
||||
# Cut for backward compatibility
|
||||
sequences_seq_last = sequences_seq_last[:, :cur_len]
|
||||
sequences = sequences[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -2948,7 +2845,7 @@ class TFGenerationMixin:
|
||||
)
|
||||
|
||||
return TFBeamSearchEncoderDecoderOutput(
|
||||
sequences=sequences_seq_last,
|
||||
sequences=sequences,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@ -2958,13 +2855,13 @@ class TFGenerationMixin:
|
||||
)
|
||||
else:
|
||||
return TFBeamSearchDecoderOnlyOutput(
|
||||
sequences=sequences_seq_last,
|
||||
sequences=sequences,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return sequences_seq_last
|
||||
return sequences
|
||||
|
||||
|
||||
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
||||
|
@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
new_past = [None for _ in range(len(past))]
|
||||
slice_start_base = tf.constant([0, 0, 0, 1, 0])
|
||||
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
|
||||
# correct 5 here
|
||||
new_past_index = current_pos - 1
|
||||
# -1 because current_pos has already been incremented before this function
|
||||
# -1 again because last index = len - 1
|
||||
new_past_index = current_pos - 2
|
||||
|
||||
for i in range(len(past)):
|
||||
update_slice = past[i][:, :, :, -1:]
|
||||
|
@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
|
||||
|
||||
@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
offset = 2
|
||||
|
||||
if past:
|
||||
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
|
||||
input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
|
||||
else:
|
||||
inputs = tf.concat([inputs, dummy_token], axis=1)
|
||||
input_ids = tf.concat([inputs, dummy_token], axis=1)
|
||||
|
||||
# Build permutation mask so that previous tokens don't see last token
|
||||
sequence_length = inputs.shape[1]
|
||||
sequence_length = input_ids.shape[1]
|
||||
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
|
||||
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
|
||||
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
|
||||
@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
|
||||
|
||||
inputs = {
|
||||
"input_ids": inputs,
|
||||
"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_mems": use_mems,
|
||||
|
Loading…
Reference in New Issue
Block a user