transformers/tests/models/xglm/test_modeling_tf_xglm.py
Sylvain Gugger 6f79d26442
Update quality tooling for formatting (#21480)
* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
2023-02-06 18:10:56 -05:00

284 lines
11 KiB
Python

# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_tf_available():
import tensorflow as tf
from transformers.models.xglm.modeling_tf_xglm import (
TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXGLMForCausalLM,
TFXGLMModel,
)
@require_tf
class TFXGLMModelTester:
config_cls = XGLMConfig
config_updates = {}
hidden_act = "gelu"
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
d_model=32,
num_hidden_layers=5,
num_attention_heads=4,
ffn_dim=37,
activation_function="gelu",
activation_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = d_model
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.ffn_dim = ffn_dim
self.activation_function = activation_function
self.activation_dropout = activation_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = 0
self.eos_token_id = 2
self.pad_token_id = 1
def get_large_model_config(self):
return XGLMConfig.from_pretrained("facebook/xglm-564M")
def prepare_config_and_inputs(self):
input_ids = tf.clip_by_value(
ids_tensor([self.batch_size, self.seq_length], self.vocab_size), clip_value_min=0, clip_value_max=3
)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
head_mask = floats_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
return (
config,
input_ids,
input_mask,
head_mask,
)
def get_config(self):
return XGLMConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
num_layers=self.num_hidden_layers,
attention_heads=self.num_attention_heads,
ffn_dim=self.ffn_dim,
activation_function=self.activation_function,
activation_dropout=self.activation_dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
input_mask,
head_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"head_mask": head_mask,
}
return config, inputs_dict
@require_tf
class TFXGLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFXGLMModel, TFXGLMForCausalLM) if is_tf_available() else ()
all_generative_model_classes = (TFXGLMForCausalLM,) if is_tf_available() else ()
test_onnx = False
test_missing_keys = False
test_pruning = False
def setUp(self):
self.model_tester = TFXGLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=XGLMConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in self.all_generative_model_classes:
x = model.get_output_embeddings()
assert isinstance(x, tf.keras.layers.Layer)
name = model.get_bias()
assert name is None
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None
@slow
def test_batch_generation(self):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
tokenizer.padding_side = "left"
# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
output_non_padded = model.generate(input_ids=inputs_non_padded)
num_paddings = (
inputs_non_padded.shape[-1]
- tf.math.reduce_sum(tf.cast(inputs["attention_mask"][-1], dtype=tf.int64)).numpy()
)
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
"Hello, my dog is a little bit of a shy one, but he is very friendly",
"Today, I am going to share with you a few of my favorite things",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
@slow
def test_model_from_pretrained(self):
for model_name in TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFXGLMModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="Currently, model embeddings are going to undergo a major refactor.")
def test_resize_token_embeddings(self):
super().test_resize_token_embeddings()
@require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
input_ids = tf.convert_to_tensor([[2, 268, 9865]], dtype=tf.int32) # The dog
# </s> The dog is a very friendly dog. He is very affectionate and loves to play with other
# fmt: off
expected_output_ids = [2, 268, 9865, 67, 11, 1988, 57252, 9865, 5, 984, 67, 1988, 213838, 1658, 53, 70446, 33, 6657, 278, 1581]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False, num_beams=1)
if verify_outputs:
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_xglm_sample(self):
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
tf.random.set_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="tf")
input_ids = tokenized.input_ids
output_ids = model.generate(input_ids, do_sample=True, seed=[7, 0])
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = (
"Today is a nice day and warm evening here over Southern Alberta!! Today when they closed schools due"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow
def test_lm_generate_xglm_left_padding(self):
"""Tests that the generated text is the same, regarless of left padding"""
tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
tokenizer.padding_side = "left"
generation_kwargs = {
"bad_words_ids": [tokenizer("is").input_ids, tokenizer("angry about").input_ids],
"no_repeat_ngram_size": 2,
"do_sample": False,
"repetition_penalty": 1.3,
}
expected_output_string = (
"Today is a beautiful day and I am so glad that we have the opportunity to spend time with"
)
sentences = ["Today is a beautiful day and"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
# using default length
output_ids = model.generate(**input_ids, **generation_kwargs)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(output_strings[0], expected_output_string)
sentences = ["Today is a beautiful day and", "This is a very long input that we absolutely don't care about"]
input_ids = tokenizer(sentences, return_tensors="tf", padding=True)
# longer max length to capture the full length (remember: it is left padded)
output_ids = model.generate(**input_ids, **generation_kwargs, max_length=28)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(output_strings[0], expected_output_string)