[ gemma] Adds support for Gemma 💎 (#29167)

* inital commit

* update

* update conversion checkpoint

* update conversion script

* nits

* some fixes

* nits

* merge

* fix permute

* nits

* fix

* nits

* nits

* nits

* fix rope

* fix both rope

* nites

* style

* make sure flax works

* fix flax init code

* fix foward

* nits

* print flax generation out

* current code

* nits

* SIIIIIIIIIIIIIIIIIII

* update

* add new tokenizer

* correct fast tokenizer

* fix conversion

* more comments

* fix modeling and conversion

* nits and nits

* nits testing

* add some tokenization tests

* add some edge cases

* add slow tests and fix them

* fixup

* fix copies for modeling

* fix copies

* add 7B slow tests

* fix

* fix

* fix tests

* make tokenizer cis go green

* styling

* last tokenizer nits

* update jax tests

* fix flax for 7b

* add jit testing 🤗

* cleanups

* isolated nit, inv_freq for rotary_emb.inv_freq

* propagate to jax

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adjust test

* fix conversion script

* change name

* correct file names

* update conversion script

* Fix bos and eos token ids in the model configuration (#3)

* update modelling

* update conversion script

* add static cache for gemma

* fix sdpa generate

* fix batched

* multiple fixes

* fix FA2

* final fix

* Rename a few missing strings and filenames (#4)

* merge with upstream main

* fix copies

* fix copies

* fix fixup

* fix fixup

* fix

* fix

* final tests

* fix fx gemma tests

* fix fx bf16/fp16 tests

* update slow fx tests

* fx slow tests: one logits, one generation

* move jit test standalone

* Apply suggestions from code review

* nits

* tokenizer updates

* more tokenization updates: custom GemmaSentencepieceExtrator

* style

* Update src/transformers/cache_utils.py

* Update src/transformers/models/gemma/__init__.py

* Update tests/models/gemma/test_modeling_flax_gemma.py

* small nits

* style

* update tokenization test

* fix the rotary embedding

* with style

* fix slow tests

* WARNING this commit might be very important for precisions

* Update tests/models/gemma/test_modeling_flax_gemma.py

* Update src/transformers/models/gemma/configuration_gemma.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/models/gemma/modeling_flax_gemma.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* small nits here and there!

* forgotten nit

* remove on the fly computation of inv_freq

* revert previous change, let's be safe and for now re-compute freq cis to make sure it's in float

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/gemma/convert_gemma_weights_to_hf.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/transformers/models/gemma/convert_gemma_weights_to_hf.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_flax_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_tokenization_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_tokenization_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_tokenization_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_tokenization_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update tests/models/gemma/test_modeling_gemma.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* nit conversion script link

* fix some tests

* add not doctest and pr doctest

* repo consistency

* fix last CIs 🚀

* update all readmes

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur 2024-02-21 14:21:28 +01:00 committed by GitHub
parent 58245ba6fb
commit 594c1277b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 4811 additions and 6 deletions

View File

@ -374,6 +374,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (from ADEPT) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. Released with the paper [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.

View File

@ -347,6 +347,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (from ADEPT) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. Released with the paper [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.

View File

@ -368,6 +368,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (de Microsoft Research) publié dans l'article [Réseaux de modulation focale](https://arxiv.org/abs/2203.11926) par Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (de l'Université Carnegie Mellon/Google Brain) publié dans l'article [Funnel-Transformer : Filtrer la redondance séquentielle pour un traitement efficace du langage](https://arxiv.org/abs/2006.03236) par Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (de ADEPT) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. Publié dans l'article [billet de blog](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (de Google) publié dans l'article [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) parthe Gemma Google team.
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (de Microsoft Research) publié dans l'article [GIT : Un transformateur génératif d'images en texte pour la vision et le langage](https://arxiv.org/abs/2205.14100) par Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (de la KAIST) publié dans l'article [Réseaux de chemins globaux-locaux pour l'estimation de profondeur monoculaire avec Vertical CutDepth](https://arxiv.org/abs/2201.07436) par Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (d'OpenAI) publié dans l'article [Améliorer la compréhension du langage par l'apprentissage préalable génératif](https://openai.com/research/language-unsupervised/) par Alec Radford, Karthik Narasimhan, Tim Salimans et Ilya Sutskever.

View File

@ -321,6 +321,7 @@ conda install conda-forge::transformers
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (Microsoft Research से) Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao. द्वाराअनुसंधान पत्र [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) के साथ जारी किया गया
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (सीएमयू/गूगल ब्रेन से) साथ में कागज [फ़नल-ट्रांसफॉर्मर: कुशल भाषा प्रसंस्करण के लिए अनुक्रमिक अतिरेक को छानना](https://arxiv.org/abs/2006.03236) जिहांग दाई, गुओकुन लाई, यिमिंग यांग, क्वोक वी. ले द्वारा रिहाई।
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (ADEPT से) रोहन बाविशी, एरिच एलसेन, कर्टिस हॉथोर्न, मैक्सवेल नी, ऑगस्टस ओडेना, अरुशी सोमानी, सागनाक तासिरलार [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (Google से) the Gemma Google team. द्वाराअनुसंधान पत्र [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) के साथ जारी किया गया
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (KAIST से) साथ वाला पेपर [वर्टिकल कटडेप्थ के साथ मोनोकुलर डेप्थ एस्टीमेशन के लिए ग्लोबल-लोकल पाथ नेटवर्क्स](https://arxiv.org/abs/2201.07436) डोयोन किम, वूंगह्युन गा, प्युंगवान आह, डोंगग्यू जू, सेहवान चुन, जुनमो किम द्वारा।
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (OpenAI से) साथ में दिया गया पेपर [जेनरेटिव प्री-ट्रेनिंग द्वारा भाषा की समझ में सुधार](https://openai.com/research/language-unsupervised/) एलेक रैडफोर्ड, कार्तिक नरसिम्हन, टिम सालिमन्स और इल्या सुत्स्केवर द्वारा।

View File

@ -381,6 +381,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (Microsoft Research から) Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao. から公開された研究論文 [Focal Modulation Networks](https://arxiv.org/abs/2203.11926)
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (CMU/Google Brain から) Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le から公開された研究論文: [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236)
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (ADEPT から) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. から公開された研究論文 [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (Google から) the Gemma Google team. から公開された研究論文 [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/)
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (Microsoft Research から) Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang. から公開された研究論文 [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100)
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (KAIST から) Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim から公開された研究論文: [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436)
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (OpenAI から) Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever から公開された研究論文: [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/)

View File

@ -296,6 +296,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (from ADEPT) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. 논문과 함께 공개 [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (Google 에서 제공)은 the Gemma Google team.의 [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/)논문과 함께 발표했습니다.
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.

View File

@ -320,6 +320,7 @@ conda install conda-forge::transformers
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (来自 Microsoft Research) 伴随论文 [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) 由 Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao 发布。
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (来自 CMU/Google Brain) 伴随论文 [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) 由 Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le 发布。
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (来自 ADEPT) 伴随论文 [blog post](https://www.adept.ai/blog/fuyu-8b) 由 Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar 发布。
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (来自 Google) 伴随论文 [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) 由 the Gemma Google team 发布。
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (来自 Microsoft Research) 伴随论文 [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) 由 Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang 发布。
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (来自 KAIST) 伴随论文 [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) 由 Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim 发布。
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (来自 OpenAI) 伴随论文 [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/) 由 Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever 发布。

View File

@ -332,6 +332,7 @@ conda install conda-forge::transformers
1. **[FocalNet](https://huggingface.co/docs/transformers/model_doc/focalnet)** (from Microsoft Research) released with the paper [Focal Modulation Networks](https://arxiv.org/abs/2203.11926) by Jianwei Yang, Chunyuan Li, Xiyang Dai, Lu Yuan, Jianfeng Gao.
1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[Fuyu](https://huggingface.co/docs/transformers/model_doc/fuyu)** (from ADEPT) Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar. Released with the paper [blog post](https://www.adept.ai/blog/fuyu-8b)
1. **[Gemma](https://huggingface.co/docs/transformers/main/model_doc/gemma)** (from Google) released with the paper [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by the Gemma Google team.
1. **[GIT](https://huggingface.co/docs/transformers/model_doc/git)** (from Microsoft Research) released with the paper [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100) by Jianfeng Wang, Zhengyuan Yang, Xiaowei Hu, Linjie Li, Kevin Lin, Zhe Gan, Zicheng Liu, Ce Liu, Lijuan Wang.
1. **[GLPN](https://huggingface.co/docs/transformers/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim.
1. **[GPT](https://huggingface.co/docs/transformers/model_doc/openai-gpt)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://openai.com/research/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.

View File

@ -354,6 +354,8 @@
title: Funnel Transformer
- local: model_doc/fuyu
title: Fuyu
- local: model_doc/gemma
title: Gemma
- local: model_doc/openai-gpt
title: GPT
- local: model_doc/gpt_neo

View File

@ -142,6 +142,7 @@ Flax), PyTorch, and/or TensorFlow.
| [FocalNet](model_doc/focalnet) | ✅ | ❌ | ❌ |
| [Funnel Transformer](model_doc/funnel) | ✅ | ✅ | ❌ |
| [Fuyu](model_doc/fuyu) | ✅ | ❌ | ❌ |
| [Gemma](model_doc/gemma) | ✅ | ❌ | ✅ |
| [GIT](model_doc/git) | ✅ | ❌ | ❌ |
| [GLPN](model_doc/glpn) | ✅ | ❌ | ❌ |
| [GPT Neo](model_doc/gpt_neo) | ✅ | ❌ | ✅ |

View File

@ -0,0 +1,71 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Gemma
## Overview
The Gemma model was proposed in [Gemma: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/gemma-open-models/) by Gemma Team, Google.
Gemma models are trained on 6T tokens, and released with 2 versions, 2b and 7b.
The abstract from the paper is the following:
*This work introduces Gemma, a new family of open language models demonstrating strong performance across academic benchmarks for language understanding, reasoning, and safety. We release two sizes of models (2 billion and 7 billion parameters), and provide both pretrained and fine-tuned checkpoints. Gemma outperforms similarly sized open models on 11 out of 18 text-based tasks, and we present comprehensive evaluations of safety and responsibility aspects of the models, alongside a detailed description of our model development. We believe the responsible release of LLMs is critical for improving the safety of frontier models, and for enabling the next wave of LLM innovations*
Tips:
- The original checkpoints can be converted using the conversion script `src/transformers/models/gemma/convert_gemma_weights_to_hf.py`
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Younes Belkada](https://huggingface.co/ybelkada), [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi), [Pedro Cuenca](https://huggingface.co/pcuenq).
## GemmaConfig
[[autodoc]] GemmaConfig
## GemmaTokenizer
[[autodoc]] GemmaTokenizer
## GemmaTokenizerFast
[[autodoc]] GemmaTokenizerFast
## GemmaModel
[[autodoc]] GemmaModel
- forward
## GemmaForCausalLM
[[autodoc]] GemmaForCausalLM
- forward
## GemmaForSequenceClassification
[[autodoc]] GemmaForSequenceClassification
- forward
## FlaxGemmaModel
[[autodoc]] FlaxGemmaModel
- __call__
## FlaxGemmaForCausalLM
[[autodoc]] FlaxGemmaForCausalLM
- __call__

View File

@ -40,6 +40,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
@ -171,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)

View File

@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)

View File

@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)

View File

@ -456,6 +456,7 @@ _import_structure = {
"FunnelTokenizer",
],
"models.fuyu": ["FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP", "FuyuConfig"],
"models.gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"],
"models.git": [
"GIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"GitConfig",
@ -1112,6 +1113,7 @@ else:
_import_structure["models.deberta_v2"].append("DebertaV2Tokenizer")
_import_structure["models.ernie_m"].append("ErnieMTokenizer")
_import_structure["models.fnet"].append("FNetTokenizer")
_import_structure["models.gemma"].append("GemmaTokenizer")
_import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizer")
_import_structure["models.llama"].append("LlamaTokenizer")
@ -1176,6 +1178,7 @@ else:
_import_structure["models.electra"].append("ElectraTokenizerFast")
_import_structure["models.fnet"].append("FNetTokenizerFast")
_import_structure["models.funnel"].append("FunnelTokenizerFast")
_import_structure["models.gemma"].append("GemmaTokenizerFast")
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
_import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast")
_import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer")
@ -2241,6 +2244,14 @@ else:
]
)
_import_structure["models.fuyu"].extend(["FuyuForCausalLM", "FuyuPreTrainedModel"])
_import_structure["models.gemma"].extend(
[
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaModel",
"GemmaPreTrainedModel",
]
)
_import_structure["models.git"].extend(
[
"GIT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4672,6 +4683,7 @@ else:
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
_import_structure["models.gemma"].extend(["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"])
_import_structure["models.longt5"].extend(
[
"FlaxLongT5ForConditionalGeneration",
@ -5208,6 +5220,7 @@ if TYPE_CHECKING:
FunnelTokenizer,
)
from .models.fuyu import FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP, FuyuConfig
from .models.gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig
from .models.git import (
GIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GitConfig,
@ -5864,6 +5877,7 @@ if TYPE_CHECKING:
from .models.deberta_v2 import DebertaV2Tokenizer
from .models.ernie_m import ErnieMTokenizer
from .models.fnet import FNetTokenizer
from .models.gemma import GemmaTokenizer
from .models.gpt_sw3 import GPTSw3Tokenizer
from .models.layoutxlm import LayoutXLMTokenizer
from .models.llama import LlamaTokenizer
@ -5920,6 +5934,7 @@ if TYPE_CHECKING:
from .models.electra import ElectraTokenizerFast
from .models.fnet import FNetTokenizerFast
from .models.funnel import FunnelTokenizerFast
from .models.gemma import GemmaTokenizerFast
from .models.gpt2 import GPT2TokenizerFast
from .models.gpt_neox import GPTNeoXTokenizerFast
from .models.gpt_neox_japanese import GPTNeoXJapaneseTokenizer
@ -6848,6 +6863,12 @@ if TYPE_CHECKING:
FuyuForCausalLM,
FuyuPreTrainedModel,
)
from .models.gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaModel,
GemmaPreTrainedModel,
)
from .models.git import (
GIT_PRETRAINED_MODEL_ARCHIVE_LIST,
GitForCausalLM,
@ -8836,6 +8857,11 @@ if TYPE_CHECKING:
FlaxElectraPreTrainedModel,
)
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gemma import (
FlaxGemmaForCausalLM,
FlaxGemmaModel,
FlaxGemmaPreTrainedModel,
)
from .models.gpt2 import (
FlaxGPT2LMHeadModel,
FlaxGPT2Model,

View File

@ -348,7 +348,11 @@ class StaticCache(Cache):
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads

View File

@ -62,6 +62,41 @@ class SentencePieceExtractor:
"""
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
vocab_scores, reverse = vocab, False
# Merges
merges = []
for merge, piece_score in vocab_scores.items():
local = []
for index in range(1, len(merge)):
piece_l, piece_r = merge[:index], merge[index:]
if piece_l in vocab and piece_r in vocab:
local.append((piece_l, piece_r, piece_score))
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
merges.extend(local)
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
merges = [(val[0], val[1]) for val in merges]
return vocab, merges
class GemmaSentencePieceExtractor(SentencePieceExtractor):
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
"""
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
order the merges with respect to the piece scores instead.
"""
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
# there is a missing token in the vocab. We have to do this to support merges
# "<0x09>" is the bytefallback for `\t`
vocab["\t"] = vocab.pop("<0x09>")
if vocab_scores is not None:
vocab_scores, reverse = dict(vocab_scores), True
else:
@ -1190,6 +1225,93 @@ class XGLMConverter(SpmConverter):
)
class GemmaConvert(SpmConverter):
handle_byte_fallback = True
""""
split_by_unicode_script: true
split_by_number: true
split_by_whitespace: true
treat_whitespace_as_suffix: false
allow_whitespace_only_pieces: true
split_digits: true
byte_fallback: true
"""
def normalizer(self, proto):
return normalizers.Replace(" ", "")
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
(self.original_tokenizer.bos_token, 0.0),
]
for piece in proto.pieces[3:]:
if piece.piece == "<0x09>":
vocab += [("\t", piece.score)]
else:
vocab += [(piece.piece, piece.score)]
# vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def pre_tokenizer(self, replacement, add_prefix_space):
return None
def unk_id(self, proto):
unk_id = 3
return unk_id
def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(
[
decoders.Replace("", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
)
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
import tokenizers
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
else:
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
elif model_type == 2:
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=True,
dropout=None,
)
)
tokenizer.add_special_tokens(
[
AddedToken("<pad>", normalized=False, special=True),
AddedToken("<eos>", normalized=False, special=True),
AddedToken("<bos>", normalized=False, special=True),
AddedToken("<unk>", normalized=False, special=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
class LlamaConverter(SpmConverter):
handle_byte_fallback = True
@ -1356,6 +1478,7 @@ SLOW_TO_FAST_CONVERTERS = {
"XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
"CodeLlamaTokenizer": LlamaConverter,
"GemmaTokenizer": GemmaConvert,
}

View File

@ -319,10 +319,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_util.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)
flat_params[key] = conditional_cast(flat_params[key])
return unflatten_dict(flat_params)

View File

@ -92,6 +92,7 @@ from . import (
fsmt,
funnel,
fuyu,
gemma,
git,
glpn,
gpt2,

View File

@ -103,6 +103,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("fsmt", "FSMTConfig"),
("funnel", "FunnelConfig"),
("fuyu", "FuyuConfig"),
("gemma", "GemmaConfig"),
("git", "GitConfig"),
("glpn", "GLPNConfig"),
("gpt-sw3", "GPT2Config"),
@ -336,6 +337,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fuyu", "FUYU_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gemma", "GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("git", "GIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -568,6 +570,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("fsmt", "FairSeq Machine-Translation"),
("funnel", "Funnel Transformer"),
("fuyu", "Fuyu"),
("gemma", "Gemma"),
("git", "GIT"),
("glpn", "GLPN"),
("gpt-sw3", "GPT-Sw3"),

View File

@ -103,6 +103,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("focalnet", "FocalNetModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("gemma", "GemmaModel"),
("git", "GitModel"),
("glpn", "GLPNModel"),
("gpt-sw3", "GPT2Model"),
@ -426,6 +427,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("git", "GitForCausalLM"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
@ -764,6 +766,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("flaubert", "FlaubertForSequenceClassification"),
("fnet", "FNetForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("gemma", "GemmaForSequenceClassification"),
("gpt-sw3", "GPT2ForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),

View File

@ -39,6 +39,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gemma", "FlaxGemmaModel"),
("gpt-sw3", "FlaxGPT2Model"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
@ -144,6 +145,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gemma", "FlaxGemmaForCausalLM"),
("gpt-sw3", "FlaxGPT2LMHeadModel"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),

View File

@ -178,6 +178,13 @@ else:
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
(
"gemma",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),

View File

@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"],
}
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma"] = ["GemmaTokenizer"]
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma"] = [
"GemmaForCausalLM",
"GemmaModel",
"GemmaPreTrainedModel",
"GemmaForSequenceClassification",
]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_gemma"] = [
"FlaxGemmaForCausalLM",
"FlaxGemmaModel",
"FlaxGemmaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig
try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma import GemmaTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma_fast import GemmaTokenizerFast
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaModel,
GemmaPreTrainedModel,
)
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_gemma import (
FlaxGemmaForCausalLM,
FlaxGemmaModel,
FlaxGemmaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Gemma model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class GemmaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Gemma-7B.
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 256000):
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GemmaModel`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 8192):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import GemmaModel, GemmaConfig
>>> # Initializing a Gemma gemma-7b style configuration
>>> configuration = GemmaConfig()
>>> # Initializing a model from the gemma-7b style configuration
>>> model = GemmaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,197 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import warnings
import torch
from accelerate import init_empty_weights
from transformers import GemmaConfig, GemmaForCausalLM, GemmaTokenizer
try:
from transformers import GemmaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
GemmaTokenizerFast = None
"""
Sample usage:
```
python src/transformers/models/gemma/convert_gemma_weights_to_hf.py \
--input_dir /path/to/downloaded/gemma/weights --model_size 7B --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import GemmaForCausalLM, GemmaTokenizerFast
model = GemmaForCausalLM.from_pretrained("/output/path")
tokenizer = GemmaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
gemma_2b_config = GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384,
)
gemma_7b_config = GemmaConfig()
CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config}
LAYER_NAME_MAPPING = {"embedder.weight": "model.embed_tokens.weight"}
def write_model(save_path, input_base_path, config, safe_serialization=True, push_to_hub=False):
num_attn_heads = config.num_attention_heads
hidden_size = config.hidden_size
num_kv_heads = config.num_key_value_heads
head_dim = config.head_dim
print(f"Fetching all parameters from the checkpoint at '{input_base_path}'")
model_state_dict = torch.load(input_base_path, map_location="cpu")["model_state_dict"]
model_state_dict.pop("freqs_cis")
state_dict = {}
for k, v in model_state_dict.items():
if "qkv_proj" in k:
if num_kv_heads == 1:
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, hidden_size)
q_proj = v[:num_attn_heads, ...]
k_proj = v[num_attn_heads : num_attn_heads + num_kv_heads, ...].repeat(num_kv_heads, 1, 1)
v_proj = v[-num_kv_heads:, ...].repeat(num_kv_heads, 1, 1)
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
num_attn_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
num_kv_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj[0].clone()
else:
q_proj, k_proj, v_proj = torch.split(v, v.shape[0] // 3, 0)
state_dict[k.replace("qkv_proj", "q_proj")] = q_proj.reshape(
num_attn_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "k_proj")] = k_proj.reshape(
num_kv_heads * head_dim, hidden_size
).clone()
state_dict[k.replace("qkv_proj", "v_proj")] = v_proj.clone()
elif k == "embedder.weight":
state_dict[LAYER_NAME_MAPPING[k]] = v
state_dict["lm_head.weight"] = v
else:
state_dict[k] = v
print("Loading the checkpoint in a Gemma model.")
with init_empty_weights():
model = GemmaForCausalLM(config)
model.load_state_dict(state_dict, assign=True, strict=False)
model.config.torch_dtype = torch.float32
del model.config._name_or_path
print("Saving in the Transformers format.")
if push_to_hub:
print(f"pushing the model to {save_path}")
model.push_to_hub(save_path, safe_serialization=safe_serialization, private=True)
else:
model.save_pretrained(save_path, safe_serialization=safe_serialization)
def write_tokenizer(input_tokenizer_path, save_path, push_to_hub=False):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = GemmaTokenizer if GemmaTokenizerFast is None else GemmaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {save_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
if push_to_hub:
tokenizer.push_to_hub(save_path)
else:
tokenizer.save_pretrained(save_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_checkpoint",
help="Absolute path to the target Gemma weights.",
required=True,
)
parser.add_argument(
"--tokenizer_checkpoint",
help="Location of Gemma tokenizer model",
)
parser.add_argument(
"--model_size",
default="7B",
choices=["2B", "7B", "tokenizer_only"],
help="'f' models correspond to the finetuned versions, and are specific to the Gemma2 official release. For more details on Gemma2, checkout the original repo: https://huggingface.co/google/gemma-7b",
)
parser.add_argument(
"--output_dir",
default="google/gemma-7b",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--pickle_serialization",
help="Whether or not to save using `safetensors`.",
action="store_true",
default=False,
)
parser.add_argument(
"--convert_tokenizer",
help="Whether or not to convert the tokenizer as well.",
action="store_true",
default=False,
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
args = parser.parse_args()
if args.convert_tokenizer:
if args.tokenizer_checkpoint is None:
raise ValueError("Path to the tokenizer is required when passing --convert_tokenizer")
spm_path = os.path.join(args.tokenizer_checkpoint)
write_tokenizer(spm_path, args.output_dir, args.push_to_hub)
config = CONFIG_MAPPING[args.model_size]
write_model(
config=config,
input_base_path=args.input_checkpoint,
save_path=args.output_dir,
safe_serialization=not args.pickle_serialization,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,763 @@
# coding=utf-8
# Copyright 2024 Google Inc., and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax Gemma model."""
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GemmaConfig"
_CHECKPOINT_FOR_DOC = "google/gemma-2b"
_REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_llama_3b_v2"
GEMMA_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a Flax Linen
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`GemmaConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
`jax.numpy.bfloat16`.
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
GEMMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
def create_sinusoidal_positions(num_pos, dim):
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim))
freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
emb = np.concatenate((freqs, freqs), axis=-1)
out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
return jnp.array(out[:, :, :num_pos])
# Copied from transformers.models.llama.modeling_flax_llama.rotate_half
def rotate_half(tensor):
"""Rotates half the hidden dims of the input."""
rotate_half_tensor = jnp.concatenate(
(-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
)
return rotate_half_tensor
# Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
class FlaxGemmaRMSNorm(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.epsilon = self.config.rms_norm_eps
self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
def __call__(self, hidden_states):
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
variance = jnp.power(variance, 2)
variance = variance.mean(-1, keepdims=True)
# use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
return (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype)
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma
class FlaxGemmaRotaryEmbedding(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
# Ignore copy
def setup(self):
head_dim = self.config.head_dim
self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
def __call__(self, key, query, position_ids):
sincos = self.sincos[position_ids]
sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
key = jnp.asarray(key, dtype=self.dtype)
query = jnp.asarray(query, dtype=self.dtype)
return key, query
class FlaxGemmaAttention(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
causal: bool = True
is_cross_attention: bool = False
def setup(self):
config = self.config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
kernel = jax.nn.initializers.normal(self.config.initializer_range)
self.q_proj = nn.Dense(
self.num_heads * self.head_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel
)
self.k_proj = nn.Dense(
self.num_key_value_heads * self.head_dim,
use_bias=config.attention_bias,
dtype=self.dtype,
kernel_init=kernel,
)
self.v_proj = nn.Dense(
self.num_key_value_heads * self.head_dim,
use_bias=config.attention_bias,
dtype=self.dtype,
kernel_init=kernel,
)
self.o_proj = nn.Dense(self.embed_dim, use_bias=config.attention_bias, dtype=self.dtype, kernel_init=kernel)
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype)
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,))
@nn.compact
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states,
attention_mask,
position_ids,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query, self.num_heads)
key = self._split_heads(key, self.num_key_value_heads)
value = self._split_heads(value, self.num_key_value_heads)
key, query = self.rotary_emb(key, query, position_ids)
query_length, key_length = query.shape[1], key.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
dropout_rng = None
if not deterministic and self.config.attention_dropout > 0.0:
dropout_rng = self.make_rng("dropout")
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache:
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
# transform boolean mask into float mask
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2)
# usual dot product attention
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
attn_weights = dot_product_attention_weights(
query,
key,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_dropout,
deterministic=deterministic,
dtype=attention_dtype,
)
if self.attention_softmax_in_fp32:
attn_weights = attn_weights.astype(self.dtype)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
attn_output = self._merge_heads(attn_output)
attn_output = self.o_proj(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Gemma
class FlaxGemmaMLP(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.act = ACT2FN[self.config.hidden_act]
self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
def __call__(self, hidden_states):
up_proj_states = self.up_proj(hidden_states)
gate_states = self.act(self.gate_proj(hidden_states))
hidden_states = self.down_proj(up_proj_states * gate_states)
return hidden_states
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma
class FlaxGemmaDecoderLayer(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.input_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)
self.self_attn = FlaxGemmaAttention(self.config, dtype=self.dtype)
self.post_attention_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)
self.mlp = FlaxGemmaMLP(self.config, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
outputs = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
# residual connection
attn_output = outputs[0]
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + hidden_states
return (hidden_states,) + outputs[1:]
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma, GPT_NEO->GEMMA, transformer->model
class FlaxGemmaPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GemmaConfig
base_model_prefix = "model"
module_class: nn.Module = None
def __init__(
self,
config: GemmaConfig,
input_shape: Tuple = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
def init_cache(self, batch_size, max_length):
r"""
Args:
batch_size (`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
"""
# init input variables to retrieve cache
input_ids = jnp.ones((batch_size, max_length))
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
batch_size, sequence_length = input_ids.shape
if position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemmaAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
outputs = self.module.apply(
inputs,
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
False,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=mutable,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past_key_values = outputs
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past_key_values = outputs
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma
class FlaxGemmaLayerCollection(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.blocks = [
FlaxGemmaDecoderLayer(self.config, dtype=self.dtype, name=str(i))
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for block in self.blocks:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
# this contains possible `None` values - `FlaxGemmaModule` will filter them out
outputs = (hidden_states, all_hidden_states, all_attentions)
return outputs
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma
class FlaxGemmaModule(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.hidden_size = self.config.hidden_size
embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.hidden_size,
embedding_init=embedding_init,
dtype=self.dtype,
)
self.layers = FlaxGemmaLayerCollection(self.config, dtype=self.dtype)
self.norm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype)
# Ignore copy
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic=True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
input_embeds = self.embed_tokens(input_ids.astype("i4"))
input_embeds = input_embeds * (self.config.hidden_size**0.5)
outputs = self.layers(
input_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states = outputs[1] + (hidden_states,)
outputs = (hidden_states, all_hidden_states) + outputs[2:]
else:
outputs = (hidden_states,) + outputs[1:]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=outputs[1],
attentions=outputs[-1],
)
@add_start_docstrings(
"The bare Gemma Model transformer outputting raw hidden-states without any specific head on top.",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma
class FlaxGemmaModel(FlaxGemmaPreTrainedModel):
module_class = FlaxGemmaModule
append_call_sample_docstring(
FlaxGemmaModel,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutput,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)
# Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma
class FlaxGemmaForCausalLMModule(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.model = FlaxGemmaModule(self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
# Ignore copy
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
outputs = self.model(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@add_start_docstrings(
"""
The Gemma Model transformer with a language modeling head (linear layer) on top.
""",
GEMMA_START_DOCSTRING,
)
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma
class FlaxGemmaForCausalLM(FlaxGemmaPreTrainedModel):
module_class = FlaxGemmaForCausalLMModule
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape
past_key_values = self.init_cache(batch_size, max_length)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since Gemma uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if attention_mask is not None:
position_ids = attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"attention_mask": extended_attention_mask,
"position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
return model_kwargs
append_call_sample_docstring(
FlaxGemmaForCausalLM,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutput,
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,326 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Gemma."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
if TYPE_CHECKING:
pass
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
SPIECE_UNDERLINE = ""
class GemmaTokenizer(PreTrainedTokenizer):
"""
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Gemma should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__getstate__
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__setstate__
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
"""
return self.sp_model.encode(text, out_type=str)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)
return sub_texts
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self._added_tokens_encoder:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (
bos_token_id
+ ([0] * len(token_ids_0))
+ eos_token_id
+ bos_token_id
+ ([0] * len(token_ids_1))
+ eos_token_id
)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output

View File

@ -0,0 +1,199 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from shutil import copyfile
from typing import Optional, Tuple
from tokenizers import processors
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
from ...utils.versions import require_version
require_version("tokenizers>=0.13.3")
if is_sentencepiece_available():
from .tokenization_gemma import GemmaTokenizer
else:
GemmaTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
class GemmaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Gemma tokenizer fast. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no prefix space. Normalization is applied to replace `" "` with `""`
```python
>>> from transformers import GemmaTokenizerFast
>>> tokenizer = GemmaTokenizerFast.from_pretrained("hf-internal-testing/dummy-gemma")
>>> tokenizer.encode("Hello this is a test")
[2, 4521, 736, 603, 476, 2121]
```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`, *optional*):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The padding token
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
"""
vocab_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = GemmaTokenizer
padding_side = "left"
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
add_bos_token=True,
add_eos_token=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
# Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output

View File

@ -737,6 +737,27 @@ class FlaxEncoderDecoderModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxGemmaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGemmaModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGPT2LMHeadModel(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -3811,6 +3811,34 @@ class FuyuPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class GemmaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GemmaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GemmaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GemmaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
GIT_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -79,6 +79,13 @@ class FNetTokenizer(metaclass=DummyObject):
requires_backends(self, ["sentencepiece"])
class GemmaTokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
class GPTSw3Tokenizer(metaclass=DummyObject):
_backends = ["sentencepiece"]

View File

@ -170,6 +170,13 @@ class FunnelTokenizerFast(metaclass=DummyObject):
requires_backends(self, ["tokenizers"])
class GemmaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
class GPT2TokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

View File

View File

@ -0,0 +1,267 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import AutoTokenizer, GemmaConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.models.gemma.modeling_flax_gemma import (
FlaxGemmaForCausalLM,
FlaxGemmaModel,
)
class FlaxGemmaModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = np.tril(np.ones((self.batch_size, self.seq_length)))
config = GemmaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.hidden_size // self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
use_cache=True,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, input_ids, input_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
position_ids=position_ids,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
position_ids=position_ids,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGemmaModel, FlaxGemmaForCausalLM) if is_flax_available() else ()
all_generative_model_classes = (FlaxGemmaForCausalLM,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxGemmaModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("google/gemma-2b", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxGemmaIntegrationTest(unittest.TestCase):
input_text = ["The capital of France is", "To play the perfect cover drive"]
model_id = "google/gemma-2b"
revision = "flax"
def setUp(self):
self.model, self.params = FlaxGemmaForCausalLM.from_pretrained(
self.model_id, revision=self.revision, _do_init=False
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.padding_side = "left"
def test_logits(self):
inputs = self.tokenizer(self.input_text, return_tensors="np", padding=True)
# fmt: off
EXPECTED_MEAN = [
[-16.427, -21.386, -35.491, -36.258, -31.401, -36.370, -37.598],
[-21.386, -32.150, -33.155, -34.344, -34.706, -34.678, -38.495],
]
EXPECTED_SLICE = [-33.462, -16.481, -30.837, -32.195, -33.113]
# fmt: on
logits = self.model(**inputs, params=self.params).logits
diff_mean = jnp.abs(logits.mean(-1) - np.array(EXPECTED_MEAN)).max()
diff_slice = jnp.abs(logits[0, -1, 475:480] - np.array(EXPECTED_SLICE)).max()
self.assertAlmostEqual(diff_mean, 0, places=3)
self.assertAlmostEqual(diff_slice, 0, places=3)
def test_generation(self):
EXPECTED_TEXTS = [
"The capital of France is a city of contrasts. It is a city of history, of art, of culture, of fashion",
"To play the perfect cover drive, you need to have a good technique and a good mindset.\n\nThe cover drive is a shot",
]
inputs = self.tokenizer(self.input_text, return_tensors="np", padding=True)
output = self.model.generate(**inputs, params=self.params, max_new_tokens=20, do_sample=False)
output_text = self.tokenizer.batch_decode(output.sequences, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_jit_generation(self):
EXPECTED_TEXTS = [
"The capital of France is a city of contrasts. It is a city of history, culture, and art, but it is",
"To play the perfect cover drive, you need to have a good technique and a good mindset.\n\nThe cover drive is a shot",
]
inputs = self.tokenizer(self.input_text, return_tensors="np", padding=True)
def generate(input_ids, attention_mask):
outputs = self.model.generate(
input_ids, attention_mask=attention_mask, params=self.params, max_new_tokens=20, do_sample=False
)
return outputs
jit_generate = jax.jit(generate)
output_sequences = jit_generate(**inputs).sequences
output_text = self.tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@ -0,0 +1,656 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Gemma model. """
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel
class GemmaModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
self.head_dim = self.hidden_size // self.num_attention_heads
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
# Ignore copy
def get_config(self):
return GemmaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
head_dim=self.head_dim,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = GemmaModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = GemmaModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = GemmaForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = GemmaForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Gemma
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (GemmaForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GemmaModel,
"text-classification": GemmaForSequenceClassification,
"text-generation": GemmaForCausalLM,
"zero-shot": GemmaForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
def setUp(self):
self.model_tester = GemmaModelTester(self)
self.config_tester = ConfigTester(self, config_class=GemmaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_Gemma_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_Gemma_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_Gemma_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = GemmaForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("Gemma uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Gemma apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
self.skipTest("Gemma flash attention does not support right padding")
@require_torch_gpu
@slow
class GemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
def test_model_2b_fp32(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_2b_fp16(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_2b_fp16_static_cache(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
model.generation_config.cache_implementation = "static"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_2b_bf16(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_bitsandbytes
def test_model_2b_4bit(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project and I need to make a 3d model of a house. I have been using",
"Hi today I'd like to share with you my experience with the new wattpad wattpad wattpad wattpad wattpad wattpad wattpad",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@unittest.skip("The test will not fit our CI runners")
def test_model_7b_fp32(self):
model_id = "google/gemma-7b"
EXPECTED_TEXTS = [
"Hello my name is ***** ***** I will be assisting you today. I am sorry to hear about your issue. I will",
"Hi,\n\nI have a problem with my 2005 1.6 16",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_7b_fp16(self):
model_id = "google/gemma-7b"
EXPECTED_TEXTS = [
"""Hello I am doing a project on a 1999 4.0L 4x4. I""",
"Hi today I am going to show you how to make a simple and easy to make a DIY 3D",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_7b_bf16(self):
model_id = "google/gemma-7b"
EXPECTED_TEXTS = [
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_7b_fp16_static_cache(self):
model_id = "google/gemma-7b"
EXPECTED_TEXTS = [
"""Hello I am doing a project on a 1999 4.0L 4x4. I""",
"Hi today I am going to show you how to make a simple and easy to make a DIY 3D",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
model.generation_config.cache_implementation = "static"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_bitsandbytes
def test_model_7b_4bit(self):
model_id = "google/gemma-7b"
EXPECTED_TEXTS = [
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@ -0,0 +1,497 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
from datasets import load_dataset
from transformers import (
AddedToken,
GemmaTokenizer,
GemmaTokenizerFast,
is_torch_available,
)
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
require_jinja,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
)
from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
pass
@require_sentencepiece
@require_tokenizers
class GemmaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GemmaTokenizer
rust_tokenizer_class = GemmaTokenizerFast
test_rust_tokenizer = False
test_sentencepiece = True
from_pretrained_kwargs = {}
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(self.tmpdirname)
@require_torch
def test_batch_tokenization(self):
if not self.test_seq2seq:
return
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Longer text that will definitely require truncation.
text = [
" UN Chief Says There Is No Military Solution in Syria",
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
" Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
" will only worsen the violence and misery for millions of people.",
]
try:
batch = tokenizer(
text=text,
max_length=3,
max_target_length=10,
return_tensors="pt",
)
except NotImplementedError:
return
self.assertEqual(batch.input_ids.shape[1], 3)
# max_target_length will default to max_length if not specified
batch = tokenizer(text, max_length=3, return_tensors="pt")
self.assertEqual(batch.input_ids.shape[1], 3)
batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt")
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
self.assertNotIn("decoder_input_ids", batch_encoder_only)
@unittest.skip("Unfortunately way too slow to build a BPE with SentencePiece.")
def test_save_slow_from_fast_and_reload_fast(self):
pass
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
r_output = tokenizer_r.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertTrue(special_token_id in r_output)
if self.test_slow_tokenizer:
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name,
additional_special_tokens=added_tokens,
**kwargs, # , from_slow=True <- unfortunately too slow to convert
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)
@slow
def test_tokenizer_integration(self):
expected_encoding = {'input_ids': [[2, 158434, 591, 84193, 3836, 685, 6599, 31223, 235290, 140247, 578, 6599, 31223, 235290, 145139, 235290, 3491, 235275, 6572, 3311, 235290, 38197, 109959, 591, 25894, 235269, 162174, 235290, 235284, 235269, 1791, 6362, 12481, 235269, 1576, 18622, 235269, 2900, 1136, 86684, 235269, 29092, 4632, 16994, 604, 13146, 14944, 40371, 591, 19700, 235327, 235275, 578, 13146, 14944, 25511, 591, 235300, 12474, 235275, 675, 1163, 235248, 235304, 235284, 235340, 229903, 5377, 575, 235248, 235274, 235276, 235276, 235340, 17044, 578, 5271, 1061, 118345, 1865, 125247, 235269, 8745, 111226, 578, 176888, 235265], [2, 25894, 603, 6869, 577, 953, 235290, 8297, 5271, 209099, 41642, 774, 748, 78253, 2793, 731, 51506, 34346, 611, 2145, 2731, 578, 1833, 4807, 575, 832, 16630, 235265], [2, 651, 4320, 8426, 25341, 36271, 1163, 573, 27894, 5929, 235265]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
self.tokenizer_integration_test_util(
expected_encoding=expected_encoding,
model_name="hf-internal-testing/dummy-gemma",
revision="",
padding=False,
)
@unittest.skip("worker 'gw4' crashed on CI, passing locally.")
def test_pickle_subword_regularization_tokenizer(self):
pass
@unittest.skip("worker 'gw4' crashed on CI, passing locally.")
def test_subword_regularization_tokenizer(self):
pass
@unittest.skip("This test will be removed from main @LysandreJik")
def test_pretrained_model_lists(self):
pass
@unittest.skip("Skipping")
def test_torch_encode_plus_sent_to_model(self):
pass
@require_torch
@require_sentencepiece
@require_tokenizers
class GemmaIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
checkpoint_name = "hf-internal-testing/dummy-gemma"
cls.tokenizer: GemmaTokenizer = GemmaTokenizer.from_pretrained(
checkpoint_name, eos_token="<s>"
) # add this token
cls.rust_tokenizer = GemmaTokenizerFast.from_pretrained(
checkpoint_name, eos_token="<s>", from_slow=True
) # add this token
return cls
@require_torch
def integration_tests(self):
inputs = self.tokenizer(
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
return_tensors="pt",
)
self.assertEqual(
nested_simplify(inputs),
{
"input_ids": [
[2, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889],
[2, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718],
],
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
},
)
def test_fast_special_tokens(self):
slow_tokenizer = self.tokenizer
fast_tokenizer = self.rust_tokenizer
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [2, 235280, 6453, 2121]
fast_tokenizer.add_eos_token = False
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [2, 235280, 6453, 2121]
fast_tokenizer.add_eos_token = True
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [2, 235280, 6453, 2121, 204]
slow_tokenizer.add_eos_token = True
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
assert slow == [2, 235280, 6453, 2121, 204]
self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False
@unittest.skip("Not super important and always failing. Let's skip it")
@slow
def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge
# list from the original vocabulary in spm
self.rust_tokenizer.save_pretrained("./out")
with tempfile.TemporaryDirectory() as dirname:
self.rust_tokenizer.save_pretrained(dirname)
with open(os.path.join(dirname, "tokenizer.json"), "r") as f:
old_serialized = f.read()
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
with tempfile.NamedTemporaryFile() as f:
new_tokenizer.save(f.name)
# Re-opening since `f` is in bytes.
new_serialized = open(f.name, "r").read()
with open("out_tokenizer.json", "w") as g:
g.write(new_serialized)
self.assertEqual(old_serialized, new_serialized)
def test_simple_encode_decode(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False
self.assertEqual(pyth_tokenizer.encode("This is a test"), [2, 1596, 603, 476, 2121])
self.assertEqual(rust_tokenizer.encode("This is a test"), [2, 1596, 603, 476, 2121])
self.assertEqual(pyth_tokenizer.decode([2, 1596, 603, 476, 2121], skip_special_tokens=True), "This is a test")
self.assertEqual(rust_tokenizer.decode([2, 1596, 603, 476, 2121], skip_special_tokens=True), "This is a test")
# bytefallback showcase
self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [2, 122182, 235710, 245467, 235427] ) # fmt: skip
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [2, 122182, 235710, 245467, 235427] ) # fmt: skip
self.assertEqual(
pyth_tokenizer.decode([2, 122182, 235710, 245467, 235427], skip_special_tokens=True),
"生活的真谛是",
)
self.assertEqual(
rust_tokenizer.decode([2, 122182, 235710, 245467, 235427], skip_special_tokens=True),
"生活的真谛是",
)
# Inner spaces showcase
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [2, 2151, 139, 4521])
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2, 2151, 139, 4521])
self.assertEqual(pyth_tokenizer.decode([2, 2151, 139, 4521], skip_special_tokens=True), "Hi Hello")
self.assertEqual(rust_tokenizer.decode([2, 2151, 139, 4521], skip_special_tokens=True), "Hi Hello")
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [2, 2151, 140, 4521])
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2, 2151, 140, 4521])
self.assertEqual(pyth_tokenizer.decode([2, 2151, 140, 4521], skip_special_tokens=True), "Hi Hello")
self.assertEqual(rust_tokenizer.decode([2, 2151, 140, 4521], skip_special_tokens=True), "Hi Hello")
self.assertEqual(pyth_tokenizer.encode(""), [2])
self.assertEqual(rust_tokenizer.encode(""), [2])
self.assertEqual(pyth_tokenizer.encode(" "), [2, 235248])
self.assertEqual(rust_tokenizer.encode(" "), [2, 235248])
self.assertEqual(pyth_tokenizer.encode(" "), [2, 139])
self.assertEqual(rust_tokenizer.encode(" "), [2, 139])
self.assertEqual(pyth_tokenizer.encode(" Hello"), [2, 25957])
self.assertEqual(rust_tokenizer.encode(" Hello"), [2, 25957])
def test_no_differences_decode(self):
self.tokenizer.add_eos_token = False
self.rust_tokenizer.add_eos_token = False
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.decode([869]), "og")
self.assertEqual(rust_tokenizer.decode([869]), "og")
self.assertEqual(pyth_tokenizer.decode([30112, 869]), " expenditureog")
self.assertEqual(rust_tokenizer.decode([30112, 869]), " expenditureog")
def test_no_differences_special_tokens(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode(""), [2])
self.assertEqual(rust_tokenizer.encode(""), [2])
self.assertEqual(pyth_tokenizer.encode("<s>"), [2, 204])
self.assertEqual(rust_tokenizer.encode("<s>"), [2, 204])
@unittest.skipIf(
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
)
def test_integration_test_xnli(self):
import tqdm
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
dataset = load_dataset("code_x_glue_ct_code_to_text", "go")
for item in tqdm.tqdm(dataset["validation"]):
string = item["code"]
encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded1, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
dataset = load_dataset("xnli", "all_languages")
for item in tqdm.tqdm(dataset["train"]):
for string in item["premise"].values():
encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2)
def test_special_token_special_word(self):
# the word inform should be split as ['in', 'form']
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
out1 = tokenizer.decode(
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
)
self.assertEqual(out1, "<REPR_END>inform")
out2 = tokenizer.decode(
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
)
# decoding strips the added prefix space.
self.assertEqual(out2, "<REPR_END> inform")
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
self.assertEqual(input_ids, [256000, 43910])
out2 = tokenizer.decode(
tokenizer.encode(" <REPR_END> inform", add_special_tokens=False), spaces_between_special_tokens=False
)
# TODO @ArthurZ currently we strip left and right, so this will not keep the spaces
self.assertEqual(out2, "<REPR_END>inform")
### Let's make sure decoding does not add extra spaces here and there
# TODO @ArthurZ this should be affected by the lstrip/rstrip/single word /normalize refactoring
# Since currently we always strip left and right of the token, results are as such
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
self.assertEqual(input_ids, [204, 25957, 204, 1139])
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
# Let's make sure that if there are any spaces, we don't remove them!
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(input_ids, [235248, 204, 25957, 204, 1368])
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
self.assertEqual(tokens, ["", "<s>", "▁Hello", "<s>", "▁how"])
decoded_tokens = tokenizer.decode(input_ids)
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
def test_some_edge_cases(self):
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
self.assertEqual(sp_tokens, ["<s>", ">"])
tokens = tokenizer.tokenize("<s>>")
self.assertEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["<s>", ">"])
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
tokens = tokenizer.tokenize(" ")
self.assertEqual(tokens, [""])
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str))
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [""])
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, ["▁▁"])
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁", out_type=str))
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
test_chats = [
[{"role": "user", "content": "Hello!"}],
[
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "user", "content": "Hello!"}],
]
# Matt: The third test case tests the default system message, but if this is ever changed in the
# class/repo code then that test will fail, and the case will need to be updated.
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [[235322, 235371, 571, 235298, 2997, 73786, 1645, 108, 4521, 149907, 235371, 571, 235298, 615, 73786, 108], [235322, 235371, 571, 235298, 2997, 73786, 1645, 108, 4521, 149907, 235371, 571, 235298, 615, 73786, 108, 235322, 235371, 571, 235298, 2997, 73786, 105776, 108, 7731, 577, 4664, 692, 35606, 235371, 571, 235298, 615, 73786, 108], [235322, 235371, 571, 235298, 2997, 73786, 1645, 108, 4521, 149907, 235371, 571, 235298, 615, 73786, 108]] # fmt: skip
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
@require_sentencepiece
@require_tokenizers
class CommonSpmIntegrationTests(unittest.TestCase):
"""
A class that regroups important test to make sure that we properly handle the special tokens.
"""
def test_edge_case_tabulation(self):
fast_tokenizer = GemmaTokenizerFast.from_pretrained("hf-internal-testing/dummy-gemma")
slow_tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
input_text = "Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"
EXPECTED_IDS = [ 2, 6750, 1, 235265, 235248, 255969, 235248, 109, 4747, 139, 235335, 139, 216311, 241316, 139, 239880, 235341, 144, 235269, 235248, 235274, 235284, 235304, 235310, 235248, 235274, 235308, 235248, 235308, 235269, 235318, 235274] # fmt: skip
EXPECTED_TOKENS = [ "Hey", "<eos>", ".", "", "\t\t", "", "\n\n", "you", "▁▁", "é", "▁▁", "@#", "😈", "▁▁", "🤗", "!", "▁▁▁▁▁▁▁", ",", "", "1", "2", "3", "4", "", "1", "5", "", "5", ",", "6", "1"] # fmt: skip
tokens = fast_tokenizer.tokenize(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(tokens, EXPECTED_TOKENS)
tokens = slow_tokenizer.tokenize(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(tokens, EXPECTED_TOKENS)
input_ids = fast_tokenizer.encode(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(input_ids, EXPECTED_IDS)
input_ids = slow_tokenizer.encode(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(input_ids, EXPECTED_IDS)
text = fast_tokenizer.decode(EXPECTED_IDS)
with self.subTest("test fast edge case fast"):
self.assertEqual(text, "<bos>Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61")
text = slow_tokenizer.decode(EXPECTED_IDS)
with self.subTest("test fast edge case fast"):
self.assertEqual(text, "<bos>Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61")
input_text = "\t\t\t\t \n\n61"
EXPECTED_IDS = [2, 255971, 235248, 109, 235318, 235274]
EXPECTED_TOKENS = ["\t\t\t\t", "", "\n\n", "6", "1"]
tokens = fast_tokenizer.tokenize(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(tokens, EXPECTED_TOKENS)
tokens = slow_tokenizer.tokenize(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(tokens, EXPECTED_TOKENS)
input_ids = fast_tokenizer.encode(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(input_ids, EXPECTED_IDS)
input_ids = slow_tokenizer.encode(input_text)
with self.subTest("test fast edge case fast"):
self.assertEqual(input_ids, EXPECTED_IDS)
text = fast_tokenizer.decode(EXPECTED_IDS)
with self.subTest("test fast edge case fast"):
self.assertEqual(text, "<bos>\t\t\t\t \n\n61")
text = slow_tokenizer.decode(EXPECTED_IDS)
with self.subTest("test fast edge case fast"):
self.assertEqual(text, "<bos>\t\t\t\t \n\n61")

View File

@ -233,6 +233,8 @@ OBJECTS_TO_IGNORE = [
"FlaxGPTNeoModel",
"FlaxLlamaForCausalLM",
"FlaxLlamaModel",
"FlaxGemmaForCausalLM",
"FlaxGemmaModel",
"FlaxMBartForConditionalGeneration",
"FlaxMBartForQuestionAnswering",
"FlaxMBartForSequenceClassification",

View File

@ -575,6 +575,10 @@ src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.
src/transformers/models/funnel/modeling_funnel.py
src/transformers/models/funnel/modeling_tf_funnel.py
src/transformers/models/fuyu/convert_fuyu_model_weights_to_hf.py
src/transformers/models/gemma/configuration_gemma.py
src/transformers/models/gemma/convert_gemma_weights_to_hf.py
src/transformers/models/gemma/modeling_flax_gemma.py
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/git/configuration_git.py
src/transformers/models/git/convert_git_to_pytorch.py
src/transformers/models/glpn/configuration_glpn.py