From a59e7c1ed4ef1568ca0ba4140c9af641b17fa37e Mon Sep 17 00:00:00 2001 From: Shang Zhang <69697986+shangz-ai@users.noreply.github.com> Date: Fri, 19 Nov 2021 10:33:39 -0800 Subject: [PATCH] Add QDQBert model and quantization examples of SQUAD task (#14066) * clean up branch for add-qdqbert-model * README update for QAT example; update docstrings in modeling_qdqbert.py * Update qdqbert.rst * Update README.md * Update README.md * calibration data using traning set; QAT example runs in fp32 * re-use BERTtokenizer for qdqbert * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docs/source/model_doc/qdqbert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove qdqbert tokenizer * Update qdqbert.rst * update evaluate-hf-trt-qa.py * update configuration_qdqbert.py * update modeling_qdqbert.py: add copied statement; replace assert with ValueError * update copied from statement * add is_quantization_available; run make fix-copies * unittest add require_quantization * add backend dependency to qdqbert model * update README; update evaluate script; make style * lint * docs qdqbert update * circleci build_doc add pytorch-quantization for qdqbert * update README * update example readme with instructions to upgrade TensorRT to 8.2 * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut * Update src/transformers/models/qdqbert/configuration_qdqbert.py Co-authored-by: Lysandre Debut * change quantization to pytorch_quantization for backend requirement * feed_forward_chunking not supported in QDQBert * make style * update model docstrings and comments in testing scripts * rename example to quantization-qdqbert; rename example scripts from qat to quant * Update src/transformers/models/qdqbert/modeling_qdqbert.py Co-authored-by: Patrick von Platen * rm experimental functions in quant_trainer * qa cleanup * make fix-copies for docs index.rst * fix doctree; use post_init() for qdqbert * fix early device assignment for qdqbert * fix CI:Model templates runner Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut Co-authored-by: Patrick von Platen --- .circleci/config.yml | 1 + README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/index.rst | 58 +- docs/source/model_doc/qdqbert.rst | 189 ++ .../quantization-qdqbert/Dockerfile | 37 + .../quantization-qdqbert/README.md | 197 ++ .../evaluate-hf-trt-qa.py | 456 +++++ .../quantization-qdqbert/quant_trainer.py | 303 +++ .../quantization-qdqbert/run_quant_qa.py | 668 +++++++ .../quantization-qdqbert/trainer_quant_qa.py | 212 ++ .../quantization-qdqbert/utils_qa.py | 427 ++++ src/transformers/__init__.py | 45 + src/transformers/file_utils.py | 19 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 9 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/qdqbert/__init__.py | 67 + .../models/qdqbert/configuration_qdqbert.py | 122 ++ .../models/qdqbert/modeling_qdqbert.py | 1750 +++++++++++++++++ src/transformers/testing_utils.py | 12 + ..._pytorch_quantization_and_torch_objects.py | 91 + tests/test_modeling_qdqbert.py | 563 ++++++ 26 files changed, 5209 insertions(+), 26 deletions(-) create mode 100644 docs/source/model_doc/qdqbert.rst create mode 100644 examples/research_projects/quantization-qdqbert/Dockerfile create mode 100644 examples/research_projects/quantization-qdqbert/README.md create mode 100755 examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py create mode 100755 examples/research_projects/quantization-qdqbert/quant_trainer.py create mode 100755 examples/research_projects/quantization-qdqbert/run_quant_qa.py create mode 100644 examples/research_projects/quantization-qdqbert/trainer_quant_qa.py create mode 100644 examples/research_projects/quantization-qdqbert/utils_qa.py create mode 100644 src/transformers/models/qdqbert/__init__.py create mode 100644 src/transformers/models/qdqbert/configuration_qdqbert.py create mode 100755 src/transformers/models/qdqbert/modeling_qdqbert.py create mode 100644 src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py create mode 100644 tests/test_modeling_qdqbert.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 07434fbab19..2f8b3f83a48 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -754,6 +754,7 @@ jobs: - run: pip install --upgrade pip - run: pip install ."[docs]" - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cpu.html + - run: pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com - save_cache: key: v0.4-build_doc-{{ checksum "setup.py" }} paths: diff --git a/README.md b/README.md index 4a336686612..2f220fcbfb4 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PhoBERT](https://huggingface.co/transformers/model_doc/phobert.html)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. 1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. +1. **[QDQBert](https://huggingface.co/transformers/model_doc/qdqbert.html)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. 1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. diff --git a/README_ko.md b/README_ko.md index c8d1e3278d5..1a1b958db77 100644 --- a/README_ko.md +++ b/README_ko.md @@ -266,6 +266,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PhoBERT](https://huggingface.co/transformers/model_doc/phobert.html)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. 1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. +1. **[QDQBert](https://huggingface.co/transformers/model_doc/qdqbert.html)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. 1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. diff --git a/README_zh-hans.md b/README_zh-hans.md index 6d565a33f79..1fb2c533094 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -290,6 +290,7 @@ conda install -c huggingface transformers 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 1. **[PhoBERT](https://huggingface.co/transformers/model_doc/phobert.html)** (来自 VinAI Research) 伴随论文 [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) 由 Dat Quoc Nguyen and Anh Tuan Nguyen 发布。 1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (来自 Microsoft Research) 伴随论文 [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 由 Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou 发布。 +1. **[QDQBert](https://huggingface.co/transformers/model_doc/qdqbert.html)** (来自 NVIDIA) 伴随论文 [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) 由 Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius 发布。 1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。 1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。 1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 4d2fb3464de..949bcb4f7b3 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -302,6 +302,7 @@ conda install -c huggingface transformers 1. **[Pegasus](https://huggingface.co/transformers/model_doc/pegasus.html)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[PhoBERT](https://huggingface.co/transformers/model_doc/phobert.html)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. 1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. +1. **[QDQBert](https://huggingface.co/transformers/model_doc/qdqbert.html)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. 1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. diff --git a/docs/source/index.rst b/docs/source/index.rst index daea16ded25..c2ce07b2a08 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -270,85 +270,88 @@ Supported models 57. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -58. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +58. :doc:`QDQBert ` (from NVIDIA) released with the paper `Integer Quantization for Deep Learning + Inference: Principles and Empirical Evaluation `__ by Hao Wu, Patrick Judd, + Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. +59. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -59. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in +60. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in pre-trained language models `__ by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. -60. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +61. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -61. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: +62. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: Enhanced Transformer with Rotary Position Embedding `__ by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -62. :doc:`SegFormer ` (from NVIDIA) released with the paper `SegFormer: Simple and Efficient +63. :doc:`SegFormer ` (from NVIDIA) released with the paper `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers `__ by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. -63. :doc:`SEW ` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in Unsupervised +64. :doc:`SEW ` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition `__ by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. -64. :doc:`SEW-D ` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in +65. :doc:`SEW-D ` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition `__ by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi. -65. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +66. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -66. :doc:`SpeechToTextTransformer2 ` (from Facebook), released together with the paper +67. :doc:`SpeechToTextTransformer2 ` (from Facebook), released together with the paper `Large-Scale Self- and Semi-Supervised Learning for Speech Translation `__ by Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau. -67. :doc:`Splinter ` (from Tel Aviv University), released together with the paper `Few-Shot +68. :doc:`Splinter ` (from Tel Aviv University), released together with the paper `Few-Shot Question Answering by Pretraining Span Selection `__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. -68. :doc:`SqueezeBert ` (from Berkeley) released with the paper `SqueezeBERT: What can computer +69. :doc:`SqueezeBert ` (from Berkeley) released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -69. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +70. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -70. :doc:`T5v1.1 ` (from Google AI) released in the repository +71. :doc:`T5v1.1 ` (from Google AI) released in the repository `google-research/text-to-text-transfer-transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -71. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +72. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -72. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +73. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -73. :doc:`TrOCR ` (from Microsoft), released together with the paper `TrOCR: Transformer-based Optical +74. :doc:`TrOCR ` (from Microsoft), released together with the paper `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models `__ by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. -74. :doc:`UniSpeech ` (from Microsoft Research) released with the paper `UniSpeech: Unified Speech +75. :doc:`UniSpeech ` (from Microsoft Research) released with the paper `UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data `__ by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. -75. :doc:`UniSpeechSat ` (from Microsoft Research) released with the paper `UNISPEECH-SAT: +76. :doc:`UniSpeechSat ` (from Microsoft Research) released with the paper `UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING `__ by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu. -76. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +77. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -77. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and +78. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and Performant Baseline for Vision and Language `__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. -78. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +79. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -79. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +80. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -80. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +81. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -81. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +82. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -82. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +83. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -83. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +84. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -464,6 +467,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | @@ -658,6 +663,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/pegasus model_doc/phobert model_doc/prophetnet + model_doc/qdqbert model_doc/rag model_doc/reformer model_doc/rembert diff --git a/docs/source/model_doc/qdqbert.rst b/docs/source/model_doc/qdqbert.rst new file mode 100644 index 00000000000..fa54c2b849a --- /dev/null +++ b/docs/source/model_doc/qdqbert.rst @@ -0,0 +1,189 @@ +.. + Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +QDQBERT +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The QDQBERT model can be referenced in `Integer Quantization for Deep Learning Inference: Principles and Empirical +Evaluation `__ by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius +Micikevicius. + +The abstract from the paper is the following: + +*Quantization techniques can reduce the size of Deep Neural Networks and improve inference latency and throughput by +taking advantage of high throughput integer instructions. In this paper we review the mathematical aspects of +quantization parameters and evaluate their choices on a wide range of neural network models for different application +domains, including vision, speech, and language. We focus on quantization techniques that are amenable to acceleration +by processors with high-throughput integer math pipelines. We also present a workflow for 8-bit quantization that is +able to maintain accuracy within 1% of the floating-point baseline on all networks studied, including models that are +more difficult to quantize, such as MobileNets and BERT-large.* + +Tips: + +- QDQBERT model adds fake quantization operations (pair of QuantizeLinear/DequantizeLinear ops) to (i) linear layer + inputs and weights, (ii) matmul inputs, (iii) residual add inputs, in BERT model. + +- QDQBERT requires the dependency of `Pytorch Quantization Toolkit + `__. To install ``pip install + pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com`` + +- QDQBERT model can be loaded from any checkpoint of HuggingFace BERT model (for example *bert-base-uncased*), and + perform Quantization Aware Training/Post Training Quantization. + +- A complete example of using QDQBERT model to perform Quatization Aware Training and Post Training Quantization for + SQUAD task can be found at `transformers/examples/research_projects/quantization-qdqbert/ + `_. + +This model was contributed by `shangz `__. + + +Set default quantizers +_______________________________________________________________________________________________________________________ + +QDQBERT model adds fake quantization operations (pair of QuantizeLinear/DequantizeLinear ops) to BERT by +:obj:`TensorQuantizer` in `Pytorch Quantization Toolkit +`__. :obj:`TensorQuantizer` is the module +for quantizing tensors, with :obj:`QuantDescriptor` defining how the tensor should be quantized. Refer to `Pytorch +Quantization Toolkit userguide +`__ for more details. + +Before creating QDQBERT model, one has to set the default :obj:`QuantDescriptor` defining default tensor quantizers. +Example: + +.. code-block:: + + >>> import pytorch_quantization.nn as quant_nn + >>> from pytorch_quantization.tensor_quant import QuantDescriptor + + >>> # The default tensor quantizer is set to use Max calibration method + >>> input_desc = QuantDescriptor(num_bits=8, calib_method="max") + >>> # The default tensor quantizer is set to be per-channel quantization for weights + >>> weight_desc = QuantDescriptor(num_bits=8, axis=((0,))) + >>> quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) + >>> quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) + + +Calibration +_______________________________________________________________________________________________________________________ + +Calibration is the terminology of passing data samples to the quantizer and deciding the best scaling factors for +tensors. After setting up the tensor quantizers, one can use the following example to calibrate the model: + +.. code-block:: + + >>> # Find the TensorQuantizer and enable calibration + >>> for name, module in model.named_modules(): + >>> if name.endswith('_input_quantizer'): + >>> module.enable_calib() + >>> module.disable_quant() # Use full precision data to calibrate + + >>> # Feeding data samples + >>> model(x) + >>> # ... + + >>> # Finalize calibration + >>> for name, module in model.named_modules(): + >>> if name.endswith('_input_quantizer'): + >>> module.load_calib_amax() + >>> module.enable_quant() + + >>> # If running on GPU, it needs to call .cuda() again because new tensors will be created by calibration process + >>> model.cuda() + + >>> # Keep running the quantized model + >>> # ... + + +Export to ONNX +_______________________________________________________________________________________________________________________ + +The goal of exporting to ONNX is to deploy inference by `TensorRT `__. Fake +quantization will be broken into a pair of QuantizeLinear/DequantizeLinear ONNX ops. After setting static member of +TensorQuantizer to use Pytorch’s own fake quantization functions, fake quantized model can be exported to ONNX, follow +the instructions in `torch.onnx `__. Example: + +.. code-block:: + + >>> from pytorch_quantization.nn import TensorQuantizer + >>> TensorQuantizer.use_fb_fake_quant = True + + >>> # Load the calibrated model + >>> ... + >>> # ONNX export + >>> torch.onnx.export(...) + + +QDQBertConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertConfig + :members: + + +QDQBertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertModel + :members: forward + + +QDQBertLMHeadModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertLMHeadModel + :members: forward + + +QDQBertForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForMaskedLM + :members: forward + + +QDQBertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForSequenceClassification + :members: forward + + +QDQBertForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForNextSentencePrediction + :members: forward + + +QDQBertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForMultipleChoice + :members: forward + + +QDQBertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForTokenClassification + :members: forward + + +QDQBertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.QDQBertForQuestionAnswering + :members: forward + diff --git a/examples/research_projects/quantization-qdqbert/Dockerfile b/examples/research_projects/quantization-qdqbert/Dockerfile new file mode 100644 index 00000000000..b7fb54f955a --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/Dockerfile @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +FROM nvcr.io/nvidia/pytorch:21.07-py3 +LABEL maintainer="Hugging Face" +LABEL repository="transformers" + +RUN apt-get update +RUN apt-get install sudo + +RUN python3 -m pip install --no-cache-dir --upgrade pip +RUN python3 -m pip install --no-cache-dir --ignore-installed ruamel.yaml \ + mkl \ + absl-py \ + yamlpy \ + tensorboardX +RUN python3 -m pip install --no-cache-dir \ + pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com + +WORKDIR /workspace +COPY . transformers/ +RUN cd transformers/ && \ + python3 -m pip install --no-cache-dir . + +RUN python3 -m pip install --no-cache-dir datasets \ + accelerate \ No newline at end of file diff --git a/examples/research_projects/quantization-qdqbert/README.md b/examples/research_projects/quantization-qdqbert/README.md new file mode 100644 index 00000000000..e00e1864901 --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/README.md @@ -0,0 +1,197 @@ + + +# Huggingface QDQBERT Quantization Example + +The QDQBERT model adds fake quantization (pair of QuantizeLinear/DequantizeLinear ops) to: + * linear layer inputs and weights + * matmul inputs + * residual add inputs + +In this example, we use QDQBERT model to do quantization on SQuAD task, including Quantization Aware Training (QAT), Post Training Quantization (PTQ) and inferencing using TensorRT. + +Required: +- [pytorch-quantization toolkit](https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization) +- [TensorRT >= 8.2](https://developer.nvidia.com/tensorrt) +- PyTorch >= 1.10.0 + +## Setup the environment with Dockerfile + +Under the directory of `transformers/`, build the docker image: +``` +docker build . -f examples/research_projects/quantization-qdqbert/Dockerfile -t bert_quantization:latest +``` + +Run the docker: +``` +docker run --gpus all --privileged --rm -it --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 bert_quantization:latest +``` + +*Note that the current NGC pytorch container (pytorch:21.07-py3) has TensorRT 8.0 which doesn't meet the requiremnt of TensorRT >= 8.2. One can either update the Dockerfile with the latest [NGC pytorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) once it supports TensorRT 8.2, or manually download and install [TensorRT >= 8.2](https://developer.nvidia.com/nvidia-tensorrt-download) in the container.* + + +In the container: +``` +cd transformers/examples/research_projects/quantization-qdqbert/ +``` + +## Quantization Aware Training (QAT) + +Calibrate the pretrained model and finetune with quantization awared: + +``` +python3 run_quant_qa.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --max_seq_length 128 \ + --doc_stride 32 \ + --output_dir calib/bert-base-uncased \ + --do_calib \ + --calibrator percentile \ + --percentile 99.99 +``` + +``` +python3 run_quant_qa.py \ + --model_name_or_path calib/bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --learning_rate 4e-5 \ + --num_train_epochs 2 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --output_dir finetuned_int8/bert-base-uncased \ + --tokenizer_name bert-base-uncased \ + --save_steps 0 +``` + +### Export QAT model to ONNX + +To export the QAT model finetuned above: + +``` +python3 run_quant_qa.py \ + --model_name_or_path finetuned_int8/bert-base-uncased \ + --output_dir ./ \ + --save_onnx \ + --per_device_eval_batch_size 1 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --dataset_name squad \ + --tokenizer_name bert-base-uncased +``` + +Use `--recalibrate-weights` to calibrate the weight ranges according to the quantizer axis. Use `--quant-per-tensor` for per tensor quantization (default is per channel). +Recalibrating will affect the accuracy of the model, but the change should be minimal (< 0.5 F1). + +### Benchmark the INT8 QAT ONNX model inference with TensorRT using dummy input + +``` +trtexec --onnx=model.onnx --explicitBatch --workspace=16384 --int8 --shapes=input_ids:64x128,attention_mask:64x128,token_type_ids:64x128 --verbose +``` + +### Evaluate the INT8 QAT ONNX model inference with TensorRT + +``` +python3 evaluate-hf-trt-qa.py \ + --onnx_model_path=./model.onnx \ + --output_dir ./ \ + --per_device_eval_batch_size 64 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --dataset_name squad \ + --tokenizer_name bert-base-uncased \ + --int8 \ + --seed 42 +``` + +## Fine-tuning of FP32 model for comparison + +Finetune a fp32 precision model with [transformers/examples/pytorch/question-answering/](../../pytorch/question-answering/): + +``` +python3 ../../pytorch/question-answering/run_qa.py \ + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --per_device_train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 2 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --output_dir ./finetuned_fp32/bert-base-uncased \ + --save_steps 0 \ + --do_train \ + --do_eval +``` + +## Post Training Quantization (PTQ) + +### PTQ by calibrating and evaluating the finetuned FP32 model above: + +``` +python3 run_quant_qa.py \ + --model_name_or_path ./finetuned_fp32/bert-base-uncased \ + --dataset_name squad \ + --calibrator percentile \ + --percentile 99.99 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --output_dir ./calib/bert-base-uncased \ + --save_steps 0 \ + --do_calib \ + --do_eval +``` + +### Export the INT8 PTQ model to ONNX + +``` +python3 run_quant_qa.py \ + --model_name_or_path ./calib/bert-base-uncased \ + --output_dir ./ \ + --save_onnx \ + --per_device_eval_batch_size 1 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --dataset_name squad \ + --tokenizer_name bert-base-uncased +``` + +### Evaluate the INT8 PTQ ONNX model inference with TensorRT + +``` +python3 evaluate-hf-trt-qa.py \ + --onnx_model_path=./model.onnx \ + --output_dir ./ \ + --per_device_eval_batch_size 64 \ + --max_seq_length 128 \ + --doc_stride 32 \ + --dataset_name squad \ + --tokenizer_name bert-base-uncased \ + --int8 \ + --seed 42 +``` + +### Quantization options + +Some useful options to support different implementations and optimizations. These should be specified for both calibration and finetuning. + +|argument|description| +|--------|-----------| +|`--quant-per-tensor`| quantize weights with one quantization range per tensor | +|`--fuse-qkv` | use a single range (the max) for quantizing QKV weights and output activations | +|`--clip-gelu N` | clip the output of GELU to a maximum of N when quantizing (e.g. 10) | +|`--disable-dropout` | disable dropout for consistent activation ranges | diff --git a/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py new file mode 100755 index 00000000000..4a618ed77cd --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/evaluate-hf-trt-qa.py @@ -0,0 +1,456 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" +import argparse +import logging +import os +import time +import timeit + +import datasets +import numpy as np +import torch +from absl import logging as absl_logging +from datasets import load_dataset, load_metric +from torch.utils.data import DataLoader + +import pycuda.autoinit # noqa: F401 +import pycuda.driver as cuda +import tensorrt as trt +import transformers +from accelerate import Accelerator +from transformers import AutoTokenizer, EvalPrediction, default_data_collator, set_seed +from transformers.trainer_pt_utils import nested_concat, nested_truncate +from utils_qa import postprocess_qa_predictions + + +TRT_LOGGER = trt.Logger(trt.Logger.WARNING) +absl_logger = absl_logging.get_absl_logger() +absl_logger.setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser() + +# Required parameters +parser.add_argument( + "--onnx_model_path", + default=None, + type=str, + required=True, + help="Path to ONNX model: ", +) + +parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model checkpoints and predictions will be written.", +) + +# Other parameters + +parser.add_argument( + "--tokenizer_name", + default="", + type=str, + required=True, + help="Pretrained tokenizer name or path if not the same as model_name", +) + +parser.add_argument( + "--version_2_with_negative", + action="store_true", + help="If true, the SQuAD examples contain some that do not have an answer.", +) +parser.add_argument( + "--null_score_diff_threshold", + type=float, + default=0.0, + help="If null_score - best_non_null is greater than the threshold predict null.", +) + +parser.add_argument( + "--max_seq_length", + default=384, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.", +) +parser.add_argument( + "--doc_stride", + default=128, + type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.", +) + +parser.add_argument("--per_device_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.") + +parser.add_argument( + "--n_best_size", + default=20, + type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json output file.", +) +parser.add_argument( + "--max_answer_length", + default=30, + type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.", +) + +parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + +parser.add_argument( + "--dataset_name", + type=str, + default=None, + required=True, + help="The name of the dataset to use (via the datasets library).", +) +parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", +) +parser.add_argument( + "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." +) +parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" +) +parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision instead of 32-bit", +) +parser.add_argument( + "--int8", + action="store_true", + help="Whether to use INT8", +) + +args = parser.parse_args() + +if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) +else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + +logger.info("Training/evaluation parameters %s", args) + +args.eval_batch_size = args.per_device_eval_batch_size + +INPUT_SHAPE = (args.eval_batch_size, args.max_seq_length) + +# TRT Engine properties +STRICT_TYPES = True + +engine_name = "temp_engine/bert-fp32.engine" +if args.fp16: + engine_name = "temp_engine/bert-fp16.engine" +if args.int8: + engine_name = "temp_engine/bert-int8.engine" + +# import ONNX file +if not os.path.exists("temp_engine"): + os.makedirs("temp_engine") + +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) +with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser( + network, TRT_LOGGER +) as parser: + with open(args.onnx_model_path, "rb") as model: + if not parser.parse(model.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + + # Query input names and shapes from parsed TensorRT network + network_inputs = [network.get_input(i) for i in range(network.num_inputs)] + input_names = [_input.name for _input in network_inputs] # ex: ["actual_input1"] + + with builder.create_builder_config() as config: + config.max_workspace_size = 1 << 50 + if STRICT_TYPES: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + if args.fp16: + config.set_flag(trt.BuilderFlag.FP16) + if args.int8: + config.set_flag(trt.BuilderFlag.INT8) + profile = builder.create_optimization_profile() + config.add_optimization_profile(profile) + for i in range(len(input_names)): + profile.set_shape(input_names[i], INPUT_SHAPE, INPUT_SHAPE, INPUT_SHAPE) + engine = builder.build_engine(network, config) + + # serialize_engine and store in file (can be directly loaded and deserialized): + with open(engine_name, "wb") as f: + f.write(engine.serialize()) + + +# run inference with TRT +def model_infer(inputs, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream): + input_ids = np.asarray(inputs["input_ids"], dtype=np.int32) + attention_mask = np.asarray(inputs["attention_mask"], dtype=np.int32) + token_type_ids = np.asarray(inputs["token_type_ids"], dtype=np.int32) + + # Copy inputs + cuda.memcpy_htod_async(d_inputs[0], input_ids.ravel(), stream) + cuda.memcpy_htod_async(d_inputs[1], attention_mask.ravel(), stream) + cuda.memcpy_htod_async(d_inputs[2], token_type_ids.ravel(), stream) + # start time + start_time = time.time() + # Run inference + context.execute_async( + bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output0), int(d_output1)], stream_handle=stream.handle + ) + # Transfer predictions back from GPU + cuda.memcpy_dtoh_async(h_output0, d_output0, stream) + cuda.memcpy_dtoh_async(h_output1, d_output1, stream) + # Synchronize the stream and take time + stream.synchronize() + # end time + end_time = time.time() + infer_time = end_time - start_time + outputs = (h_output0, h_output1) + # print(outputs) + return outputs, infer_time + + +# Initialize the accelerator. We will let the accelerator handle device placement for us in this example. +accelerator = Accelerator() +# Make one log on every process with the configuration for debugging. +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +# Setup logging, we only want one process per machine to log things on the screen. +# accelerator.is_local_main_process is only True for one process per machine. +logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) +if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() +else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + +# If passed along, set the training seed now. +if args.seed is not None: + set_seed(args.seed) + +# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) +# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ +# (the dataset will be downloaded automatically from the datasets Hub). +# +# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called +# 'text' is found. You can easily tweak this behavior (see below). +if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) +else: + raise ValueError("Evaluation requires a dataset name") +# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at +# https://huggingface.co/docs/datasets/loading_datasets.html. + +# Preprocessing the datasets. +# Preprocessing is slighlty different for training and evaluation. + +column_names = raw_datasets["validation"].column_names + +question_column_name = "question" if "question" in column_names else column_names[0] +context_column_name = "context" if "context" in column_names else column_names[1] +answer_column_name = "answers" if "answers" in column_names else column_names[2] + +# Padding side determines if we do (question|context) or (context|question). +pad_on_right = tokenizer.padding_side == "right" + +if args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + +max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + +# Validation preprocessing +def prepare_validation_features(examples): + # Some of the questions have lots of whitespace on the left, which is not useful and will make the + # truncation of the context fail (the tokenized question will take a lots of space). So we remove that + # left whitespace + examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] + + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + +eval_examples = raw_datasets["validation"] +# Validation Feature Creation +eval_dataset = eval_examples.map( + prepare_validation_features, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on validation dataset", +) + +data_collator = default_data_collator + +eval_dataset_for_model = eval_dataset.remove_columns(["example_id", "offset_mapping"]) +eval_dataloader = DataLoader( + eval_dataset_for_model, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size +) + + +# Post-processing: +def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=args.version_2_with_negative, + n_best_size=args.n_best_size, + max_answer_length=args.max_answer_length, + null_score_diff_threshold=args.null_score_diff_threshold, + output_dir=args.output_dir, + prefix=stage, + ) + # Format the result to the format the metric expects. + if args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + +metric = load_metric("squad_v2" if args.version_2_with_negative else "squad") + +# Evaluation! +logger.info("Loading ONNX model %s for evaluation", args.onnx_model_path) +with open(engine_name, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime, runtime.deserialize_cuda_engine( + f.read() +) as engine, engine.create_execution_context() as context: + + # setup for TRT inferrence + for i in range(len(input_names)): + context.set_binding_shape(i, INPUT_SHAPE) + assert context.all_binding_shapes_specified + + def binding_nbytes(binding): + return trt.volume(engine.get_binding_shape(binding)) * engine.get_binding_dtype(binding).itemsize + + # Allocate device memory for inputs and outputs. + d_inputs = [cuda.mem_alloc(binding_nbytes(binding)) for binding in engine if engine.binding_is_input(binding)] + + # Allocate output buffer + h_output0 = cuda.pagelocked_empty(tuple(context.get_binding_shape(3)), dtype=np.float32) + h_output1 = cuda.pagelocked_empty(tuple(context.get_binding_shape(4)), dtype=np.float32) + d_output0 = cuda.mem_alloc(h_output0.nbytes) + d_output1 = cuda.mem_alloc(h_output1.nbytes) + + # Create a stream in which to copy inputs/outputs and run inference. + stream = cuda.Stream() + + # Evaluation + logger.info("***** Running Evaluation *****") + logger.info(f" Num examples = {len(eval_dataset)}") + logger.info(f" Batch size = {args.per_device_eval_batch_size}") + + total_time = 0.0 + niter = 0 + start_time = timeit.default_timer() + + all_preds = None + for step, batch in enumerate(eval_dataloader): + + outputs, infer_time = model_infer(batch, context, d_inputs, h_output0, h_output1, d_output0, d_output1, stream) + total_time += infer_time + niter += 1 + + start_logits, end_logits = outputs + start_logits = torch.tensor(start_logits) + end_logits = torch.tensor(end_logits) + + # necessary to pad predictions and labels for being gathered + start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100) + end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100) + + logits = (accelerator.gather(start_logits).cpu().numpy(), accelerator.gather(end_logits).cpu().numpy()) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + + if all_preds is not None: + all_preds = nested_truncate(all_preds, len(eval_dataset)) + + evalTime = timeit.default_timer() - start_time + logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(eval_dataset)) + # Inference time from TRT + logger.info("Average Inference Time = {:.3f} ms".format(total_time * 1000 / niter)) + logger.info("Total Inference Time = {:.3f} ms".format(total_time * 1000)) + logger.info("Total Number of Inference = %d", niter) + +prediction = post_processing_function(eval_examples, eval_dataset, all_preds) +eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) +logger.info(f"Evaluation metrics: {eval_metric}") diff --git a/examples/research_projects/quantization-qdqbert/quant_trainer.py b/examples/research_projects/quantization-qdqbert/quant_trainer.py new file mode 100755 index 00000000000..9fa79787022 --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/quant_trainer.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helper functions for training models with pytorch-quantization""" +import logging +import re + +import torch + +import pytorch_quantization +import pytorch_quantization.nn as quant_nn +from pytorch_quantization import calib +from pytorch_quantization.tensor_quant import QuantDescriptor + + +logger = logging.getLogger(__name__) + +name_width = 50 # max width of layer names +qname_width = 70 # max width of quantizer names + +# ========================================== Quant Trainer API ========================================== + + +def add_arguments(parser): + """Add arguments to parser for functions defined in quant_trainer.""" + + group = parser.add_argument_group("quant_trainer arguments") + group.add_argument("--wprec", type=int, default=8, help="weight precision") + group.add_argument("--aprec", type=int, default=8, help="activation precision") + group.add_argument("--quant-per-tensor", action="store_true", help="per tensor weight scaling") + group.add_argument("--quant-disable", action="store_true", help="disable all quantizers") + group.add_argument("--quant-disable-embeddings", action="store_true", help="disable all embeddings quantizers") + group.add_argument("--quant-disable-keyword", type=str, nargs="+", help="disable quantizers by keyword") + group.add_argument("--quant-disable-layer-module", type=str, help="disable quantizers by keyword under layer.\d+.") + group.add_argument("--quant-enable-layer-module", type=str, help="enable quantizers by keyword under layer.\d+.") + group.add_argument("--calibrator", default="max", help="which quantization range calibrator to use") + group.add_argument("--percentile", default=None, type=float, help="percentile for PercentileCalibrator") + group.add_argument("--fuse-qkv", action="store_true", help="use the same scale factor for qkv") + group.add_argument("--clip-gelu", metavar="N", type=float, help="clip gelu output maximum value to N") + group.add_argument( + "--recalibrate-weights", + action="store_true", + help="recalibrate weight amaxes by taking the max of the weights." + " amaxes will be computed with the current quantization granularity (axis).", + ) + + +def set_default_quantizers(args): + """Set default quantizers before creating the model.""" + + if args.calibrator == "max": + calib_method = "max" + elif args.calibrator == "percentile": + if args.percentile is None: + raise ValueError("Specify --percentile when using percentile calibrator") + calib_method = "histogram" + elif args.calibrator == "mse": + calib_method = "histogram" + else: + raise ValueError(f"Invalid calibrator {args.calibrator}") + + input_desc = QuantDescriptor(num_bits=args.aprec, calib_method=calib_method) + weight_desc = QuantDescriptor(num_bits=args.wprec, axis=(None if args.quant_per_tensor else (0,))) + quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) + quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) + + +def configure_model(model, args, calib=False, eval=False): + """Function called before the training loop.""" + + logger.info("Configuring Model for Quantization") + logger.info(f"using quantization package {pytorch_quantization.__file__}") + + if not calib: + if args.quant_disable_embeddings: + set_quantizer_by_name(model, ["embeddings"], which="weight", _disabled=True) + + if args.quant_disable: + set_quantizer_by_name(model, [""], _disabled=True) + + if args.quant_disable_keyword: + set_quantizer_by_name(model, args.quant_disable_keyword, _disabled=True) + + if args.quant_disable_layer_module: + set_quantizer_by_name(model, ["layer.\d+." + args.quant_disable_layer_module], _disabled=True) + + if args.quant_enable_layer_module: + set_quantizer_by_name(model, ["layer.\d+." + args.quant_enable_layer_module], _disabled=False) + + if args.recalibrate_weights: + recalibrate_weights(model) + + if args.fuse_qkv: + fuse_qkv(model, args) + + if args.clip_gelu: + clip_gelu(model, args.clip_gelu) + + # if args.local_rank in [-1, 0] and not calib: + print_quant_summary(model) + + +def enable_calibration(model): + """Enable calibration of all *_input_quantizer modules in model.""" + + logger.info("Enabling Calibration") + for name, module in model.named_modules(): + if name.endswith("_quantizer"): + if module._calibrator is not None: + module.disable_quant() + module.enable_calib() + else: + module.disable() + logger.info(f"{name:80}: {module}") + + +def finish_calibration(model, args): + """Disable calibration and load amax for all "*_input_quantizer modules in model.""" + + logger.info("Loading calibrated amax") + for name, module in model.named_modules(): + if name.endswith("_quantizer"): + if module._calibrator is not None: + if isinstance(module._calibrator, calib.MaxCalibrator): + module.load_calib_amax() + else: + module.load_calib_amax("percentile", percentile=args.percentile) + module.enable_quant() + module.disable_calib() + else: + module.enable() + model.cuda() + print_quant_summary(model) + + +# ========================================== Helper Function ========================================== + + +def fuse_qkv(model, args): + """Adjust quantization ranges to match an implementation where the QKV projections are implemented with a single GEMM. + Force the weight and output scale factors to match by taking the max of (Q,K,V). + """ + + def fuse3(qq, qk, qv): + for mod in [qq, qk, qv]: + if not hasattr(mod, "_amax"): + print(" WARNING: NO AMAX BUFFER") + return + q = qq._amax.detach().item() + k = qk._amax.detach().item() + v = qv._amax.detach().item() + + amax = max(q, k, v) + qq._amax.fill_(amax) + qk._amax.fill_(amax) + qv._amax.fill_(amax) + logger.info(f" q={q:5.2f} k={k:5.2f} v={v:5.2f} -> {amax:5.2f}") + + for name, mod in model.named_modules(): + if name.endswith(".attention.self"): + logger.info(f"FUSE_QKV: {name:{name_width}}") + fuse3(mod.matmul_q_input_quantizer, mod.matmul_k_input_quantizer, mod.matmul_v_input_quantizer) + if args.quant_per_tensor: + fuse3(mod.query._weight_quantizer, mod.key._weight_quantizer, mod.value._weight_quantizer) + + +def clip_gelu(model, maxval): + """Clip activations generated by GELU to maxval when quantized. + Implemented by adjusting the amax of the following input_quantizer. + """ + + for name, mod in model.named_modules(): + if name.endswith(".output.dense") and not name.endswith("attention.output.dense"): + amax_init = mod._input_quantizer._amax.data.detach().item() + mod._input_quantizer._amax.data.detach().clamp_(max=maxval) + amax = mod._input_quantizer._amax.data.detach().item() + logger.info(f"CLIP_GELU: {name:{name_width}} amax: {amax_init:5.2f} -> {amax:5.2f}") + + +def expand_amax(model): + """Expand per-tensor amax to be per channel, where each channel is assigned the per-tensor amax.""" + + for name, mod in model.named_modules(): + if hasattr(mod, "_weight_quantizer") and mod._weight_quantizer.axis is not None: + k = mod.weight.shape[0] + amax = mod._weight_quantizer._amax.detach() + mod._weight_quantizer._amax = torch.ones(k, dtype=amax.dtype, device=amax.device) * amax + print(f"expanding {name} {amax} -> {mod._weight_quantizer._amax}") + + +def recalibrate_weights(model): + """Performs max calibration on the weights and updates amax.""" + + for name, mod in model.named_modules(): + if hasattr(mod, "_weight_quantizer"): + if not hasattr(mod.weight_quantizer, "_amax"): + print("RECALIB: {name:{name_width}} WARNING: NO AMAX BUFFER") + continue + + # determine which axes to reduce across + # e.g. a 4D tensor quantized per axis 0 should reduce over (1,2,3) + axis_set = set() if mod._weight_quantizer.axis is None else set(mod._weight_quantizer.axis) + reduce_axis = set(range(len(mod.weight.size()))) - axis_set + amax = pytorch_quantization.utils.reduce_amax(mod.weight, axis=reduce_axis, keepdims=True).detach() + logger.info(f"RECALIB: {name:{name_width}} {mod._weight_quantizer._amax.flatten()} -> {amax.flatten()}") + mod._weight_quantizer._amax = amax + + +def print_model_summary(model, name_width=25, line_width=180, ignore=None): + """Print model quantization configuration.""" + + if ignore is None: + ignore = [] + elif not isinstance(ignore, list): + ignore = [ignore] + + name_width = 0 + for name, mod in model.named_modules(): + if not hasattr(mod, "weight"): + continue + name_width = max(name_width, len(name)) + + for name, mod in model.named_modules(): + input_q = getattr(mod, "_input_quantizer", None) + weight_q = getattr(mod, "_weight_quantizer", None) + if not hasattr(mod, "weight"): + continue + if type(mod) in ignore: + continue + if [True for s in ignore if type(s) is str and s in name]: + continue + act_str = f"Act:{input_q.extra_repr()}" + wgt_str = f"Wgt:{weight_q.extra_repr()}" + s = f"{name:{name_width}} {act_str} {wgt_str}" + if len(s) <= line_width: + logger.info(s) + else: + logger.info(f"{name:{name_width}} {act_str}") + logger.info(f'{" ":{name_width}} {wgt_str}') + + +def print_quant_summary(model): + """Print summary of all quantizer modules in the model.""" + + count = 0 + for name, mod in model.named_modules(): + if isinstance(mod, pytorch_quantization.nn.TensorQuantizer): + print(f"{name:80} {mod}") + count += 1 + print(f"{count} TensorQuantizers found in model") + + +def set_quantizer(name, mod, quantizer, k, v): + """Set attributes for mod.quantizer.""" + + quantizer_mod = getattr(mod, quantizer, None) + if quantizer_mod is not None: + assert hasattr(quantizer_mod, k) + setattr(quantizer_mod, k, v) + else: + logger.warn(f"{name} has no {quantizer}") + + +def set_quantizers(name, mod, which="both", **kwargs): + """Set quantizer attributes for mod.""" + + s = f"Warning: changing {which} quantizers of {name:{qname_width}}" + for k, v in kwargs.items(): + s += f" {k}={v}" + if which in ["input", "both"]: + set_quantizer(name, mod, "_input_quantizer", k, v) + if which in ["weight", "both"]: + set_quantizer(name, mod, "_weight_quantizer", k, v) + logger.info(s) + + +def set_quantizer_by_name(model, names, **kwargs): + """Set quantizer attributes for layers where name contains a substring in names.""" + + for name, mod in model.named_modules(): + if hasattr(mod, "_input_quantizer") or hasattr(mod, "_weight_quantizer"): + for n in names: + if re.search(n, name): + set_quantizers(name, mod, **kwargs) + elif name.endswith("_quantizer"): + for n in names: + if re.search(n, name): + s = f"Warning: changing {name:{name_width}}" + for k, v in kwargs.items(): + s += f" {k}={v}" + setattr(mod, k, v) + logger.info(s) diff --git a/examples/research_projects/quantization-qdqbert/run_quant_qa.py b/examples/research_projects/quantization-qdqbert/run_quant_qa.py new file mode 100755 index 00000000000..01791681eff --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/run_quant_qa.py @@ -0,0 +1,668 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for question answering. +""" +# You can also adapt this script on your own question answering task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +from datasets import load_dataset, load_metric + +import quant_trainer +import transformers +from trainer_quant_qa import QuestionAnsweringTrainer +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PreTrainedTokenizerFast, + QDQBertConfig, + QDQBertForQuestionAnswering, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import SchedulerType, get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version +from utils_qa import postprocess_qa_predictions + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.9.0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + do_calib: bool = field(default=False, metadata={"help": "Whether to run calibration of quantization ranges."}) + num_calib_batch: int = field( + default=4, + metadata={"help": "Number of batches for calibration. 0 will disable calibration "}, + ) + save_onnx: bool = field(default=False, metadata={"help": "Whether to save model to onnx."}) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " + "be faster on GPU but will be slower on TPU)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "If true, some of the examples do not have an answer."} + ) + null_score_diff_threshold: float = field( + default=0.0, + metadata={ + "help": "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + }, + ) + doc_stride: int = field( + default=128, + metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, + ) + n_best_size: int = field( + default=20, + metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, + ) + max_answer_length: int = field( + default=30, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + and self.test_file is None + ): + raise ValueError("Need either a dataset name or a training/validation file/test_file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if self.test_file is not None: + extension = self.test_file.split(".")[-1] + assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + # quant_trainer arguments + quant_trainer.add_arguments(parser) + + # if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # # If we pass only one argument to the script and it's the path to a json file, + # # let's parse it to get our arguments. + # model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + # else: + + model_args, data_args, training_args, quant_trainer_args = parser.parse_args_into_dataclasses() + + # setup QAT training args for scheduler (default to use cosine annealing learning rate schedule) + training_args.lr_scheduler_type = SchedulerType.COSINE + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # set default quantization parameters before building model + quant_trainer.set_default_quantizers(quant_trainer_args) + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = QDQBertConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = QDQBertForQuestionAnswering.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Tokenizer check: this script requires a fast tokenizer. + if not isinstance(tokenizer, PreTrainedTokenizerFast): + raise ValueError( + "This example script only works for models that have a fast tokenizer. Checkout the big table of models " + "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " + "requirement" + ) + + # Preprocessing the datasets. + # Preprocessing is slighlty different for training and evaluation. + if training_args.do_train or model_args.do_calib: + column_names = raw_datasets["train"].column_names + elif training_args.do_eval or model_args.save_onnx: + column_names = raw_datasets["validation"].column_names + else: + column_names = raw_datasets["test"].column_names + question_column_name = "question" if "question" in column_names else column_names[0] + context_column_name = "context" if "context" in column_names else column_names[1] + answer_column_name = "answers" if "answers" in column_names else column_names[2] + + # Padding side determines if we do (question|context) or (context|question). + pad_on_right = tokenizer.padding_side == "right" + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + # Training preprocessing + def prepare_train_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if data_args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_examples.pop("offset_mapping") + + # Let's label those examples! + tokenized_examples["start_positions"] = [] + tokenized_examples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_examples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = examples[answer_column_name][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized_examples["start_positions"].append(cls_index) + tokenized_examples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_examples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_examples["end_positions"].append(token_end_index + 1) + + return tokenized_examples + + if training_args.do_train or model_args.do_calib: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + # We will select sample from whole data if agument is specified + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + # Create train feature from dataset + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + prepare_train_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + if data_args.max_train_samples is not None: + # Number of samples might increase during Feature Creation, We select only specified max samples + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + # Validation preprocessing + def prepare_validation_features(examples): + # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results + # in one example possible giving several features when a context is long, each of those features having a + # context that overlaps a bit the context of the previous feature. + tokenized_examples = tokenizer( + examples[question_column_name if pad_on_right else context_column_name], + examples[context_column_name if pad_on_right else question_column_name], + truncation="only_second" if pad_on_right else "only_first", + max_length=max_seq_length, + stride=data_args.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length" if data_args.pad_to_max_length else False, + ) + + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_examples["example_id"] = [] + + for i in range(len(tokenized_examples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_examples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_examples["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_examples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_examples["offset_mapping"][i]) + ] + + return tokenized_examples + + if training_args.do_eval or model_args.save_onnx: + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_examples = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + # We will select sample from whole data + eval_examples = eval_examples.select(range(data_args.max_eval_samples)) + # Validation Feature Creation + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + if data_args.max_eval_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + if training_args.do_predict: + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_examples = raw_datasets["test"] + if data_args.max_predict_samples is not None: + # We will select sample from whole data + predict_examples = predict_examples.select(range(data_args.max_predict_samples)) + # Predict Feature Creation + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_examples.map( + prepare_validation_features, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + if data_args.max_predict_samples is not None: + # During Feature creation dataset samples might increase, we will select required samples again + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + + # Data collator + # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data + # collator. + data_collator = ( + default_data_collator + if data_args.pad_to_max_length + else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) + ) + + # Post-processing: + def post_processing_function(examples, features, predictions, stage="eval"): + # Post-processing: we match the start logits and end logits to answers in the original context. + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=data_args.version_2_with_negative, + n_best_size=data_args.n_best_size, + max_answer_length=data_args.max_answer_length, + null_score_diff_threshold=data_args.null_score_diff_threshold, + output_dir=training_args.output_dir, + log_level=log_level, + prefix=stage, + ) + # Format the result to the format the metric expects. + if data_args.version_2_with_negative: + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) + + metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") + + def compute_metrics(p: EvalPrediction): + return metric.compute(predictions=p.predictions, references=p.label_ids) + + # Initialize our Trainer + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train or model_args.do_calib else None, + eval_dataset=eval_dataset if training_args.do_eval or model_args.save_onnx else None, + eval_examples=eval_examples if training_args.do_eval or model_args.save_onnx else None, + tokenizer=tokenizer, + data_collator=data_collator, + post_process_function=post_processing_function, + compute_metrics=compute_metrics, + quant_trainer_args=quant_trainer_args, + ) + + # Calibration + if model_args.do_calib: + logger.info("*** Calibrate ***") + results = trainer.calibrate() + trainer.save_model() + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + quant_trainer.configure_model(trainer.model, quant_trainer_args) + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + quant_trainer.configure_model(trainer.model, quant_trainer_args, eval=True) + metrics = trainer.evaluate() + + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Prediction + if training_args.do_predict: + logger.info("*** Predict ***") + results = trainer.predict(predict_dataset, predict_examples) + metrics = results.metrics + + max_predict_samples = ( + data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if training_args.push_to_hub: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + trainer.push_to_hub(**kwargs) + + if model_args.save_onnx: + logger.info("Exporting model to onnx") + results = trainer.save_onnx(output_dir=training_args.output_dir) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py new file mode 100644 index 00000000000..b23edb6d518 --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A subclass of `Trainer` specific to Question-Answering tasks +""" + +import logging +import os + +import torch +from torch.utils.data import DataLoader + +import quant_trainer +from transformers import Trainer, is_torch_tpu_available +from transformers.trainer_utils import PredictionOutput + + +logger = logging.getLogger(__name__) + +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + + +class QuestionAnsweringTrainer(Trainer): + def __init__(self, *args, eval_examples=None, post_process_function=None, quant_trainer_args=None, **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + self.quant_trainer_args = quant_trainer_args + self.calib_num = 128 # default number of calibration samples + + def get_calib_dataloader(self, calib_dataset=None): + """ + Returns the calibration dataloader :class:`~torch.utils.data.DataLoader`. + + Args: + calib_dataset (:obj:`torch.utils.data.Dataset`, `optional`) + """ + if calib_dataset is None and self.calib_dataset is None: + raise ValueError("Trainer: calibration requires an calib_dataset.") + calib_dataset = calib_dataset if calib_dataset is not None else self.calib_dataset + + calib_dataset = self._remove_unused_columns(calib_dataset, description="Calibration") + + return DataLoader( + calib_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + shuffle=True, + ) + + def calibrate(self, calib_dataset=None): + calib_dataset = self.train_dataset if calib_dataset is None else calib_dataset + calib_dataloader = self.get_calib_dataloader(calib_dataset) + + model = self.model + quant_trainer.configure_model(model, self.quant_trainer_args, calib=True) + model.eval() + quant_trainer.enable_calibration(model) + + logger.info("***** Running calibration *****") + logger.info(f" Num examples = {self.calib_num}") + logger.info(f" Batch size = {calib_dataloader.batch_size}") + + for step, inputs in enumerate(calib_dataloader): + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only=True) + if (step + 1) * calib_dataloader.batch_size >= self.calib_num: + break + + quant_trainer.finish_calibration(model, self.quant_trainer_args) + self.model = model + + def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) + metrics = self.compute_metrics(eval_preds) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + self.log(metrics) + else: + metrics = {} + + if self.args.tpu_metrics_debug or self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + return metrics + + def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): + predict_dataloader = self.get_test_dataloader(predict_dataset) + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + predict_dataloader, + description="Prediction", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is None or self.compute_metrics is None: + return output + + predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") + metrics = self.compute_metrics(predictions) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) + + def save_onnx(self, output_dir="./"): + eval_dataset = self.eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + + batch = next(iter(eval_dataloader)) + + # saving device - to make it consistent + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # convert to tuple + input_tuple = tuple(v.to(device) for k, v in batch.items()) + + logger.info("Converting model to be onnx compatible") + from pytorch_quantization.nn import TensorQuantizer + + TensorQuantizer.use_fb_fake_quant = True + + model = self.model.to(device) + + model.eval() + model.float() + + model_to_save = model.module if hasattr(model, "module") else model + quant_trainer.configure_model(model_to_save, self.quant_trainer_args) + + output_model_file = os.path.join(output_dir, "model.onnx") + logger.info(f"exporting model to {output_model_file}") + + axes = {0: "batch_size", 1: "seq_len"} + + torch.onnx.export( + model_to_save, + input_tuple, + output_model_file, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=["input_ids", "attention_mask", "token_type_ids"], + output_names=["output_start_logits", "output_end_logits"], + dynamic_axes={ + "input_ids": axes, + "attention_mask": axes, + "token_type_ids": axes, + "output_start_logits": axes, + "output_end_logits": axes, + }, + verbose=True, + ) + logger.info("onnx export finished") diff --git a/examples/research_projects/quantization-qdqbert/utils_qa.py b/examples/research_projects/quantization-qdqbert/utils_qa.py new file mode 100644 index 00000000000..1157849c991 --- /dev/null +++ b/examples/research_projects/quantization-qdqbert/utils_qa.py @@ -0,0 +1,427 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Post-processing utilities for question answering. +""" +import collections +import json +import logging +import os +from typing import Optional, Tuple + +import numpy as np +from tqdm.auto import tqdm + + +logger = logging.getLogger(__name__) + + +def postprocess_qa_predictions( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + if version_2_with_negative: + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + +def postprocess_qa_predictions_with_beam_search( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + start_n_top: int = 5, + end_n_top: int = 5, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the + original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as + cls token predictions. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + start_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. + end_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions + + assert len(predictions[0]) == len( + features + ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() if version_2_with_negative else None + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_score = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_log_prob = start_top_log_probs[feature_index] + start_indexes = start_top_index[feature_index] + end_log_prob = end_top_log_probs[feature_index] + end_indexes = end_top_index[feature_index] + feature_null_score = cls_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction + if min_null_score is None or feature_null_score < min_null_score: + min_null_score = feature_null_score + + # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. + for i in range(start_n_top): + for j in range(end_n_top): + start_index = int(start_indexes[i]) + j_index = i * end_n_top + j + end_index = int(end_indexes[j_index]) + # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the + # p_mask but let's not take any risk) + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length negative or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_log_prob[i] + end_log_prob[j_index], + "start_log_prob": start_log_prob[i], + "end_log_prob": end_log_prob[j_index], + } + ) + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0: + predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction and set the probability for the null answer. + all_predictions[example["id"]] = predictions[0]["text"] + if version_2_with_negative: + scores_diff_json[example["id"]] = float(min_null_score) + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions, scores_diff_json diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 09b8d37466b..88099cebda9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -44,6 +44,7 @@ from . import dependency_versions_check from .file_utils import ( _LazyModule, is_flax_available, + is_pytorch_quantization_available, is_scatter_available, is_sentencepiece_available, is_speech_available, @@ -248,6 +249,7 @@ _import_structure = { "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.phobert": ["PhobertTokenizer"], "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], + "models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], "models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"], @@ -529,6 +531,30 @@ else: name for name in dir(dummy_scatter_objects) if not name.startswith("_") ] +if is_torch_available() and is_pytorch_quantization_available(): + _import_structure["models.qdqbert"].extend( + [ + "QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "QDQBertForMaskedLM", + "QDQBertForMultipleChoice", + "QDQBertForNextSentencePrediction", + "QDQBertForQuestionAnswering", + "QDQBertForSequenceClassification", + "QDQBertForTokenClassification", + "QDQBertLayer", + "QDQBertLMHeadModel", + "QDQBertModel", + "QDQBertPreTrainedModel", + "load_tf_weights_in_qdqbert", + ] + ) +else: + from .utils import dummy_pytorch_quantization_and_torch_objects + + _import_structure["utils.dummy_pytorch_quantization_and_torch_objects"] = [ + name for name in dir(dummy_pytorch_quantization_and_torch_objects) if not name.startswith("_") + ] + # PyTorch-backed objects if is_torch_available(): _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] @@ -2188,6 +2214,7 @@ if TYPE_CHECKING: from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.phobert import PhobertTokenizer from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer + from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig @@ -2415,6 +2442,24 @@ if TYPE_CHECKING: else: from .utils.dummy_scatter_objects import * + if is_torch_available() and is_pytorch_quantization_available(): + from .models.qdqbert import ( + QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + QDQBertForMaskedLM, + QDQBertForMultipleChoice, + QDQBertForNextSentencePrediction, + QDQBertForQuestionAnswering, + QDQBertForSequenceClassification, + QDQBertForTokenClassification, + QDQBertLayer, + QDQBertLMHeadModel, + QDQBertModel, + QDQBertPreTrainedModel, + load_tf_weights_in_qdqbert, + ) + else: + from .utils.dummy_pytorch_quantization_and_torch_objects import * + if is_torch_available(): # Benchmarks from .benchmark.benchmark import PyTorchBenchmark diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 3b24b4c7375..ae9fc499804 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -196,6 +196,14 @@ except importlib_metadata.PackageNotFoundError: _scatter_available = False +_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None +try: + _pytorch_quantization_version = importlib_metadata.version("pytorch_quantization") + logger.debug(f"Successfully imported pytorch-quantization version {_pytorch_quantization_version}") +except importlib_metadata.PackageNotFoundError: + _pytorch_quantization_available = False + + _soundfile_available = importlib.util.find_spec("soundfile") is not None try: _soundfile_version = importlib_metadata.version("soundfile") @@ -431,6 +439,10 @@ def is_scatter_available(): return _scatter_available +def is_pytorch_quantization_available(): + return _pytorch_quantization_available + + def is_pandas_available(): return importlib.util.find_spec("pandas") is not None @@ -610,6 +622,12 @@ SCATTER_IMPORT_ERROR = """ explained here: https://github.com/rusty1s/pytorch_scatter. """ +# docstyle-ignore +PYTORCH_QUANTIZATION_IMPORT_ERROR = """ +{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip: +`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com` +""" + # docstyle-ignore PANDAS_IMPORT_ERROR = """ @@ -661,6 +679,7 @@ BACKENDS_MAPPING = OrderedDict( ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), ("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)), + ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e103d3056da..451077e1b84 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -79,6 +79,7 @@ from . import ( pegasus, phobert, prophetnet, + qdqbert, rag, reformer, rembert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 11223da7320..2070a63994d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -31,6 +31,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here ("imagegpt", "ImageGPTConfig"), + ("qdqbert", "QDQBertConfig"), ("vision-encoder-decoder", "VisionEncoderDecoderConfig"), ("trocr", "TrOCRConfig"), ("fnet", "FNetConfig"), @@ -113,6 +114,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -185,6 +187,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here ("imagegpt", "ImageGPT"), + ("qdqbert", "QDQBert"), ("vision-encoder-decoder", "Vision Encoder decoder"), ("trocr", "TrOCR"), ("fnet", "FNet"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3039f4b12f0..dc534c6ccf1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -29,6 +29,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping ("imagegpt", "ImageGPTModel"), + ("qdqbert", "QDQBertModel"), ("fnet", "FNetModel"), ("segformer", "SegformerModel"), ("gptj", "GPTJModel"), @@ -147,6 +148,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping ("imagegpt", "ImageGPTForCausalLM"), + ("qdqbert", "QDQBertForMaskedLM"), ("fnet", "FNetForMaskedLM"), ("gptj", "GPTJForCausalLM"), ("rembert", "RemBertForMaskedLM"), @@ -198,6 +200,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping ("imagegpt", "ImageGPTForCausalLM"), + ("qdqbert", "QDQBertLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("gptj", "GPTJForCausalLM"), ("rembert", "RemBertForCausalLM"), @@ -257,6 +260,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping + ("qdqbert", "QDQBertForMaskedLM"), ("fnet", "FNetForMaskedLM"), ("rembert", "RemBertForMaskedLM"), ("roformer", "RoFormerForMaskedLM"), @@ -327,6 +331,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + ("qdqbert", "QDQBertForSequenceClassification"), ("fnet", "FNetForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), @@ -372,6 +377,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping + ("qdqbert", "QDQBertForQuestionAnswering"), ("fnet", "FNetForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), ("rembert", "RemBertForQuestionAnswering"), @@ -418,6 +424,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping + ("qdqbert", "QDQBertForTokenClassification"), ("fnet", "FNetForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("rembert", "RemBertForTokenClassification"), @@ -452,6 +459,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping + ("qdqbert", "QDQBertForMultipleChoice"), ("fnet", "FNetForMultipleChoice"), ("rembert", "RemBertForMultipleChoice"), ("canine", "CanineForMultipleChoice"), @@ -480,6 +488,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ + ("qdqbert", "QDQBertForNextSentencePrediction"), ("bert", "BertForNextSentencePrediction"), ("fnet", "FNetForNextSentencePrediction"), ("megatron-bert", "MegatronBertForNextSentencePrediction"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 04ac20ddf6a..f25eb0606f1 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -173,6 +173,7 @@ else: ), ), ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), ("hubert", ("Wav2Vec2CTCTokenizer", None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/qdqbert/__init__.py b/src/transformers/models/qdqbert/__init__.py new file mode 100644 index 00000000000..fc033271427 --- /dev/null +++ b/src/transformers/models/qdqbert/__init__.py @@ -0,0 +1,67 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available + + +_import_structure = { + "configuration_qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"], +} + +if is_torch_available(): + _import_structure["modeling_qdqbert"] = [ + "QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "QDQBertForMaskedLM", + "QDQBertForMultipleChoice", + "QDQBertForNextSentencePrediction", + "QDQBertForQuestionAnswering", + "QDQBertForSequenceClassification", + "QDQBertForTokenClassification", + "QDQBertLayer", + "QDQBertLMHeadModel", + "QDQBertModel", + "QDQBertPreTrainedModel", + "load_tf_weights_in_qdqbert", + ] + + +if TYPE_CHECKING: + from .configuration_qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig + + if is_torch_available(): + from .modeling_qdqbert import ( + QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + QDQBertForMaskedLM, + QDQBertForMultipleChoice, + QDQBertForNextSentencePrediction, + QDQBertForQuestionAnswering, + QDQBertForSequenceClassification, + QDQBertForTokenClassification, + QDQBertLayer, + QDQBertLMHeadModel, + QDQBertModel, + QDQBertPreTrainedModel, + load_tf_weights_in_qdqbert, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/qdqbert/configuration_qdqbert.py b/src/transformers/models/qdqbert/configuration_qdqbert.py new file mode 100644 index 00000000000..ede907486d6 --- /dev/null +++ b/src/transformers/models/qdqbert/configuration_qdqbert.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" QDQBERT model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json", + # QDQBERT models can be loaded from any BERT checkpoint, available at https://huggingface.co/models?filter=bert +} + + +class QDQBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.QDQBertModel`. It is used to + instantiate an QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BERT `bert-base-uncased + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the QDQBERT model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.QDQBertModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.QDQBertModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + + Examples:: + + >>> from transformers import QDQBertModel, QDQBertConfig + + >>> # Initializing a QDQBERT bert-base-uncased style configuration + >>> configuration = QDQBertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = QDQBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "qdqbert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + is_encoder_decoder=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py new file mode 100755 index 00000000000..1e664036926 --- /dev/null +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -0,0 +1,1750 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch QDQBERT model. """ + + +import math +import os +import warnings + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_pytorch_quantization_available, + replace_return_docstrings, + requires_backends, +) +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_qdqbert import QDQBertConfig + + +logger = logging.get_logger(__name__) + +# soft dependency +if is_pytorch_quantization_available(): + try: + from pytorch_quantization import nn as quant_nn + from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer + except OSError: + logger.error( + "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. " + "Please try to reinstall it following the instructions here: https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." + ) + +_CHECKPOINT_FOR_DOC = "bert-base-uncased" +_CONFIG_FOR_DOC = "QDQBertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + +QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "bert-base-uncased", + # See all BERT models at https://huggingface.co/models?filter=bert +] + + +def load_tf_weights_in_qdqbert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert -> QDQBert +class QDQBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class QDQBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul( + self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2)) + ) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in QDQBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul( + self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer) + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class QDQBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Quantize the inputs to the residual add + self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Quantize the inputs to the residual add + add_local = self.add_local_input_quantizer(hidden_states) + add_residual = self.add_residual_input_quantizer(input_tensor) + hidden_states = self.LayerNorm(add_local + add_residual) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertAttention with Bert -> QDQBert +class QDQBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = QDQBertSelfAttention(config) + self.output = QDQBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class QDQBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class QDQBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Quantize the inputs to the residual add + self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Quantize the inputs to the residual add + add_local = self.add_local_input_quantizer(hidden_states) + add_residual = self.add_residual_input_quantizer(input_tensor) + hidden_states = self.LayerNorm(add_local + add_residual) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert +class QDQBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_len_dim = 1 + self.attention = QDQBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = QDQBertAttention(config) + self.intermediate = QDQBertIntermediate(config) + self.output = QDQBertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = self.feed_forward_chunk(attention_output) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Based on transformers.models.bert.modeling_bert.BertEncoder with Bert -> QDQBert +class QDQBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert -> QDQBert +class QDQBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert -> QDQBert +class QDQBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert -> QDQBert +class QDQBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = QDQBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert -> QDQBert +class QDQBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = QDQBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert -> QDQBert +class QDQBertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Based on transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert -> QDQBert +class QDQBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = QDQBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +# Based on transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert -> QDQBert +class QDQBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = QDQBertConfig + load_tf_weights = load_tf_weights_in_qdqbert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, QDQBertEncoder): + module.gradient_checkpointing = value + + +QDQBERT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.QDQBertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +QDQBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.", + QDQBERT_START_DOCSTRING, +) +class QDQBertModel(QDQBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration + set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder` + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + requires_backends(self, "pytorch_quantization") + super().__init__(config) + self.config = config + + self.embeddings = QDQBertEmbeddings(config) + self.encoder = QDQBertEncoder(config) + + self.pooler = QDQBertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning. """, QDQBERT_START_DOCSTRING +) +class QDQBertLMHeadModel(QDQBertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.cls = QDQBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + Returns: + + Example:: + + >>> from transformers import BertTokenizer, QDQBertLMHeadModel, QDQBertConfig + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = QDQBertConfig.from_pretrained("bert-base-cased") + >>> config.is_decoder = True + >>> model = QDQBertLMHeadModel.from_pretrained('bert-base-cased', config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top. """, QDQBERT_START_DOCSTRING) +class QDQBertForMaskedLM(QDQBertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.cls = QDQBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top. """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = QDQBertModel(config) + self.cls = QDQBertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see ``input_ids`` docstring). Indices should be in ``[0, 1]``: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example:: + + >>> from transformers import BertTokenizer, QDQBertForNextSentencePrediction + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = QDQBertForNextSentencePrediction.from_pretrained('bert-base-uncased') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForSequenceClassification(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = QDQBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForMultipleChoice(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = QDQBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForTokenClassification(QDQBertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForQuestionAnswering(QDQBertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index c10c07f788b..bbfdaedfb72 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -39,6 +39,7 @@ from .file_utils import ( is_onnx_available, is_pandas_available, is_pytesseract_available, + is_pytorch_quantization_available, is_rjieba_available, is_scatter_available, is_sentencepiece_available, @@ -371,6 +372,17 @@ def require_scatter(test_case): return test_case +def require_pytorch_quantization(test_case): + """ + Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch + Quantization Toolkit isn't installed. + """ + if not is_pytorch_quantization_available(): + return unittest.skip("test requires PyTorch Quantization Toolkit")(test_case) + else: + return test_case + + def require_vision(test_case): """ Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't diff --git a/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py b/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py new file mode 100644 index 00000000000..79f9c54316a --- /dev/null +++ b/src/transformers/utils/dummy_pytorch_quantization_and_torch_objects.py @@ -0,0 +1,91 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..file_utils import requires_backends + + +QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class QDQBertForMaskedLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertForNextSentencePrediction: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + +class QDQBertForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertForTokenClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertLayer: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + +class QDQBertLMHeadModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +class QDQBertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["pytorch_quantization", "torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["pytorch_quantization", "torch"]) + + +def load_tf_weights_in_qdqbert(*args, **kwargs): + requires_backends(load_tf_weights_in_qdqbert, ["pytorch_quantization", "torch"]) diff --git a/tests/test_modeling_qdqbert.py b/tests/test_modeling_qdqbert.py new file mode 100644 index 00000000000..b667b830c84 --- /dev/null +++ b/tests/test_modeling_qdqbert.py @@ -0,0 +1,563 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# Copyright 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch QDQBERT model. """ + + +import unittest + +from tests.test_modeling_common import floats_tensor +from transformers import QDQBertConfig, is_torch_available +from transformers.testing_utils import require_pytorch_quantization, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + QDQBertForMaskedLM, + QDQBertForMultipleChoice, + QDQBertForNextSentencePrediction, + QDQBertForQuestionAnswering, + QDQBertForSequenceClassification, + QDQBertForTokenClassification, + QDQBertLMHeadModel, + QDQBertModel, + ) + from transformers.models.qdqbert.modeling_qdqbert import QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST + + +class QDQBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + # Set default quantizers before creating the model. + import pytorch_quantization.nn as quant_nn + from pytorch_quantization.tensor_quant import QuantDescriptor + + # The default tensor quantizer is set to use Max calibration method + input_desc = QuantDescriptor(num_bits=8, calib_method="max") + # The default tensor quantizer is set to be per-channel quantization for weights + weight_desc = QuantDescriptor(num_bits=8, axis=((0,))) + quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) + quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) + # For the test cases, since QDQBert model is tested in one run without calibration, the quantized tensors are set as fake quantized tensors which give float type tensors in the end. + quant_nn.TensorQuantizer.use_fb_fake_quant = True + + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return QDQBertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = QDQBertModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = QDQBertModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = QDQBertLMHeadModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = QDQBertForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_model_for_causal_lm_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = QDQBertLMHeadModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + encoder_hidden_states=encoder_hidden_states, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = QDQBertLMHeadModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_for_next_sequence_prediction( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = QDQBertForNextSentencePrediction(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=sequence_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 2)) + + def create_and_check_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = QDQBertForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = QDQBertForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = QDQBertForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = QDQBertForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +@require_pytorch_quantization +class QDQBertModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + QDQBertModel, + QDQBertForMaskedLM, + QDQBertForMultipleChoice, + QDQBertForNextSentencePrediction, + QDQBertForQuestionAnswering, + QDQBertForSequenceClassification, + QDQBertForTokenClassification, + QDQBertLMHeadModel, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (QDQBertLMHeadModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = QDQBertModelTester(self) + self.config_tester = ConfigTester(self, config_class=QDQBertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_as_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + + def test_model_as_decoder_with_default_input_mask(self): + # This regression test was failing with PyTorch < 1.3 + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + input_mask = None + + self.model_tester.create_and_check_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_causal_lm_decoder(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_next_sequence_prediction(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in QDQBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = QDQBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + # Override + def test_feed_forward_chunking(self): + # feed forward chunking is not supported in QDQBert + pass + + +@require_torch +@require_pytorch_quantization +class QDQBertModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + # Set default quantizers before creating the model. + import pytorch_quantization.nn as quant_nn + from pytorch_quantization.tensor_quant import QuantDescriptor + + # The default tensor quantizer is set to use Max calibration method + input_desc = QuantDescriptor(num_bits=8, calib_method="max") + # The default tensor quantizer is set to be per-channel quantization for weights + weight_desc = QuantDescriptor(num_bits=8, axis=((0,))) + quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) + quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) + + model = QDQBertModel.from_pretrained("bert-base-uncased") + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + output = model(input_ids, attention_mask=attention_mask)[0] + expected_shape = torch.Size((1, 11, 768)) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.tensor( + [[[0.4571, -0.0735, 0.8594], [0.2774, -0.0278, 0.8794], [0.3548, -0.0473, 0.7593]]] + ) + self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))