do not mess with the negative sign

This commit is contained in:
Patrick von Platen 2020-03-10 14:22:16 +01:00
parent 10989715d0
commit ca1330f0b2

View File

@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
else:
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
scores = set_tensor_by_indices_to_value(
scores, eos_token_indices_mask, -float("inf")
)
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
next_token_logits = set_tensor_by_indices_to_value(
next_token_logits, eos_token_indices_mask, -float("inf")
)
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
if do_sample:
@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
print(next_tokens)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# re-order batch
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
# re-order internal states
if past:
past = self._reorder_cache(past, beam_idx)