mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
TF: GPT-2 generation supports left-padding (#17426)
* TF GPT-2 now properly works with left padding * throw a warning when eos token == pad token and there is no attention mask
This commit is contained in:
parent
c1a138613d
commit
975dd2bbbc
@ -1498,8 +1498,14 @@ class TFGenerationMixin:
|
||||
)
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
if attention_mask is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
@ -1525,7 +1531,9 @@ class TFGenerationMixin:
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
input_ids, pad_token_id, eos_token_id
|
||||
)
|
||||
|
||||
# 4. Prepare model inputs which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -1653,12 +1661,17 @@ class TFGenerationMixin:
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
inputs: tf.Tensor,
|
||||
pad_token_id: int,
|
||||
pad_token_id: Optional[int],
|
||||
eos_token_id: Optional[int],
|
||||
) -> tf.Tensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs:
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
|
||||
else:
|
||||
return tf.ones(inputs.shape[:2], dtype=tf.int32)
|
||||
@ -1954,6 +1967,7 @@ class TFGenerationMixin:
|
||||
# 1. init greedy_search values
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -1973,10 +1987,9 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
# define bsz, seq_length
|
||||
batch_size, seq_length = input_ids.shape
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
# initialize `generated`, `finished_sequences`, and `current_pos`
|
||||
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
@ -1984,25 +1997,26 @@ class TFGenerationMixin:
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||
|
||||
# write prompt to generated
|
||||
for i in range(seq_length):
|
||||
for i in range(cur_len):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
# define condition fn
|
||||
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
"""state termination condition fn."""
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
# define condition fn
|
||||
def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
def greedy_search_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
"""state update fn."""
|
||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -2029,13 +2043,8 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
# TODO(pvp, joao, matt) - all the logits processors need to be adapted
|
||||
# to be XLA compatible
|
||||
input_ids = None
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, current_pos[0])
|
||||
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||
|
||||
# argmax
|
||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||
@ -2047,16 +2056,14 @@ class TFGenerationMixin:
|
||||
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||
|
||||
# update `generated` and `current_pos`
|
||||
generated = generated.write(current_pos[0], next_tokens)
|
||||
# update `generated` and `cur_len`
|
||||
generated = generated.write(cur_len, next_tokens)
|
||||
next_tokens = next_tokens[:, None]
|
||||
current_pos += 1
|
||||
cur_len += 1
|
||||
|
||||
# update model_kwargs
|
||||
if use_xla:
|
||||
model_kwargs = self._update_model_kwargs_for_xla_generation(
|
||||
outputs, model_kwargs, current_pos, max_length
|
||||
)
|
||||
model_kwargs = self._update_model_kwargs_for_xla_generation(outputs, model_kwargs, cur_len, max_length)
|
||||
else:
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
@ -2067,24 +2074,24 @@ class TFGenerationMixin:
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
next_tokens = tf.transpose(next_tokens[: current_pos[0]])
|
||||
next_tokens = tf.transpose(next_tokens[:cur_len])
|
||||
|
||||
return generated, finished_sequences, next_tokens, current_pos, model_kwargs
|
||||
return generated, finished_sequences, next_tokens, cur_len, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, input_ids, current_pos, model_kwargs
|
||||
generated, finished_sequences, next_tokens, cur_len, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, input_ids, cur_len, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
maximum_iterations = max_length - seq_length - 1
|
||||
generated, _, _, current_pos, _ = tf.while_loop(
|
||||
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, _, cur_len, _ = tf.while_loop(
|
||||
greedy_search_cond_fn,
|
||||
greedy_search_body_fn,
|
||||
(generated, finished_sequences, next_tokens, current_pos, model_kwargs),
|
||||
(generated, finished_sequences, next_tokens, cur_len, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
@ -2093,7 +2100,7 @@ class TFGenerationMixin:
|
||||
|
||||
if not use_xla:
|
||||
# cut for backward compatibility
|
||||
output_ids = output_ids[:, : current_pos[0]]
|
||||
output_ids = output_ids[:, :cur_len]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -2231,6 +2238,7 @@ class TFGenerationMixin:
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -2250,10 +2258,9 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
# define bsz, seq_length
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
# initialize `generated`, `finished_sequences`
|
||||
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
@ -2261,19 +2268,22 @@ class TFGenerationMixin:
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
generated = generated.write(i, tf.broadcast_to(pad_token_id, (batch_size,)))
|
||||
|
||||
# write prompt to generated
|
||||
for i in range(cur_len):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
def sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
def sample_body_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -2300,12 +2310,7 @@ class TFGenerationMixin:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
# TODO(pvp, joao, matt) - all the logits processors/wrappers need to be adapted
|
||||
# to be XLA compatible
|
||||
input_ids = None
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[:cur_len])
|
||||
input_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len)
|
||||
next_tokens_scores = logits_warper(input_ids, next_tokens_scores, cur_len)
|
||||
|
||||
@ -2359,7 +2364,7 @@ class TFGenerationMixin:
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if sample_cond_fn(generated, finished_sequences, next_tokens, cur_len, model_kwargs):
|
||||
maximum_iterations = max_length - cur_len - 1
|
||||
maximum_iterations = max_length - cur_len
|
||||
generated, _, _, cur_len, _ = tf.while_loop(
|
||||
sample_cond_fn,
|
||||
sample_body_fn,
|
||||
@ -2613,12 +2618,13 @@ class TFGenerationMixin:
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
for i in range(max_length):
|
||||
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
intermediary_running_sequences = intermediary_running_sequences.write(
|
||||
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
|
||||
)
|
||||
if pad_token_id: # ignores the cases when it is 0 or None
|
||||
for i in range(max_length):
|
||||
sequences = sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
running_sequences = running_sequences.write(i, tf.broadcast_to(pad_token_id, (batch_size, num_beams)))
|
||||
intermediary_running_sequences = intermediary_running_sequences.write(
|
||||
i, tf.broadcast_to(pad_token_id, (batch_size, num_beams * 2))
|
||||
)
|
||||
|
||||
# write prompt to running_sequences
|
||||
for i in range(cur_len):
|
||||
@ -2699,9 +2705,7 @@ class TFGenerationMixin:
|
||||
(0, 0, cur_len - input_ids_length),
|
||||
(batch_size, num_beams, input_ids_length),
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
flatten_beam_dim(input_token), use_xla=use_xla, **model_kwargs
|
||||
)
|
||||
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_token), **model_kwargs)
|
||||
model_outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
|
@ -490,8 +490,8 @@ class GenerationMixin:
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
pad_token_id: int,
|
||||
eos_token_id: int,
|
||||
pad_token_id: Optional[int],
|
||||
eos_token_id: Optional[int],
|
||||
) -> torch.LongTensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
@ -1137,7 +1137,11 @@ class GenerationMixin:
|
||||
eos_token_id = self.config.decoder.eos_token_id
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# special case if pad_token_id is not defined
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
|
@ -813,25 +813,21 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
|
||||
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
||||
# tests will need to be fixed after the change
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
|
||||
|
||||
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
|
||||
# for a future PR to not change too many things for now.
|
||||
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
|
||||
position_ids = None
|
||||
attention_mask = None
|
||||
if use_xla:
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if past is not None and attention_mask is not None:
|
||||
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
|
||||
elif attention_mask is not None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
|
||||
if past:
|
||||
position_ids = tf.expand_dims(position_ids[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
@ -839,6 +835,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
"position_ids": position_ids,
|
||||
"past": past,
|
||||
"use_cache": use_cache,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
|
||||
|
@ -456,7 +456,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
@ -465,12 +465,12 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"repetition_penalty": 1.3,
|
||||
}
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and I am so happy to be able take part in this amazing event.",
|
||||
"Yesterday was a very busy day for the first time since I started writing this post",
|
||||
"Yesterday was a very interesting time for the world to see how much of this is",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@ -483,7 +483,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
@ -498,13 +498,13 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all",
|
||||
"Yesterday was another solid success as news coverage became standard American domestic television hit.",
|
||||
"Today is a beautiful day and we will make you feel very hot/terrific in all your",
|
||||
"Yesterday was known by national television networks as Le Big Show or Wild Dog Jeopard",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@ -517,7 +517,7 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["Today is a beautiful day and", "Yesterday was"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
@ -526,37 +526,69 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"num_beams": 2,
|
||||
}
|
||||
|
||||
output_ids = model.generate(input_ids, **generation_kwargs)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
expected_output_string = [
|
||||
"Today is a beautiful day and a great day for all of us.\n\nI’m",
|
||||
"Yesterday was the first day of the year for the second time in a row,",
|
||||
"Yesterday was the first time that a person has been arrested in the United States for",
|
||||
]
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2_left_padding(self):
|
||||
"""Tests that the generated text is the same, regarless of left padding"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
generation_kwargs = {
|
||||
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"do_sample": False,
|
||||
"repetition_penalty": 1.3,
|
||||
}
|
||||
expected_output_string = (
|
||||
"Today is a beautiful day and I am so happy to be able take part in this amazing event."
|
||||
)
|
||||
|
||||
sentences = ["Today is a beautiful day and"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
# using default length
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertEqual(output_strings[0], expected_output_string)
|
||||
|
||||
sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
# longer max length to capture the full length (remember: it is left padded)
|
||||
output_ids = model.generate(**input_ids, **generation_kwargs, max_length=27)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertEqual(output_strings[0], expected_output_string)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_greedy_xla(self):
|
||||
# TODO (Joao): convert this to an example with a batch size>1 with different input lengths that works (and fix
|
||||
# the underlying problem)
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentences = ["The dog"]
|
||||
sentences = ["The dog", "The flying machine"]
|
||||
expected_output_strings = [
|
||||
"The dog was found in a field near the intersection of West and West Streets.\n\nThe dog",
|
||||
"The dog was found in a field near the intersection of West and West Streets.\n\nThe",
|
||||
"The flying machine is a small, lightweight, and lightweight aircraft that can be used for any type of",
|
||||
]
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
output_ids = model.generate(**input_ids, do_sample=False)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_strings)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
output_ids = xla_generate(**input_ids, do_sample=False)
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_strings)
|
||||
|
||||
@ -574,21 +606,24 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
sentence = ["The dog"]
|
||||
sentence = ["The dog", "The flying machine"]
|
||||
expected_output_string = [
|
||||
"The dog owner asked why did our vet decide there needed to be extra ventilation inside because most"
|
||||
" puppies"
|
||||
" puppies",
|
||||
"The flying machine was made by an artist who found it difficult to control it as it did not use",
|
||||
]
|
||||
expected_output_string_xla = [
|
||||
"The dog has been named in connection with the murder of a 20-year-old man in!"
|
||||
"The dog has been named in connection with the murder of a 20-year-old man in",
|
||||
"The flying machine is a new and improved system to operate and operate a new system and system "
|
||||
"system system",
|
||||
]
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True)
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_ids = model.generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string)
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
output_ids = xla_generate(input_ids, do_sample=True, seed=[7, 0])
|
||||
output_ids = xla_generate(**input_ids, do_sample=True, seed=[7, 0])
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(output_strings, expected_output_string_xla)
|
||||
|
Loading…
Reference in New Issue
Block a user