mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
do not mess with the negative sign
This commit is contained in:
parent
10989715d0
commit
ca1330f0b2
@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||
if do_sample is False:
|
||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
||||
beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
|
||||
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
|
||||
else:
|
||||
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||
@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
# create eos_token_ids boolean mask
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, eos_token_indices_mask, -float("inf")
|
||||
)
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||
)
|
||||
|
||||
next_token_logits = set_tensor_by_indices_to_value(
|
||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||
)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
# create eos_token_ids boolean mask
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
next_token_logits = set_tensor_by_indices_to_value(
|
||||
next_token_logits, eos_token_indices_mask, -float("inf")
|
||||
)
|
||||
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||
|
||||
if do_sample:
|
||||
@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||
print(next_tokens)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
# re-order batch
|
||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
Loading…
Reference in New Issue
Block a user