mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Generate: TF can now generate from embeddings in encoder-decoder models (#21475)
This commit is contained in:
parent
12eb528b5a
commit
1e4cf8bb44
@ -664,9 +664,11 @@ class TFGenerationMixin:
|
||||
)
|
||||
|
||||
# 4. Define model inputs
|
||||
input_ids = self._prepare_model_inputs(input_ids, generation_config.bos_token_id)
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
input_ids, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
# inputs_ids now has to be defined and cannot be None anymore
|
||||
batch_size = shape_list(input_ids)[0]
|
||||
batch_size = shape_list(inputs_tensor)[0]
|
||||
|
||||
# 5. Prepare other model kwargs
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
@ -678,23 +680,26 @@ class TFGenerationMixin:
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||
)
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
if not self.config.is_encoder_decoder:
|
||||
if generation_config.pad_token_id is not None and tf.math.reduce_any(
|
||||
input_ids[:, -1] == generation_config.pad_token_id
|
||||
inputs_tensor[:, -1] == generation_config.pad_token_id
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||
)
|
||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
inputs_tensor, model_kwargs, model_input_name
|
||||
)
|
||||
|
||||
# 6. Prepare model inputs which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
@ -702,6 +707,9 @@ class TFGenerationMixin:
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||
input_ids = inputs_tensor
|
||||
|
||||
# 7. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
@ -924,7 +932,9 @@ class TFGenerationMixin:
|
||||
else:
|
||||
return tf.ones(inputs.shape[:2], dtype=tf.int32)
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor: tf.Tensor, model_kwargs) -> Dict[str, Any]:
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
@ -938,7 +948,9 @@ class TFGenerationMixin:
|
||||
|
||||
# vision models don't use `attention_mask`.
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[self.main_input_name] = inputs_tensor
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
if model_input_name != self.main_input_name: # in Keras, the first input must always be passed
|
||||
encoder_kwargs[self.main_input_name] = None
|
||||
encoder_outputs = encoder(**encoder_kwargs)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
|
||||
@ -1007,19 +1019,79 @@ class TFGenerationMixin:
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
|
||||
# TODO(Patrick) - adapt this function when making `generate` more flexible
|
||||
# for all kinds of input types
|
||||
if inputs is None:
|
||||
# if no `inputs` are passed create prompt of size (1,1) filled with BOS token
|
||||
if not isinstance(bos_token_id, int) or bos_token_id < 0:
|
||||
raise ValueError(
|
||||
"you should either supply a context to complete as `input_ids` input "
|
||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||
)
|
||||
return tf.cast(tf.fill((1, 1), bos_token_id), dtype=tf.int32)
|
||||
def _prepare_model_inputs(
|
||||
self,
|
||||
inputs: Optional[tf.Tensor] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
|
||||
) -> Tuple[tf.Tensor, Optional[str], Dict[str, tf.Tensor]]:
|
||||
"""
|
||||
This function extracts the model-specific `inputs` for generation.
|
||||
"""
|
||||
# 1. retrieve all kwargs that are non-None or non-model input related.
|
||||
# some encoder-decoder models have different names for model and encoder
|
||||
if (
|
||||
self.config.is_encoder_decoder
|
||||
and hasattr(self, "encoder")
|
||||
and hasattr(self.encoder, "main_input_name")
|
||||
and self.encoder.main_input_name != self.main_input_name
|
||||
):
|
||||
input_name = self.encoder.main_input_name
|
||||
else:
|
||||
input_name = self.main_input_name
|
||||
|
||||
return inputs
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
|
||||
|
||||
# 2. check whether model_input_name is passed as kwarg
|
||||
# if yes and `inputs` is None use kwarg inputs
|
||||
inputs_kwarg = model_kwargs.pop(input_name, None)
|
||||
if inputs_kwarg is not None and inputs is not None:
|
||||
raise ValueError(
|
||||
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
|
||||
f"Make sure to either pass {inputs} or {input_name}=..."
|
||||
)
|
||||
elif inputs_kwarg is not None:
|
||||
inputs = inputs_kwarg
|
||||
|
||||
# 3. In the presence of `inputs_embeds` for text models:
|
||||
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
|
||||
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
|
||||
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
|
||||
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
|
||||
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
|
||||
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
||||
if not self.config.is_encoder_decoder:
|
||||
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
||||
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
||||
)
|
||||
if not has_inputs_embeds_forwarding:
|
||||
raise ValueError(
|
||||
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
|
||||
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
|
||||
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
|
||||
)
|
||||
else:
|
||||
if inputs is not None:
|
||||
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
if inputs is None:
|
||||
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
def _prepare_input_ids_for_generation(
|
||||
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
|
||||
) -> tf.Tensor:
|
||||
if self.config.is_encoder_decoder and encoder_outputs is not None:
|
||||
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
|
||||
shape = encoder_outputs.last_hidden_state.size()[:-1]
|
||||
return tf.ones(shape, dtype=tf.int32) * -100
|
||||
|
||||
if bos_token_id is None:
|
||||
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
|
||||
return tf.ones((1, 1), dtype=tf.int32) * bos_token_id
|
||||
|
||||
@staticmethod
|
||||
def _extract_past_from_model_output(outputs: ModelOutput):
|
||||
|
@ -5,11 +5,13 @@ Framework agnostic tests for generate()-related methods.
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import torch_device
|
||||
|
||||
|
||||
class GenerationIntegrationTestsMixin:
|
||||
# To be populated by the child classes
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": None,
|
||||
"AutoModelForSeq2SeqLM": None,
|
||||
"LogitsProcessorList": None,
|
||||
"MinLengthLogitsProcessor": None,
|
||||
@ -60,3 +62,91 @@ class GenerationIntegrationTestsMixin:
|
||||
|
||||
bart_model.config.min_length = None
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
bart_model = bart_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = gpt2_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
gpt2_model = gpt2_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
@ -135,6 +135,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_tf_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": TFAutoModelForCausalLM,
|
||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": TFLogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||
|
@ -40,7 +40,6 @@ if is_torch_available():
|
||||
ImageGPTForCausalImageModeling,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
@ -1792,6 +1791,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_torch_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": AutoModelForCausalLM,
|
||||
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||
"LogitsProcessorList": LogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
|
||||
@ -2094,182 +2094,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
output = generator(prompt, stop_sequence=" number")
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in in in number"}])
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_t5(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
t5_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
t5_model = T5ForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-t5").to(torch_device)
|
||||
input_ids = t5_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 56])
|
||||
|
||||
max_new_tokens = 3
|
||||
t5_model.config.max_length = 20
|
||||
t5_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = t5_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = t5_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
# 56 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 59])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = t5_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_bart(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(
|
||||
decoder_input_ids=input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4
|
||||
)
|
||||
# 29 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 32])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gptj(self):
|
||||
article = """Justin Timberlake."""
|
||||
gptj_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gptj")
|
||||
gptj_model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gptj").to(torch_device)
|
||||
input_ids = gptj_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gptj_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gptj_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gptj_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only_contrastive_search_gpt2(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20, penalty_alpha=0.6, top_k=4)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
gpt2_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
||||
torch_device
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_encoder_decoder_generate_attention_mask(self):
|
||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
Loading…
Reference in New Issue
Block a user