From 7db2a79b387fd862ffb0af72f7148e6371339c7f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 30 Sep 2021 16:38:07 +0530 Subject: [PATCH] [examples/flax] use Repository API for push_to_hub (#13672) * use Repository for push_to_hub * update readme * update other flax scripts * update readme * update qa example * fix push_to_hub call * fix typo * fix more typos * update readme * use abosolute path to get repo name * fix glue script --- examples/flax/README.md | 11 ++ examples/flax/language-modeling/README.md | 111 ++++-------------- .../flax/language-modeling/run_clm_flax.py | 22 +++- .../flax/language-modeling/run_mlm_flax.py | 22 +++- .../flax/language-modeling/run_t5_mlm_flax.py | 22 +++- examples/flax/question-answering/README.md | 32 +---- examples/flax/question-answering/run_qa.py | 23 +++- examples/flax/summarization/README.md | 33 +----- .../summarization/run_summarization_flax.py | 23 ++-- examples/flax/text-classification/README.md | 44 +------ .../flax/text-classification/run_flax_glue.py | 28 ++++- examples/flax/token-classification/README.md | 27 +---- .../flax/token-classification/run_flax_ner.py | 23 +++- examples/flax/vision/README.md | 33 +----- .../flax/vision/run_image_classification.py | 21 +++- 15 files changed, 183 insertions(+), 292 deletions(-) diff --git a/examples/flax/README.md b/examples/flax/README.md index cbcb1648a88..634537c56e2 100644 --- a/examples/flax/README.md +++ b/examples/flax/README.md @@ -61,3 +61,14 @@ For a complete overview of models that are supported in JAX/Flax, please have a Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021. Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub. + +## Upload the trained/fine-tuned model to the Hub + +All the example scripts support automatic upload of your final model to the [Model Hub](https://huggingface.co/models) by adding a `--push_to_hub` argument. It will then create a repository with your username slash the name of the folder you are using as `output_dir`. For instance, `"sgugger/test-mrpc"` if your username is `sgugger` and you are working in the folder `~/tmp/test-mrpc`. + +To specify a given repository name, use the `--hub_model_id` argument. You will need to specify the whole repository name (including your username), for instance `--hub_model_id sgugger/finetuned-bert-mrpc`. To upload to an organization you are a member of, just use the name of that organization instead of your username: `--hub_model_id huggingface/finetuned-bert-mrpc`. + +A few notes on this integration: + +- you will need to be logged in to the Hugging Face website locally for it to work, the easiest way to achieve this is to run `huggingface-cli login` and then type your username and password when prompted. You can also pass along your authentication token with the `--hub_token` argument. +- the `output_dir` you pick will either need to be a new folder or a local clone of the distant repository you are using. diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index b96b78f5db4..435d8618bc6 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -33,32 +33,10 @@ in Norwegian on a single TPUv3-8 pod. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. -Let's start by creating a model repository to save the trained model and logs. -Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create norwegian-roberta-base -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//norwegian-roberta-base -``` - -To setup all relevant files for training, let's go into the cloned model directory. +To setup all relevant files for training, let's create a directory. ```bash -cd norwegian-roberta-base -``` - -Next, let's add a symbolic link to the `run_mlm_flax.py`. - -```bash -ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py +mkdir ./norwegian-roberta-base ``` ### Train tokenizer @@ -92,7 +70,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency= ]) # Save files to disk -tokenizer.save("./tokenizer.json") +tokenizer.save("./norwegian-roberta-base/tokenizer.json") ``` ### Create configuration @@ -105,7 +83,7 @@ in the local model folder: from transformers import RobertaConfig config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265) -config.save_pretrained("./") +config.save_pretrained("./norwegian-roberta-base") ``` Great, we have set up our model repository. During training, we will automatically @@ -116,11 +94,11 @@ push the training logs and model weights to the repo. Next we can run the example script to pretrain the model: ```bash -./run_mlm_flax.py \ - --output_dir="./" \ +python run_mlm_flax.py \ + --output_dir="./norwegian-roberta-base" \ --model_type="roberta" \ - --config_name="./" \ - --tokenizer_name="./" \ + --config_name="./norwegian-roberta-base" \ + --tokenizer_name="./norwegian-roberta-base" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --max_seq_length="128" \ @@ -157,32 +135,11 @@ in Norwegian on a single TPUv3-8 pod. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. -Let's start by creating a model repository to save the trained model and logs. -Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like. -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create norwegian-gpt2 -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//norwegian-gpt2 -``` - -To setup all relevant files for training, let's go into the cloned model directory. +To setup all relevant files for training, let's create a directory. ```bash -cd norwegian-gpt2 -``` - -Next, let's add a symbolic link to the training script `run_clm_flax.py`. - -```bash -ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py +mkdir ./norwegian-gpt2 ``` ### Train tokenizer @@ -216,7 +173,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency= ]) # Save files to disk -tokenizer.save("./tokenizer.json") +tokenizer.save("./norwegian-gpt2/tokenizer.json") ``` ### Create configuration @@ -229,7 +186,7 @@ in the local model folder: from transformers import GPT2Config config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257) -config.save_pretrained("./") +config.save_pretrained("./norwegian-gpt2") ``` Great, we have set up our model repository. During training, we will now automatically @@ -240,11 +197,11 @@ push the training logs and model weights to the repo. Finally, we can run the example script to pretrain the model: ```bash -./run_clm_flax.py \ - --output_dir="./" \ +python run_clm_flax.py \ + --output_dir="./norwegian-gpt2" \ --model_type="gpt2" \ - --config_name="./" \ - --tokenizer_name="./" \ + --config_name="./norwegian-gpt2" \ + --tokenizer_name="./norwegian-gpt2" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --do_train --do_eval \ @@ -282,30 +239,10 @@ The example script uses the 🤗 Datasets library. You can easily customize them Let's start by creating a model repository to save the trained model and logs. Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like. -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create norwegian-t5-base -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//norwegian-t5-base -``` - -To setup all relevant files for trairing, let's go into the cloned model directory. +To setup all relevant files for trairing, let's create a directory. ```bash -cd norwegian-t5-base -``` - -Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts. - -```bash -ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py -ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py +cd ./norwegian-t5-base ``` ### Train tokenizer @@ -351,7 +288,7 @@ tokenizer.train_from_iterator( ) # Save files to disk -tokenizer.save("./tokenizer.json") +tokenizer.save("./norwegian-t5-base/tokenizer.json") ``` ### Create configuration @@ -364,7 +301,7 @@ in the local model folder: from transformers import T5Config config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size()) -config.save_pretrained("./") +config.save_pretrained("./norwegian-t5-base") ``` Great, we have set up our model repository. During training, we will automatically @@ -375,11 +312,11 @@ push the training logs and model weights to the repo. Next we can run the example script to pretrain the model: ```bash -./run_t5_mlm_flax.py \ - --output_dir="./" \ +python run_t5_mlm_flax.py \ + --output_dir="./norwegian-t5-base" \ --model_type="t5" \ - --config_name="./" \ - --tokenizer_name="./" \ + --config_name="./norwegian-t5-base" \ + --tokenizer_name="./norwegian-t5-base" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --max_seq_length="512" \ diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index dfb781c7a4a..95c313c6d30 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -43,6 +43,7 @@ from flax import jax_utils, traverse_util from flax.jax_utils import unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, @@ -54,6 +55,7 @@ from transformers import ( is_tensorboard_available, set_seed, ) +from transformers.file_utils import get_full_repo_name from transformers.testing_utils import CaptureLogger @@ -275,6 +277,16 @@ def main(): # Set seed before initializing model. set_seed(training_args.seed) + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -654,12 +666,10 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {cur_step}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) if __name__ == "__main__": diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 4ac0fce32ce..322479148db 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -41,6 +41,7 @@ import optax from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, @@ -54,6 +55,7 @@ from transformers import ( is_tensorboard_available, set_seed, ) +from transformers.file_utils import get_full_repo_name MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) @@ -308,6 +310,16 @@ if __name__ == "__main__": # Set seed before initializing model. set_seed(training_args.seed) + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -683,9 +695,7 @@ if __name__ == "__main__": # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {cur_step}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 14ef8eb5248..e75b0f290f4 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -39,6 +39,7 @@ import optax from flax import jax_utils, traverse_util from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, @@ -52,6 +53,7 @@ from transformers import ( is_tensorboard_available, set_seed, ) +from transformers.file_utils import get_full_repo_name from transformers.models.t5.modeling_flax_t5 import shift_tokens_right @@ -438,6 +440,16 @@ if __name__ == "__main__": # Set seed before initializing model. set_seed(training_args.seed) + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -791,9 +803,7 @@ if __name__ == "__main__": # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {cur_step}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) diff --git a/examples/flax/question-answering/README.md b/examples/flax/question-answering/README.md index 6b8360349ef..a5f8ebf6b93 100644 --- a/examples/flax/question-answering/README.md +++ b/examples/flax/question-answering/README.md @@ -26,31 +26,6 @@ of the script. The following example fine-tunes BERT on SQuAD: -To begin with it is recommended to create a model repository to save the trained model and logs. -Here we call the model `"bert-qa-squad-test"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create bert-qa-squad-test -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//bert-qa-squad-test -``` - -Great, we have set up our model repository. During training, we will automatically -push the training logs and model weights to the repo. - -Next, let's add a symbolic link to the `run_qa.py`. - -```bash -export MODEL_DIR="./bert-qa-squad-test" -ln -s ~/transformers/examples/flax/question-answering/run_qa.py run_qa.py -``` ```bash python run_qa.py \ @@ -63,7 +38,7 @@ python run_qa.py \ --learning_rate 3e-5 \ --num_train_epochs 2 \ --per_device_train_batch_size 12 \ - --output_dir ${MODEL_DIR} \ + --output_dir ./bert-qa-squad \ --eval_steps 1000 \ --push_to_hub ``` @@ -101,8 +76,9 @@ python run_qa.py \ --num_train_epochs 2 \ --max_seq_length 384 \ --doc_stride 128 \ ---output_dir /tmp/wwm_uncased_finetuned_squad/ \ ---eval_steps 1000 +--output_dir ./wwm_uncased_finetuned_squad/ \ +--eval_steps 1000 \ +--push_to_hub ``` Training with the previously defined hyper-parameters yields the following results: diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index b8f06abf046..e3dab2203fa 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -25,6 +25,7 @@ import sys import time from dataclasses import dataclass, field from itertools import chain +from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple import datasets @@ -41,6 +42,7 @@ from flax.jax_utils import replicate, unreplicate from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository from transformers import ( AutoConfig, AutoTokenizer, @@ -50,6 +52,7 @@ from transformers import ( PreTrainedTokenizerFast, TrainingArguments, ) +from transformers.file_utils import get_full_repo_name from transformers.utils import check_min_version from utils_qa import postprocess_qa_predictions @@ -359,6 +362,16 @@ def main(): transformers.utils.logging.set_verbosity_error() # endregion + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # region Load Data # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ @@ -891,12 +904,10 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {cur_step}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # endregion diff --git a/examples/flax/summarization/README.md b/examples/flax/summarization/README.md index adc9cb15e3f..bbe231f31a5 100644 --- a/examples/flax/summarization/README.md +++ b/examples/flax/summarization/README.md @@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism. For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below. -Let's start by creating a model repository to save the trained model and logs. -Here we call the model `"bart-base-xsum"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create bart-base-xsum -``` -Next we clone the model repository to add the tokenizer and model files. -``` -git clone https://huggingface.co//bart-base-xsum -``` -To ensure that all tensorboard traces will be uploaded correctly, we need to -track them. You can run the following command inside your model repo to do so. - -``` -cd bart-base-xsum -git lfs track "*tfevents*" -``` - -Great, we have set up our model repository. During training, we will automatically -push the training logs and model weights to the repo. - -Next, let's add a symbolic link to the `run_summarization_flax.py`. - -```bash -export MODEL_DIR="./bart-base-xsum" -ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py -``` - ### Train the model Next we can run the example script to train the model: ```bash python run_summarization_flax.py \ - --output_dir ${MODEL_DIR} \ + --output_dir ./bart-base-xsum \ --model_name_or_path facebook/bart-base \ --tokenizer_name facebook/bart-base \ --dataset_name="xsum" \ diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 60292c4d84c..9c72cce2160 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util from flax.jax_utils import unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, @@ -52,7 +53,7 @@ from transformers import ( TrainingArguments, is_tensorboard_available, ) -from transformers.file_utils import is_offline_mode +from transformers.file_utils import get_full_repo_name, is_offline_mode logger = logging.getLogger(__name__) @@ -333,6 +334,16 @@ def main(): # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -800,12 +811,10 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of epoch {epoch+1}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) if __name__ == "__main__": diff --git a/examples/flax/text-classification/README.md b/examples/flax/text-classification/README.md index 8fd64f0e44f..bf4c4c79cc1 100644 --- a/examples/flax/text-classification/README.md +++ b/examples/flax/text-classification/README.md @@ -21,47 +21,15 @@ limitations under the License. Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py). Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding -Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models). - -To begin with it is recommended to create a model repository to save the trained model and logs. -Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create bert-glue-mrpc-test -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//bert-glue-mrpc-test -``` - -To ensure that all tensorboard traces will be uploaded correctly, we need to -track them. You can run the following command inside your model repo to do so. - -``` -cd bert-glue-mrpc-test -git lfs track "*tfevents*" -``` - -Great, we have set up our model repository. During training, we will automatically -push the training logs and model weights to the repo. - -Next, let's add a symbolic link to the `run_flax_glue.py`. - -```bash -export TASK_NAME=mrpc -export MODEL_DIR="./bert-glue-mrpc-test" -ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py -``` - +Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a +dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case, +refer to the comments inside for help). GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: ```bash +export TASK_NAME=mrpc + python run_flax_glue.py \ --model_name_or_path bert-base-cased \ --task_name ${TASK_NAME} \ @@ -69,7 +37,7 @@ python run_flax_glue.py \ --learning_rate 2e-5 \ --num_train_epochs 3 \ --per_device_train_batch_size 4 \ - --output_dir ${MODEL_DIR} \ + --output_dir ./$TASK_NAME/ \ --push_to_hub ``` diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 6a12a855be4..ccccfbea964 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -20,6 +20,7 @@ import os import random import time from itertools import chain +from pathlib import Path from typing import Any, Callable, Dict, Tuple import datasets @@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig +from transformers.file_utils import get_full_repo_name logger = logging.getLogger(__name__) @@ -128,6 +131,10 @@ def parse_args(): action="store_true", help="If passed, model checkpoints and tensorboard logs will be pushed to the hub", ) + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") args = parser.parse_args() # Sanity checks @@ -141,6 +148,9 @@ def parse_args(): extension = args.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) @@ -267,6 +277,14 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() + # Handle the repository creation + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). @@ -499,12 +517,10 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - args.output_dir, - params=params, - push_to_hub=args.push_to_hub, - commit_message=f"Saving weights and logs of epoch {epoch}", - ) + model.save_pretrained(args.output_dir, params=params) + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) if __name__ == "__main__": diff --git a/examples/flax/token-classification/README.md b/examples/flax/token-classification/README.md index 34f156e1c45..915cf6ae20f 100644 --- a/examples/flax/token-classification/README.md +++ b/examples/flax/token-classification/README.md @@ -22,31 +22,6 @@ It will either run on a datasets hosted on our hub or with your own text files f The following example fine-tunes BERT on CoNLL-2003: -To begin with it is recommended to create a model repository to save the trained model and logs. -Here we call the model `"bert-ner-conll2003-test"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create bert-ner-conll2003-test -``` - -Next we clone the model repository to add the tokenizer and model files. - -``` -git clone https://huggingface.co//bert-ner-conll2003-test -``` - -Great, we have set up our model repository. During training, we will automatically -push the training logs and model weights to the repo. - -Next, let's add a symbolic link to the `run_flax_ner.py`. - -```bash -export MODEL_DIR="./bert-ner-conll2003-test" -ln -s ~/transformers/examples/flax/token-classification/run_flax_ner.py run_flax_ner.py -``` ```bash python run_flax_ner.py \ @@ -56,7 +31,7 @@ python run_flax_ner.py \ --learning_rate 2e-5 \ --num_train_epochs 3 \ --per_device_train_batch_size 4 \ - --output_dir ${MODEL_DIR} \ + --output_dir ./bert-ner-conll2003 \ --eval_steps 300 \ --push_to_hub ``` diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index aedc2eba85a..17a08c5a616 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -21,6 +21,7 @@ import sys import time from dataclasses import dataclass, field from itertools import chain +from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple import datasets @@ -37,6 +38,7 @@ from flax.jax_utils import replicate, unreplicate from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard +from huggingface_hub import Repository from transformers import ( AutoConfig, AutoTokenizer, @@ -44,6 +46,7 @@ from transformers import ( HfArgumentParser, TrainingArguments, ) +from transformers.file_utils import get_full_repo_name from transformers.utils import check_min_version from transformers.utils.versions import require_version @@ -304,6 +307,16 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). @@ -656,12 +669,10 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of step {cur_step}", - ) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" diff --git a/examples/flax/vision/README.md b/examples/flax/vision/README.md index 19a213b838d..d865b8a30ce 100644 --- a/examples/flax/vision/README.md +++ b/examples/flax/vision/README.md @@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism. In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset. -Let's start by creating a model repository to save the trained model and logs. -Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like. - -You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that -you are logged in) or via the command line: - -``` -huggingface-cli repo create vit-base-patch16-imagenette -``` -Next we clone the model repository to add the tokenizer and model files. -``` -git clone https://huggingface.co//vit-base-patch16-imagenette -``` -To ensure that all tensorboard traces will be uploaded correctly, we need to -track them. You can run the following command inside your model repo to do so. - -``` -cd vit-base-patch16-imagenette -git lfs track "*tfevents*" -``` - -Great, we have set up our model repository. During training, we will automatically -push the training logs and model weights to the repo. - -Next, let's add a symbolic link to the `run_image_classification_flax.py`. - -```bash -export MODEL_DIR="./vit-base-patch16-imagenette -ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py -``` - ## Prepare the dataset We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute). @@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model: ```bash python run_image_classification.py \ - --output_dir ${MODEL_DIR} \ + --output_dir ./vit-base-patch16-imagenette \ --model_name_or_path google/vit-base-patch16-224-in21k \ --train_dir="imagenette2/train" \ --validation_dir="imagenette2/val" \ diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index e12a20aa276..882d846b6f8 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -42,6 +42,7 @@ from flax import jax_utils from flax.jax_utils import unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key +from huggingface_hub import Repository from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, @@ -52,6 +53,7 @@ from transformers import ( is_tensorboard_available, set_seed, ) +from transformers.file_utils import get_full_repo_name logger = logging.getLogger(__name__) @@ -205,6 +207,16 @@ def main(): # set seed for random transforms and torch dataloaders set_seed(training_args.seed) + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray @@ -455,12 +467,9 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of epoch {epoch+1}", - ) + model.save_pretrained(training_args.output_dir, params=params) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) if __name__ == "__main__":