From fb2b45e5627c99a8d65c6cb42cad70bfef7b87be Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Tue, 29 Nov 2022 12:02:40 -0300 Subject: [PATCH] add in layer gpt2 tokenizer (#20421) * add minimal working gpt2 tokenizer * graph mode and output equivalence tests working * not today tensorflow. serialization test passing! * fix style, documentation, docstrings and all that jazz * passing consistency checks * move keras nlp to tf dependencies * fix tf modeling utils and gpt2 attention to enable compiling * fix (I hope) keras nlp dependencies * rever changes on generation * remove debug prints * remove redundant tf dummy objects * add from config, get config and max length settings to address review * let flake ignore the error on distillation you are welcome * test from config * add padding test * address sgugger review --- docs/source/en/model_doc/gpt2.mdx | 4 + .../seq2seq-distillation/distillation.py | 2 +- setup.py | 6 +- src/transformers/__init__.py | 22 +++ src/transformers/dependency_versions_table.py | 1 + src/transformers/models/gpt2/__init__.py | 17 +++ .../models/gpt2/tokenization_gpt2_tf.py | 104 ++++++++++++++ src/transformers/utils/__init__.py | 1 + .../utils/dummy_keras_nlp_objects.py | 10 ++ src/transformers/utils/import_utils.py | 4 + .../models/gpt2/test_tokenization_gpt2_tf.py | 130 ++++++++++++++++++ 11 files changed, 297 insertions(+), 4 deletions(-) create mode 100644 src/transformers/models/gpt2/tokenization_gpt2_tf.py create mode 100644 src/transformers/utils/dummy_keras_nlp_objects.py create mode 100644 tests/models/gpt2/test_tokenization_gpt2_tf.py diff --git a/docs/source/en/model_doc/gpt2.mdx b/docs/source/en/model_doc/gpt2.mdx index 475b0801365..caa23c337f6 100644 --- a/docs/source/en/model_doc/gpt2.mdx +++ b/docs/source/en/model_doc/gpt2.mdx @@ -138,6 +138,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] modeling_tf_outputs.TFSequenceClassifierOutputWithPast +## TFGPT2Tokenizer + +[[autodoc]] TFGPT2Tokenizer + ## FlaxGPT2Model [[autodoc]] FlaxGPT2Model diff --git a/examples/research_projects/seq2seq-distillation/distillation.py b/examples/research_projects/seq2seq-distillation/distillation.py index 5a403be8d56..78ff49718bb 100755 --- a/examples/research_projects/seq2seq-distillation/distillation.py +++ b/examples/research_projects/seq2seq-distillation/distillation.py @@ -5,7 +5,7 @@ import gc import os import sys from pathlib import Path -from typing import List +from typing import List # noqa: F401 import pytorch_lightning as pl import torch diff --git a/setup.py b/setup.py index ef533058b7b..a78075875e8 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ _deps = [ "jaxlib>=0.1.65,<=0.3.6", "jieba", "kenlm", + "keras-nlp>=0.3.1", "nltk", "natten>=0.14.4", "numpy>=1.17", @@ -241,14 +242,13 @@ class DepsTableUpdateCommand(Command): with open(target, "w", encoding="utf-8", newline="\n") as f: f.write("\n".join(content)) - extras = {} extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text") +extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp") extras["torch"] = deps_list("torch") extras["accelerate"] = deps_list("accelerate") diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f38918cf24b..75c467fc936 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -32,6 +32,7 @@ from .utils import ( OptionalDependencyNotAvailable, _LazyModule, is_flax_available, + is_keras_nlp_available, is_sentencepiece_available, is_speech_available, is_tensorflow_text_available, @@ -694,6 +695,19 @@ except OptionalDependencyNotAvailable: else: _import_structure["models.bert"].append("TFBertTokenizer") +# keras-nlp-specific objects +try: + if not is_keras_nlp_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_keras_nlp_objects + + _import_structure["utils.dummy_keras_nlp_objects"] = [ + name for name in dir(dummy_keras_nlp_objects) if not name.startswith("_") + ] +else: + _import_structure["models.gpt2"].append("TFGPT2Tokenizer") + try: if not (is_sentencepiece_available() and is_speech_available()): raise OptionalDependencyNotAvailable() @@ -3828,6 +3842,14 @@ if TYPE_CHECKING: else: from .models.bert import TFBertTokenizer + try: + if not is_keras_nlp_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_keras_nlp_objects import * + else: + from .models.gpt2 import TFGPT2Tokenizer + try: if not (is_speech_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b5fd55d1a75..5ee9317270f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -30,6 +30,7 @@ deps = { "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "jieba": "jieba", "kenlm": "kenlm", + "keras-nlp": "keras-nlp>=0.3.1", "nltk": "nltk", "natten": "natten>=0.14.4", "numpy": "numpy>=1.17", diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py index 477f0cc8d8b..e934602496f 100644 --- a/src/transformers/models/gpt2/__init__.py +++ b/src/transformers/models/gpt2/__init__.py @@ -22,6 +22,7 @@ from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, is_flax_available, + is_keras_nlp_available, is_tf_available, is_tokenizers_available, is_torch_available, @@ -74,6 +75,14 @@ else: "TFGPT2PreTrainedModel", ] +try: + if not is_keras_nlp_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_gpt2_tf"] = ["TFGPT2Tokenizer"] + try: if not is_flax_available(): raise OptionalDependencyNotAvailable() @@ -127,6 +136,14 @@ if TYPE_CHECKING: TFGPT2PreTrainedModel, ) + try: + if not is_keras_nlp_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_gpt2_tf import TFGPT2Tokenizer + try: if not is_flax_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/gpt2/tokenization_gpt2_tf.py b/src/transformers/models/gpt2/tokenization_gpt2_tf.py new file mode 100644 index 00000000000..ba6f754373c --- /dev/null +++ b/src/transformers/models/gpt2/tokenization_gpt2_tf.py @@ -0,0 +1,104 @@ +import os +from typing import Dict, List, Union + +import tensorflow as tf + +from keras_nlp.tokenizers import BytePairTokenizer +from tensorflow_text import pad_model_inputs + +from .tokenization_gpt2 import GPT2Tokenizer + + +class TFGPT2Tokenizer(tf.keras.layers.Layer): + """ + This is an in-graph tokenizer for GPT2. It should be initialized similarly to other tokenizers, using the + `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings + from an existing standard tokenizer object. + + In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run + when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options + than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes + straight from `tf.string` inputs to outputs. + + Args: + vocab (Dict[str, int]): Vocabulary dict for Byte Pair Tokenizer + merges (List[str]): Merges list for Byte Pair Tokenizer + """ + + def __init__(self, vocab: Dict[str, int], merges: List[str], max_length: int = None, pad_token_id: int = None): + super().__init__() + self.pad_token_id = pad_token_id + self.max_length = max_length + self.vocab = vocab + self.merges = merges + self.tf_tokenizer = BytePairTokenizer(vocab, merges, sequence_length=max_length) + + @classmethod + def from_tokenizer(cls, tokenizer: GPT2Tokenizer, *args, **kwargs): + """Creates TFGPT2Tokenizer from GPT2Tokenizer + + Args: + tokenizer (GPT2Tokenizer) + + Examples: + + ```python + from transformers import AutoTokenizer, TFGPT2Tokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tf_tokenizer = TFGPT2Tokenizer.from_tokenizer(tokenizer) + ``` + """ + merges = [" ".join(m) for m in tokenizer.bpe_ranks.keys()] + vocab = tokenizer.get_vocab() + return cls(vocab, merges, *args, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): + """Creates TFGPT2Tokenizer from pretrained GPT2Tokenizer + + Args: + pretrained_model_name_or_path (Union[str, os.PathLike]): Path to pretrained model + + Examples: + + ```python + from transformers import TFGPT2Tokenizer + + tf_tokenizer = TFGPT2Tokenizer.from_pretrained("gpt2") + ``` + """ + tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) + return cls.from_tokenizer(tokenizer, *init_inputs, **kwargs) + + @classmethod + def from_config(cls, config): + """Creates TFGPT2Tokenizer from configurations + + Args: + config (Dict): Dictionary with keys such as stated in `get_config`. + """ + return cls(**config) + + def get_config(self): + return { + "vocab": self.vocab, + "merges": self.merges, + "max_length": self.max_length, + "pad_token_id": self.pad_token_id, + } + + def call(self, x, max_length: int = None): + input_ids = self.tf_tokenizer(x) + attention_mask = tf.ones_like(input_ids) + + if self.pad_token_id is not None: + # pad the tokens up to max length + max_length = max_length if max_length is not None else self.max_length + + if max_length is not None: + input_ids, attention_mask = pad_model_inputs( + input_ids, max_seq_length=max_length, pad_value=self.pad_token_id + ) + + return {"attention_mask": attention_mask, "input_ids": input_ids} diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 0c0c0190567..2ca5bf96f71 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -111,6 +111,7 @@ from .import_utils import ( is_ipex_available, is_jumanpp_available, is_kenlm_available, + is_keras_nlp_available, is_librosa_available, is_more_itertools_available, is_natten_available, diff --git a/src/transformers/utils/dummy_keras_nlp_objects.py b/src/transformers/utils/dummy_keras_nlp_objects.py new file mode 100644 index 00000000000..6d9a466d29e --- /dev/null +++ b/src/transformers/utils/dummy_keras_nlp_objects.py @@ -0,0 +1,10 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa +from ..utils import DummyObject, requires_backends + + +class TFGPT2Tokenizer(metaclass=DummyObject): + _backends = ["keras_nlp"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["keras_nlp"]) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 81faa6ea853..bed1ec233cf 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -572,6 +572,10 @@ def is_tensorflow_text_available(): return importlib.util.find_spec("tensorflow_text") is not None +def is_keras_nlp_available(): + return importlib.util.find_spec("keras_nlp") is not None + + def is_in_notebook(): try: # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py diff --git a/tests/models/gpt2/test_tokenization_gpt2_tf.py b/tests/models/gpt2/test_tokenization_gpt2_tf.py new file mode 100644 index 00000000000..416f9e0796a --- /dev/null +++ b/tests/models/gpt2/test_tokenization_gpt2_tf.py @@ -0,0 +1,130 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from transformers import AutoConfig, TFGPT2LMHeadModel, is_tensorflow_text_available, is_tf_available +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from transformers.testing_utils import require_tensorflow_text, slow + + +if is_tensorflow_text_available(): + from transformers.models.gpt2 import TFGPT2Tokenizer + +if is_tf_available(): + import tensorflow as tf + + +TOKENIZER_CHECKPOINTS = ["gpt2"] +TINY_MODEL_CHECKPOINT = "gpt2" + +if is_tf_available(): + + class ModelToSave(tf.Module): + def __init__(self, tokenizer): + super().__init__() + self.tokenizer = tokenizer + config = AutoConfig.from_pretrained(TINY_MODEL_CHECKPOINT) + self.model = TFGPT2LMHeadModel.from_config(config) + + @tf.function(input_signature=(tf.TensorSpec((None,), tf.string, name="text"),)) + def serving(self, text): + + tokenized = self.tokenizer(text) + input_ids_dense = tokenized["input_ids"].to_tensor() + + input_mask = tf.cast(input_ids_dense > 0, tf.int32) + # input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN]) + + outputs = self.model(input_ids=input_ids_dense, attention_mask=input_mask)["logits"] + + return outputs + + +@require_tensorflow_text +class GPTTokenizationTest(unittest.TestCase): + # The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints, + # so that's what we focus on here. + + def setUp(self): + super().setUp() + + self.tokenizers = [GPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in (TOKENIZER_CHECKPOINTS)] + self.tf_tokenizers = [TFGPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS] + assert len(self.tokenizers) == len(self.tf_tokenizers) + + self.test_sentences = [ + "This is a straightforward English test sentence.", + "This one has some weird characters\rto\nsee\r\nif those\u00E9break things.", + "Now we're going to add some Chinese: 一 二 三 一二三", + "And some much more rare Chinese: 齉 堃 齉堃", + "Je vais aussi écrire en français pour tester les accents", + "Classical Irish also has some unusual characters, so in they go: Gaelaċ, ꝼ", + ] + self.paired_sentences = list(zip(self.test_sentences, self.test_sentences[::-1])) + + def test_output_equivalence(self): + for tokenizer, tf_tokenizer in zip(self.tokenizers, self.tf_tokenizers): + for test_inputs in self.test_sentences: + python_outputs = tokenizer([test_inputs], return_tensors="tf") + tf_outputs = tf_tokenizer([test_inputs]) + + for key in python_outputs.keys(): + # convert them to numpy to avoid messing with ragged tensors + python_outputs_values = python_outputs[key].numpy() + tf_outputs_values = tf_outputs[key].numpy() + + self.assertTrue(tf.reduce_all(python_outputs_values.shape == tf_outputs_values.shape)) + self.assertTrue(tf.reduce_all(tf.cast(python_outputs_values, tf.int64) == tf_outputs_values)) + + @slow + def test_graph_mode(self): + for tf_tokenizer in self.tf_tokenizers: + compiled_tokenizer = tf.function(tf_tokenizer) + for test_inputs in self.test_sentences: + test_inputs = tf.constant(test_inputs) + compiled_outputs = compiled_tokenizer(test_inputs) + eager_outputs = tf_tokenizer(test_inputs) + + for key in eager_outputs.keys(): + self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key])) + + @slow + def test_saved_model(self): + for tf_tokenizer in self.tf_tokenizers: + model = ModelToSave(tokenizer=tf_tokenizer) + test_inputs = tf.convert_to_tensor([self.test_sentences[0]]) + out = model.serving(test_inputs) # Build model with some sample inputs + with TemporaryDirectory() as tempdir: + save_path = Path(tempdir) / "saved.model" + tf.saved_model.save(model, save_path, signatures={"serving_default": model.serving}) + loaded_model = tf.saved_model.load(save_path) + loaded_output = loaded_model.signatures["serving_default"](test_inputs)["output_0"] + # We may see small differences because the loaded model is compiled, so we need an epsilon for the test + self.assertTrue(tf.reduce_all(out == loaded_output)) + + @slow + def test_from_config(self): + for tf_tokenizer in self.tf_tokenizers: + test_inputs = tf.convert_to_tensor([self.test_sentences[0]]) + out = tf_tokenizer(test_inputs) # Build model with some sample inputs + + config = tf_tokenizer.get_config() + model_from_config = TFGPT2Tokenizer.from_config(config) + from_config_output = model_from_config(test_inputs) + + for key in from_config_output.keys(): + self.assertTrue(tf.reduce_all(from_config_output[key] == out[key])) + + @slow + def test_padding(self): + for tf_tokenizer in self.tf_tokenizers: + # for the test to run + tf_tokenizer.pad_token_id = 123123 + + for max_length in [3, 5, 1024]: + test_inputs = tf.convert_to_tensor([self.test_sentences[0]]) + out = tf_tokenizer(test_inputs, max_length=max_length) + + out_length = out["input_ids"].numpy().shape[1] + + assert out_length == max_length