diff --git a/.circleci/config.yml b/.circleci/config.yml index 58621a9ba9e..f2ee432c241 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -279,7 +279,7 @@ jobs: - v0.4-tf-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech] + - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision] - run: pip install tensorflow_probability - save_cache: key: v0.4-tf-{{ checksum "setup.py" }} @@ -313,7 +313,7 @@ jobs: - v0.4-tf-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech] + - run: pip install .[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision] - run: pip install tensorflow_probability - save_cache: key: v0.4-tf-{{ checksum "setup.py" }} diff --git a/.github/workflows/self-nightly-scheduled.yml b/.github/workflows/self-nightly-scheduled.yml index 6f76e9e8a39..93e9e317a0c 100644 --- a/.github/workflows/self-nightly-scheduled.yml +++ b/.github/workflows/self-nightly-scheduled.yml @@ -205,8 +205,9 @@ jobs: apt -y update && apt install -y libaio-dev pip install --upgrade pip pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U - pip install .[testing,deepspeed,fairscale] - pip install git+https://github.com/microsoft/DeepSpeed + rm -rf ~/.cache/torch_extensions/ # shared between conflicting builds + pip install .[testing,fairscale] + pip install git+https://github.com/microsoft/DeepSpeed # testing bleeding edge - name: Are GPUs recognized by our DL frameworks run: | @@ -218,7 +219,7 @@ jobs: - name: Run all tests on GPU run: | python -m pytest -n 1 -v --dist=loadfile --make-reports=tests_torch_cuda_extensions_multi_gpu tests/deepspeed tests/extended - + - name: Failure short reports if: ${{ always() }} run: cat reports/tests_torch_cuda_extensions_multi_gpu_failures_short.txt diff --git a/.github/workflows/self-push.yml b/.github/workflows/self-push.yml index 57473f45b0e..5d4218787f0 100644 --- a/.github/workflows/self-push.yml +++ b/.github/workflows/self-push.yml @@ -50,7 +50,7 @@ jobs: python -c "import torch; print('Cuda version:', torch.version.cuda)" python -c "import torch; print('CuDNN version:', torch.backends.cudnn.version())" python -c "import torch; print('Number of GPUs available:', torch.cuda.device_count())" - + - name: Fetch the tests to run run: | python utils/tests_fetcher.py --diff_with_last_commit | tee test_preparation.txt @@ -105,7 +105,7 @@ jobs: run: | python -c "from jax.lib import xla_bridge; print('GPU available:', xla_bridge.get_backend().platform)" python -c "import jax; print('Number of GPUs available:', len(jax.local_devices()))" - + - name: Fetch the tests to run run: | python utils/tests_fetcher.py --diff_with_last_commit | tee test_preparation.txt @@ -203,7 +203,7 @@ jobs: apt install -y libsndfile1-dev pip install --upgrade pip pip install .[sklearn,testing,onnxruntime,sentencepiece,torch-speech,vision,timm] - + - name: Launcher docker uses: actions/checkout@v2 with: @@ -277,7 +277,7 @@ jobs: # run: | # python -c "from jax.lib import xla_bridge; print('GPU available:', xla_bridge.get_backend().platform)" # python -c "import jax; print('Number of GPUs available:', len(jax.local_devices()))" -# +# # - name: Fetch the tests to run # run: | # python utils/tests_fetcher.py --diff_with_last_commit | tee test_preparation.txt @@ -389,11 +389,11 @@ jobs: python -c "import torch; print('Cuda version:', torch.version.cuda)" python -c "import torch; print('CuDNN version:', torch.backends.cudnn.version())" python -c "import torch; print('Number of GPUs available:', torch.cuda.device_count())" - + - name: Fetch the tests to run run: | python utils/tests_fetcher.py --diff_with_last_commit --filters tests/deepspeed tests/extended | tee test_preparation.txt - + - name: Report fetched tests uses: actions/upload-artifact@v2 with: @@ -437,6 +437,7 @@ jobs: run: | apt -y update && apt install -y libaio-dev pip install --upgrade pip + rm -rf ~/.cache/torch_extensions/ # shared between conflicting builds pip install .[testing,deepspeed,fairscale] - name: Are GPUs recognized by our DL frameworks diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index 0027e139975..f6b3a617589 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -143,7 +143,7 @@ jobs: run: | apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip - pip install .[sklearn,testing,onnx,sentencepiece,tf-speech] + pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] - name: Are GPUs recognized by our DL frameworks run: | @@ -293,7 +293,7 @@ jobs: run: | apt -y update && apt install -y libsndfile1-dev git pip install --upgrade pip - pip install .[sklearn,testing,onnx,sentencepiece,tf-speech] + pip install .[sklearn,testing,onnx,sentencepiece,tf-speech,vision] - name: Are GPUs recognized by our DL frameworks run: | @@ -429,6 +429,7 @@ jobs: run: | apt -y update && apt install -y libaio-dev pip install --upgrade pip + rm -rf ~/.cache/torch_extensions/ # shared between conflicting builds pip install .[testing,deepspeed,fairscale] - name: Are GPUs recognized by our DL frameworks diff --git a/docs/source/main_classes/deepspeed.rst b/docs/source/main_classes/deepspeed.rst index db639bb53d5..5b2e6e64e5c 100644 --- a/docs/source/main_classes/deepspeed.rst +++ b/docs/source/main_classes/deepspeed.rst @@ -46,6 +46,20 @@ won't be possible on a single GPU. parts of DeepSpeed like ``zero.Init`` for ZeRO stage 3 and higher. To tap into this feature read the docs on :ref:`deepspeed-non-trainer-integration`. +What is integrated: + +Training: + +1. DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 with ZeRO-Infinity (CPU and NVME offload). + +Inference: + +1. DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but + it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see: + :ref:`deepspeed-zero-inference`. + +There is also DeepSpeed Inference - this is a totally different technology which uses Tensor Parallelism instead of +ZeRO (coming soon). @@ -1628,6 +1642,47 @@ larger multi-dimensional shape, this means that the parameter is partitioned and +.. _deepspeed-zero-inference: + + +ZeRO Inference +======================================================================================================================= + +ZeRO Inference uses the same config as ZeRO-3 Training. You just don't need the optimizer and scheduler sections. In +fact you can leave these in the config file if you want to share the same one with the training. They will just be +ignored. + +Otherwise you just need to pass the usual :class:`~transformers.TrainingArguments` arguments. For example: + +.. code-block:: bash + + deepspeed --num_gpus=2 your_program.py --do_eval --deepspeed ds_config.json + +The only important thing is that you need to use a ZeRO-3 configuration, since ZeRO-2 provides no benefit whatsoever +for the inference as only ZeRO-3 performs sharding of parameters, whereas ZeRO-1 shards gradients and optimizer states. + +Here is an example of running ``run_translation.py`` under DeepSpeed deploying all available GPUs: + +.. code-block:: bash + + deepspeed examples/pytorch/translation/run_translation.py \ + --deepspeed tests/deepspeed/ds_config_zero3.json \ + --model_name_or_path t5-small --output_dir output_dir \ + --do_eval --max_eval_samples 50 --warmup_steps 50 \ + --max_source_length 128 --val_max_target_length 128 \ + --overwrite_output_dir --per_device_eval_batch_size 4 \ + --predict_with_generate --dataset_config "ro-en" --fp16 \ + --source_lang en --target_lang ro --dataset_name wmt16 \ + --source_prefix "translate English to Romanian: " + +Since for inference there is no need for additional large memory used by the optimizer states and the gradients you +should be able to fit much larger batches and/or sequence length onto the same hardware. + + +Additionally DeepSpeed is currently developing a related product called Deepspeed-Inference which has no relationship +to the ZeRO technology, but instead uses tensor parallelism to scale models that can't fit onto a single GPU. This is a +work in progress and we will provide the integration once that product is complete. + Filing Issues ======================================================================================================================= diff --git a/docs/source/main_classes/tokenizer.rst b/docs/source/main_classes/tokenizer.rst index 8ef1ac56ba1..18798e9b49d 100644 --- a/docs/source/main_classes/tokenizer.rst +++ b/docs/source/main_classes/tokenizer.rst @@ -39,7 +39,8 @@ methods for using all the tokenizers: - Managing special tokens (like mask, beginning-of-sentence, etc.): adding them, assigning them to attributes in the tokenizer for easy access and making sure they are not split during tokenization. -:class:`~transformers.BatchEncoding` holds the output of the tokenizer's encoding methods (``__call__``, +:class:`~transformers.BatchEncoding` holds the output of the +:class:`~transformers.tokenization_utils_base.PreTrainedTokenizerBase`'s encoding methods (``__call__``, ``encode_plus`` and ``batch_encode_plus``) and is derived from a Python dictionary. When the tokenizer is a pure python tokenizer, this class behaves just like a standard python dictionary and holds the various model inputs computed by these methods (``input_ids``, ``attention_mask``...). When the tokenizer is a "Fast" tokenizer (i.e., backed by diff --git a/docs/source/model_doc/imagegpt.rst b/docs/source/model_doc/imagegpt.rst index 2f332aa645c..9b32b429001 100644 --- a/docs/source/model_doc/imagegpt.rst +++ b/docs/source/model_doc/imagegpt.rst @@ -96,10 +96,10 @@ ImageGPTModel :members: forward -ImageGPTForCausalLM +ImageGPTForCausalImageModeling ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.ImageGPTForCausalLM +.. autoclass:: transformers.ImageGPTForCausalImageModeling :members: forward diff --git a/docs/source/quicktour.rst b/docs/source/quicktour.rst index a853944b44f..0abf9bccada 100644 --- a/docs/source/quicktour.rst +++ b/docs/source/quicktour.rst @@ -51,6 +51,15 @@ The easiest way to use a pretrained model on a given task is to use :func:`~tran Let's see how this work for sentiment analysis (the other tasks are all covered in the :doc:`task summary `): +Install the following dependencies (if not already installed): + +.. code-block:: bash + + ## PYTORCH CODE + pip install torch + ## TENSORFLOW CODE + pip install tensorflow + .. code-block:: >>> from transformers import pipeline @@ -337,27 +346,42 @@ Once your model is fine-tuned, you can save it with its tokenizer in the followi .. code-block:: - tokenizer.save_pretrained(save_directory) - model.save_pretrained(save_directory) + >>> ## PYTORCH CODE + >>> pt_save_directory = './pt_save_pretrained' + >>> tokenizer.save_pretrained(pt_save_directory) + >>> pt_model.save_pretrained(pt_save_directory) + >>> ## TENSORFLOW CODE + >>> tf_save_directory = './tf_save_pretrained' + >>> tokenizer.save_pretrained(tf_save_directory) + >>> tf_model.save_pretrained(tf_save_directory) You can then load this model back using the :func:`~transformers.AutoModel.from_pretrained` method by passing the directory name instead of the model name. One cool feature of 🤗 Transformers is that you can easily switch between -PyTorch and TensorFlow: any model saved as before can be loaded back either in PyTorch or TensorFlow. If you are -loading a saved PyTorch model in a TensorFlow model, use :func:`~transformers.TFAutoModel.from_pretrained` like this: +PyTorch and TensorFlow: any model saved as before can be loaded back either in PyTorch or TensorFlow. + + +If you would like to load your saved model in the other framework, first make sure it is installed: + +.. code-block:: bash + + ## PYTORCH CODE + pip install tensorflow + ## TENSORFLOW CODE + pip install torch + +Then, use the corresponding Auto class to load it like this: .. code-block:: - from transformers import TFAutoModel - tokenizer = AutoTokenizer.from_pretrained(save_directory) - model = TFAutoModel.from_pretrained(save_directory, from_pt=True) + ## PYTORCH CODE + >>> from transformers import TFAutoModel + >>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory) + >>> tf_model = TFAutoModel.from_pretrained(pt_save_directory, from_pt=True) + ## TENSORFLOW CODE + >>> from transformers import AutoModel + >>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory) + >>> pt_model = AutoModel.from_pretrained(tf_save_directory, from_tf=True) -and if you are loading a saved TensorFlow model in a PyTorch model, you should use the following code: - -.. code-block:: - - from transformers import AutoModel - tokenizer = AutoTokenizer.from_pretrained(save_directory) - model = AutoModel.from_pretrained(save_directory, from_tf=True) Lastly, you can also ask the model to return all hidden states and all attention weights if you need them: diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 95c313c6d30..50054a6044e 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -27,6 +27,7 @@ import os import sys import time from dataclasses import dataclass, field +from itertools import chain from pathlib import Path from typing import Callable, Optional @@ -430,7 +431,7 @@ def main(): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 322479148db..3be4bf387d1 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -25,6 +25,7 @@ import os import sys import time from dataclasses import dataclass, field +from itertools import chain # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. from pathlib import Path @@ -453,7 +454,7 @@ if __name__ == "__main__": # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index e75b0f290f4..b78dc0431ad 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -25,6 +25,7 @@ import os import sys import time from dataclasses import dataclass, field +from itertools import chain from pathlib import Path from typing import Dict, List, Optional @@ -563,7 +564,7 @@ if __name__ == "__main__": # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 444df1b8091..f098f139ae8 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -26,6 +26,7 @@ import math import os import sys from dataclasses import dataclass, field +from itertools import chain from typing import Optional import datasets @@ -408,7 +409,7 @@ def main(): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index ef9edffb348..ed0702e3bb2 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -27,6 +27,7 @@ import logging import math import os import random +from itertools import chain from pathlib import Path import datasets @@ -366,7 +367,7 @@ def main(): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index a1b5b7aca38..3f8ab03f45e 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -26,6 +26,7 @@ import math import os import sys from dataclasses import dataclass, field +from itertools import chain from typing import Optional import datasets @@ -432,7 +433,7 @@ def main(): # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index e356741dafe..2fc492daa16 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -27,6 +27,7 @@ import logging import math import os import random +from itertools import chain from pathlib import Path import datasets @@ -406,7 +407,7 @@ def main(): # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index 840bfa9ad67..063393e0a42 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -23,6 +23,7 @@ import math import os import sys from dataclasses import dataclass, field +from itertools import chain from typing import Optional import datasets @@ -403,7 +404,7 @@ def main(): # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index 54a80a5c327..b18ea1288c1 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -22,6 +22,7 @@ import logging import os import sys from dataclasses import dataclass, field +from itertools import chain from typing import Optional, Union import datasets @@ -185,7 +186,7 @@ class DataCollatorForMultipleChoice: flattened_features = [ [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features ] - flattened_features = sum(flattened_features, []) + flattened_features = list(chain(*flattened_features)) batch = self.tokenizer.pad( flattened_features, @@ -333,8 +334,8 @@ def main(): ] # Flatten out - first_sentences = sum(first_sentences, []) - second_sentences = sum(second_sentences, []) + first_sentences = list(chain(*first_sentences)) + second_sentences = list(chain(*second_sentences)) # Tokenize tokenized_examples = tokenizer( diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index 07d212a65a2..6f0f38a8318 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -24,6 +24,7 @@ import math import os import random from dataclasses import dataclass +from itertools import chain from pathlib import Path from typing import Optional, Union @@ -224,7 +225,7 @@ class DataCollatorForMultipleChoice: flattened_features = [ [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features ] - flattened_features = sum(flattened_features, []) + flattened_features = list(chain(*flattened_features)) batch = self.tokenizer.pad( flattened_features, @@ -365,8 +366,8 @@ def main(): labels = examples[label_column_name] # Flatten out - first_sentences = sum(first_sentences, []) - second_sentences = sum(second_sentences, []) + first_sentences = list(chain(*first_sentences)) + second_sentences = list(chain(*second_sentences)) # Tokenize tokenized_examples = tokenizer( diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index 34fa5d3b159..c56f10478f5 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -23,6 +23,7 @@ import os import sys import time from dataclasses import dataclass, field +from itertools import chain from pathlib import Path from typing import Callable, Optional @@ -364,7 +365,7 @@ def main(): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 5f1adc5ccf6..d8383b0f242 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -30,6 +30,7 @@ import random import sys from dataclasses import dataclass, field from functools import partial +from itertools import chain from pathlib import Path from typing import Optional @@ -406,7 +407,7 @@ def main(): # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py index 244a3a9a475..c4f318416cf 100755 --- a/examples/tensorflow/language-modeling/run_mlm.py +++ b/examples/tensorflow/language-modeling/run_mlm.py @@ -32,6 +32,7 @@ import random import sys from dataclasses import dataclass, field from functools import partial +from itertools import chain from pathlib import Path from typing import Optional @@ -462,7 +463,7 @@ def main(): # max_seq_length. def group_texts(examples): # Concatenate all texts. - concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. diff --git a/examples/tensorflow/multiple-choice/run_swag.py b/examples/tensorflow/multiple-choice/run_swag.py index 77dab86f5b7..56e6012ac60 100644 --- a/examples/tensorflow/multiple-choice/run_swag.py +++ b/examples/tensorflow/multiple-choice/run_swag.py @@ -22,6 +22,7 @@ import logging import os import sys from dataclasses import dataclass, field +from itertools import chain from pathlib import Path from typing import Optional @@ -342,8 +343,8 @@ def main(): ] # Flatten out - first_sentences = sum(first_sentences, []) - second_sentences = sum(second_sentences, []) + first_sentences = list(chain(*first_sentences)) + second_sentences = list(chain(*second_sentences)) # Tokenize tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True, max_length=max_seq_length) diff --git a/setup.py b/setup.py index cf96f9e4ef1..4d59a717f27 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ _deps = [ "cookiecutter==1.7.2", "dataclasses", "datasets", - "deepspeed>=0.5.3", + "deepspeed>=0.5.7", "docutils==0.16.0", "fairscale>0.3", "faiss-cpu", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7881c0fa0ba..a6d0b2a2ebb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -94,7 +94,8 @@ _import_structure = { "DataCollatorWithPadding", "default_data_collator", ], - "feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"], + "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], + "feature_extraction_utils": ["BatchFeature"], "file_utils": [ "CONFIG_NAME", "MODEL_CARD_NAME", @@ -618,6 +619,7 @@ if is_torch_available(): _import_structure["models.auto"].extend( [ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", @@ -976,7 +978,7 @@ if is_torch_available(): _import_structure["models.imagegpt"].extend( [ "IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST", - "ImageGPTForCausalLM", + "ImageGPTForCausalImageModeling", "ImageGPTForImageClassification", "ImageGPTModel", "ImageGPTPreTrainedModel", @@ -2071,9 +2073,10 @@ if TYPE_CHECKING: DataCollatorWithPadding, default_data_collator, ) + from .feature_extraction_sequence_utils import SequenceFeatureExtractor # Feature Extractor - from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor + from .feature_extraction_utils import BatchFeature # Files and general utilities from .file_utils import ( @@ -2531,6 +2534,7 @@ if TYPE_CHECKING: ) from .models.auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, @@ -2833,7 +2837,7 @@ if TYPE_CHECKING: ) from .models.imagegpt import ( IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST, - ImageGPTForCausalLM, + ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel, ImageGPTPreTrainedModel, diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index bb5d25d4b23..edbcbd50cca 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -111,6 +111,29 @@ class HfDeepSpeedConfig: return default return config.get(ds_key, default) + def del_config_sub_tree(self, ds_key_long, must_exist=False): + """ + Deletes a sub-section of the config file if it's found. + + Unless ``must_exist`` is :obj:`True` the section doesn't have to exist. + """ + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + for node in nodes: + parent_config = config + config = config.get(node) + if config is None: + if must_exist: + raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}") + else: + return + + # if found remove it + if parent_config is not None: + parent_config.pop(node) + def is_true(self, ds_key_long): """ Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to @@ -280,30 +303,10 @@ def deepspeed_config(): return None -def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): +def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps): """ - Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. - - If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made. - - Args: - trainer: Trainer object - num_training_steps: per single gpu - resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load - - Returns: model, optimizer, lr_scheduler - + A convenience wrapper that deals with optimizer and lr scheduler configuration. """ - import deepspeed - from deepspeed.utils import logger as ds_logger - - model = trainer.model - args = trainer.args - - hf_deepspeed_config = args.hf_deepspeed_config - hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - - # resume config update - some bits like `model` and `num_training_steps` only become available during train config = hf_deepspeed_config.config # Optimizer + Scheduler @@ -351,13 +354,54 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): else: lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) - # keep for quick debug: - # from pprint import pprint; pprint(config) + return optimizer, lr_scheduler - # set the Deepspeed log level consistent with the trainer + +def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): + """ + Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. + + If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made. + + Args: + trainer: Trainer object + num_training_steps: per single gpu + resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load + inference: launch in inference mode (no optimizer and no lr scheduler) + + Returns: model, optimizer, lr_scheduler + + """ + import deepspeed + from deepspeed.utils import logger as ds_logger + + model = trainer.model + args = trainer.args + + # resume config update - some bits like `model` and `num_training_steps` only become available during train + hf_deepspeed_config = args.hf_deepspeed_config + hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) + config = hf_deepspeed_config.config + + # set the Deepspeed log level consistent with the Trainer ds_logger.setLevel(args.get_process_log_level()) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + if inference: + # only Z3 makes sense for the inference + if not hf_deepspeed_config.is_zero3(): + raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") + + # in case the training config is re-used for inference + hf_deepspeed_config.del_config_sub_tree("optimizer") + hf_deepspeed_config.del_config_sub_tree("lr_scheduler") + optimizer, lr_scheduler = None, None + model_parameters = None + else: + optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + + # keep for quick debug: + # from pprint import pprint; pprint(config) model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 294cd16c9b1..b074ffe13a3 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -8,7 +8,7 @@ deps = { "cookiecutter": "cookiecutter==1.7.2", "dataclasses": "dataclasses", "datasets": "datasets", - "deepspeed": "deepspeed>=0.5.3", + "deepspeed": "deepspeed>=0.5.7", "docutils": "docutils==0.16.0", "fairscale": "fairscale>0.3", "faiss-cpu": "faiss-cpu", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 502effc8cf3..5294f3aab79 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -35,6 +35,7 @@ from dataclasses import fields from enum import Enum from functools import partial, wraps from hashlib import sha256 +from itertools import chain from pathlib import Path from types import ModuleType from typing import Any, BinaryIO, ContextManager, Dict, List, Optional, Tuple, Union @@ -2148,7 +2149,7 @@ class _LazyModule(ModuleType): for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE - self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), []) + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 03f360403f5..bd6e3a369bd 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -32,6 +32,7 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_auto"] = [ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", @@ -139,6 +140,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index dc534c6ccf1..403c59c67d1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -147,7 +147,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - ("imagegpt", "ImageGPTForCausalLM"), ("qdqbert", "QDQBertForMaskedLM"), ("fnet", "FNetForMaskedLM"), ("gptj", "GPTJForCausalLM"), @@ -199,7 +198,6 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - ("imagegpt", "ImageGPTForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), ("trocr", "TrOCRForCausalLM"), ("gptj", "GPTJForCausalLM"), @@ -233,6 +231,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( + # Model for Causal Image Modeling mapping + [ + ("imagegpt", "ImageGPTForCausalImageModeling"), + ] +) + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image Classification mapping @@ -524,6 +529,9 @@ MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES +) MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ) diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 2a2f7bffb45..00a3e6d4034 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -444,7 +444,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) - return init_variables["cache"] + return unfreeze(init_variables["cache"]) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) def __call__( diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index a62e52e3bcb..c43343ecaf8 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -388,7 +388,7 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel): init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) - return init_variables["cache"] + return unfreeze(init_variables["cache"]) @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) def __call__( diff --git a/src/transformers/models/imagegpt/__init__.py b/src/transformers/models/imagegpt/__init__.py index 16be206d418..4fc9496ee90 100644 --- a/src/transformers/models/imagegpt/__init__.py +++ b/src/transformers/models/imagegpt/__init__.py @@ -31,7 +31,7 @@ if is_vision_available(): if is_torch_available(): _import_structure["modeling_imagegpt"] = [ "IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST", - "ImageGPTForCausalLM", + "ImageGPTForCausalImageModeling", "ImageGPTForImageClassification", "ImageGPTModel", "ImageGPTPreTrainedModel", @@ -48,7 +48,7 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_imagegpt import ( IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST, - ImageGPTForCausalLM, + ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel, ImageGPTPreTrainedModel, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 6f3a3c3c2af..4652774d07a 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -881,7 +881,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): """, IMAGEGPT_START_DOCSTRING, ) -class ImageGPTForCausalLM(ImageGPTPreTrainedModel): +class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): @@ -958,13 +958,13 @@ class ImageGPTForCausalLM(ImageGPTPreTrainedModel): Examples:: - >>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalLM + >>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling >>> import torch >>> import matplotlib.pyplot as plt >>> import numpy as np >>> feature_extractor = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small') - >>> model = ImageGPTForCausalLM.from_pretrained('openai/imagegpt-small') + >>> model = ImageGPTForCausalImageModeling.from_pretrained('openai/imagegpt-small') >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> model.to(device) diff --git a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py index 0fa06b670b3..7a8c4fab7bf 100644 --- a/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py @@ -47,11 +47,11 @@ def normalize_box(box, width, height): ] -def apply_tesseract(image: Image.Image): +def apply_tesseract(image: Image.Image, lang: Optional[str]): """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" # apply OCR - data = pytesseract.image_to_data(image, output_type="dict") + data = pytesseract.image_to_data(image, lang=lang, output_type="dict") words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] # filter empty words and corresponding coordinates @@ -102,6 +102,9 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM Only has an effect if :obj:`do_resize` is set to :obj:`True`. apply_ocr (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. + ocr_lang (:obj:`Optional[str]`, `optional`): + The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is + used. .. note:: @@ -110,12 +113,13 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM model_input_names = ["pixel_values"] - def __init__(self, do_resize=True, size=224, resample=Image.BILINEAR, apply_ocr=True, **kwargs): + def __init__(self, do_resize=True, size=224, resample=Image.BILINEAR, apply_ocr=True, ocr_lang=None, **kwargs): super().__init__(**kwargs) self.do_resize = do_resize self.size = size self.resample = resample self.apply_ocr = apply_ocr + self.ocr_lang = ocr_lang if apply_ocr: requires_backends(self, "pytesseract") @@ -199,7 +203,7 @@ class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM words_batch = [] boxes_batch = [] for image in images: - words, boxes = apply_tesseract(self.to_pil_image(image)) + words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang) words_batch.append(words) boxes_batch.append(boxes) diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py index 940a3e03779..2c1f6eb7121 100644 --- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py @@ -1275,7 +1275,7 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer): if "bbox" in encoded_inputs: encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] if "labels" in encoded_inputs: - encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["bbox"] + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py index 0a2c84469f7..73a2cc2cb34 100644 --- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py +++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py @@ -746,7 +746,7 @@ class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast): if "bbox" in encoded_inputs: encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] if "labels" in encoded_inputs: - encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["bbox"] + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py index e71e8eef2ca..0e40cb06fe2 100644 --- a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py +++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py @@ -1051,7 +1051,7 @@ class LayoutXLMTokenizer(PreTrainedTokenizer): if "bbox" in encoded_inputs: encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] if "labels" in encoded_inputs: - encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["bbox"] + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py index 4f91da1f1ca..4b9170250f5 100644 --- a/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py +++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm_fast.py @@ -614,7 +614,7 @@ class LayoutXLMTokenizerFast(PreTrainedTokenizerFast): if "bbox" in encoded_inputs: encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] if "labels" in encoded_inputs: - encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["bbox"] + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 0da576be742..d72ad37dcff 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1747,6 +1747,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): init_configuration, *init_inputs, use_auth_token=use_auth_token, + cache_dir=cache_dir, **kwargs, ) @@ -1758,6 +1759,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): init_configuration, *init_inputs, use_auth_token=None, + cache_dir=None, **kwargs ): # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json @@ -1797,7 +1799,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Second attempt. If we have not yet found tokenizer_class, let's try to use the config. try: - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, use_auth_token=use_auth_token) + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + ) config_tokenizer_class = config.tokenizer_class except (OSError, ValueError, KeyError): # skip if an error occurred. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f954fe3ae01..7e6d5002657 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2229,15 +2229,12 @@ class Trainer: # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) self.model = deepspeed_engine.module self.model_wrapped = deepspeed_engine self.deepspeed = deepspeed_engine - # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since - # for example the Z3-optimizer is a must for zero3 to work even for inference - what we - # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer - deepspeed_engine.optimizer.optimizer = None - deepspeed_engine.lr_scheduler = None model = self._wrap_model(self.model, training=False) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fa8bb6d04c1..77cc378926d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -341,6 +341,9 @@ def load_tf_weights_in_albert(*args, **kwargs): MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None +MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = None + + MODEL_FOR_CAUSAL_LM_MAPPING = None @@ -2661,7 +2664,7 @@ class IBertPreTrainedModel: IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST = None -class ImageGPTForCausalLM: +class ImageGPTForCausalImageModeling: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 5c06d8b57f4..8e7587235df 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -697,11 +697,10 @@ class TestDeepSpeedWithLauncher(TestCasePlus): def test_basic_distributed(self, stage): self.run_and_check(stage=stage, distributed=True) - @parameterized.expand(stages) - def test_do_eval_no_train(self, stage): - # we should not fail if train is skipped + def test_do_eval_no_train(self): + # testing only zero3 since zero2 makes no sense with inference self.run_and_check( - stage=stage, + stage=ZERO3, eval_steps=1, distributed=False, do_train=False, @@ -755,6 +754,22 @@ class TestDeepSpeedWithLauncher(TestCasePlus): self.do_checks(output_dir, do_train=do_train, do_eval=do_eval) + @require_torch_multi_gpu + @parameterized.expand(["fp16", "fp32"]) + def test_inference(self, dtype): + # this is just inference, so no optimizer should be loaded + # it only works for z3 (makes no sense with z1-z2) + fp16 = True if dtype == "fp16" else False + self.run_and_check( + stage=ZERO3, + model_name=T5_TINY, + distributed=True, + do_train=False, + do_eval=True, + quality_checks=False, + fp16=fp16, + ) + def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True): if do_train: diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index 9ead09a7d35..db8bd8c6d07 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -414,6 +414,7 @@ def prepare_img(): return image +@require_torch @require_vision class BeitModelIntegrationTest(unittest.TestCase): @cached_property diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 05c980c6422..6640028293e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -61,6 +61,7 @@ if is_torch_available(): from transformers import ( BERT_PRETRAINED_MODEL_ARCHIVE_LIST, + MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, @@ -150,6 +151,7 @@ class ModelTesterMixin: elif model_class in [ *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING), + *get_values(MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING), *get_values(MODEL_FOR_MASKED_LM_MAPPING), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING), ]: diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py index 222f2afbe2e..925dbc6b0eb 100644 --- a/tests/test_modeling_deit.py +++ b/tests/test_modeling_deit.py @@ -391,6 +391,7 @@ def prepare_img(): return image +@require_torch @require_vision class DeiTModelIntegrationTest(unittest.TestCase): @cached_property diff --git a/tests/test_modeling_imagegpt.py b/tests/test_modeling_imagegpt.py index 6a2562d407e..85526800dcb 100644 --- a/tests/test_modeling_imagegpt.py +++ b/tests/test_modeling_imagegpt.py @@ -34,7 +34,7 @@ if is_torch_available(): from transformers import ( IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST, - ImageGPTForCausalLM, + ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel, ) @@ -207,14 +207,14 @@ class ImageGPTModelTester: self.parent.assertEqual(len(result.past_key_values), config.n_layer) def create_and_check_lm_head_model(self, config, pixel_values, input_mask, head_mask, token_type_ids, *args): - model = ImageGPTForCausalLM(config) + model = ImageGPTForCausalImageModeling(config) model.to(torch_device) model.eval() labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1) result = model(pixel_values, token_type_ids=token_type_ids, labels=labels) self.parent.assertEqual(result.loss.shape, ()) - # ImageGPTForCausalLM doens't have tied input- and output embeddings + # ImageGPTForCausalImageModeling doens't have tied input- and output embeddings self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size - 1)) def create_and_check_imagegpt_for_image_classification( @@ -255,9 +255,9 @@ class ImageGPTModelTester: class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( - (ImageGPTForCausalLM, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else () + (ImageGPTForCausalImageModeling, ImageGPTForImageClassification, ImageGPTModel) if is_torch_available() else () ) - all_generative_model_classes = (ImageGPTForCausalLM,) if is_torch_available() else () + all_generative_model_classes = (ImageGPTForCausalImageModeling,) if is_torch_available() else () test_missing_keys = False input_name = "pixel_values" @@ -273,7 +273,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa return inputs_dict - # we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalLM doesn't have tied input- and output embeddings + # we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size - 1) self.assertIsInstance(scores, tuple) @@ -519,7 +519,7 @@ class ImageGPTModelIntegrationTest(unittest.TestCase): @slow def test_inference_causal_lm_head(self): - model = ImageGPTForCausalLM.from_pretrained("openai/imagegpt-small").to(torch_device) + model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small").to(torch_device) feature_extractor = self.default_feature_extractor image = prepare_img() diff --git a/tests/test_modeling_tf_vit.py b/tests/test_modeling_tf_vit.py index eb342aa68da..ea493fc593c 100644 --- a/tests/test_modeling_tf_vit.py +++ b/tests/test_modeling_tf_vit.py @@ -353,7 +353,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): - model = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=True) + model = TFViTModel.from_pretrained("google/vit-base-patch16-224") self.assertIsNotNone(model) @@ -363,6 +363,7 @@ def prepare_img(): return image +@require_tf @require_vision class TFViTModelIntegrationTest(unittest.TestCase): @cached_property diff --git a/tests/test_modeling_vit.py b/tests/test_modeling_vit.py index 6073bf2392d..c24ae535a13 100644 --- a/tests/test_modeling_vit.py +++ b/tests/test_modeling_vit.py @@ -331,6 +331,7 @@ def prepare_img(): return image +@require_torch @require_vision class ViTModelIntegrationTest(unittest.TestCase): @cached_property diff --git a/tests/test_pipelines_audio_classification.py b/tests/test_pipelines_audio_classification.py index f01825dd990..ef2dc26aa55 100644 --- a/tests/test_pipelines_audio_classification.py +++ b/tests/test_pipelines_audio_classification.py @@ -114,12 +114,12 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest audio = np.array(dataset[3]["speech"], dtype=np.float32) output = audio_classifier(audio, top_k=4) self.assertEqual( - nested_simplify(output, decimals=4), + nested_simplify(output, decimals=3), [ - {"score": 0.9809, "label": "go"}, - {"score": 0.0073, "label": "up"}, - {"score": 0.0064, "label": "_unknown_"}, - {"score": 0.0015, "label": "down"}, + {"score": 0.981, "label": "go"}, + {"score": 0.007, "label": "up"}, + {"score": 0.006, "label": "_unknown_"}, + {"score": 0.001, "label": "down"}, ], )