mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
do not mess with the negative sign
This commit is contained in:
parent
10989715d0
commit
ca1330f0b2
@ -894,7 +894,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||||
if do_sample is False:
|
if do_sample is False:
|
||||||
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
|
||||||
beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9
|
beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
|
||||||
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
|
beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
|
||||||
else:
|
else:
|
||||||
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
|
||||||
@ -926,6 +926,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
next_token_logits = next_token_logits / temperature
|
next_token_logits = next_token_logits / temperature
|
||||||
|
|
||||||
|
# calculate log softmax score
|
||||||
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
|
# set eos token prob to zero if min_length is not reached
|
||||||
|
if eos_token_ids is not None and cur_len < min_length:
|
||||||
|
# create eos_token_ids boolean mask
|
||||||
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
|
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
|
)
|
||||||
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
|
scores = set_tensor_by_indices_to_value(
|
||||||
|
scores, eos_token_indices_mask, -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||||
@ -937,24 +952,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_logits = set_tensor_by_indices_to_value(
|
scores = set_tensor_by_indices_to_value(
|
||||||
next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
|
||||||
# create eos_token_ids boolean mask
|
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
|
||||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
|
||||||
)
|
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
|
||||||
|
|
||||||
next_token_logits = set_tensor_by_indices_to_value(
|
|
||||||
next_token_logits, eos_token_indices_mask, -float("inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate log softmax score
|
|
||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
|
||||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
@ -991,6 +992,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
) # (batch_size, num_beams * vocab_size)
|
) # (batch_size, num_beams * vocab_size)
|
||||||
|
|
||||||
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||||
|
print(next_tokens)
|
||||||
|
|
||||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||||
|
|
||||||
@ -1064,7 +1066,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
# re-order batch
|
# re-order batch
|
||||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||||
|
|
||||||
# re-order internal states
|
# re-order internal states
|
||||||
if past:
|
if past:
|
||||||
past = self._reorder_cache(past, beam_idx)
|
past = self._reorder_cache(past, beam_idx)
|
||||||
|
Loading…
Reference in New Issue
Block a user