mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax community event] How to use hub during training (#12447)
* fix_torch_device_generate_test * remove @ * upload * finish doc * Apply suggestions from code review Co-authored-by: Omar Sanseviero <osanseviero@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Julien Chaumond <chaumond@gmail.com> * finish Co-authored-by: Omar Sanseviero <osanseviero@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
3aa37b945e
commit
b655f16d4e
@ -24,10 +24,10 @@ Don't forget to sign up [here](https://forms.gle/tVGPhjKXyEsSgUcs8)!
|
||||
- [Quickstart Flax/JAX in 🤗 Transformers](#quickstart-flax-and-jax-in-transformers)
|
||||
- [Flax design philosophy in 🤗 Transformers](#flax-design-philosophy-in-transformers)
|
||||
- [How to use flax models & scripts](#how-to-use-flax-models-and-example-scripts)
|
||||
- [How to make a demo for submission](#how-to-make-a-demo)
|
||||
- [Talks](#talks)
|
||||
- [How to use the 🤗 Hub for training](#how-to-use-the-hub-for-training)
|
||||
- [How to setup TPU VM](#how-to-setup-tpu-vm)
|
||||
- [How to use the 🤗 Hub for training and demo](#how-to-use-the-hub-for-training-and-demo)
|
||||
- [How to use the 🤗 Hub for demo](#how-to-use-the-hub-for-demo)
|
||||
- [Project evaluation](#project-evaluation)
|
||||
- [General Tips & Tricks](#general-tips-and-tricks)
|
||||
- [FAQ](#faq)
|
||||
@ -369,7 +369,7 @@ be available in a couple of days.
|
||||
- [RoBERTa](https://github.com/huggingface/transformers/blob/master/src/transformers/models/roberta/modeling_flax_roberta.py)
|
||||
- [T5](https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_flax_t5.py)
|
||||
- [ViT](https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_flax_vit.py)
|
||||
- [(TODO) Wav2Vec2](https://github.com/huggingface/transformers/blob/master/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py)
|
||||
- [Wav2Vec2](https://github.com/huggingface/transformers/blob/master/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py)
|
||||
|
||||
You can find all available training scripts for JAX/Flax under the
|
||||
official [flax example folder](https://github.com/huggingface/transformers/tree/master/examples/flax). Note that a couple of training scripts will be released in the following week.
|
||||
@ -379,6 +379,8 @@ official [flax example folder](https://github.com/huggingface/transformers/tree/
|
||||
- [Text classification (BERT, RoBERTa, ELECTRA, BigBird)](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py)
|
||||
- [Summarization / Seq2Seq (BART, MBART, T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py)
|
||||
- [Masked Seq2Seq pret-training (T5)](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py)
|
||||
- [Contrastive Loss pretraining for Wav2Vec2](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/wav2vec2)
|
||||
- [Fine-tuning long-range QA for BigBird](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/big_bird)
|
||||
- [(TODO) Image classification (ViT)]( )
|
||||
- [(TODO) CLIP pretraining, fine-tuning (CLIP)]( )
|
||||
|
||||
@ -683,10 +685,6 @@ That menas you could use the same training script on CPUs, GPUs, TPUs.
|
||||
|
||||
To know more about how to train the Flax models on different devices (GPU, multi-GPUs, TPUs) and use the example scripts, please look at the [examples README](https://github.com/huggingface/transformers/tree/master/examples/flax).
|
||||
|
||||
## How to make a demo
|
||||
|
||||
TODO (should be filled by 30.06.)...
|
||||
|
||||
## Talks
|
||||
|
||||
Super excited to kick off 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics during our community event! Find the schedule, zoom links and calendar events below!
|
||||
@ -800,7 +798,180 @@ Super excited to kick off 3 days of talks around JAX / Flax, Transformers, large
|
||||
- Website: https://cohere.ai/
|
||||
|
||||
|
||||
## How to use the hub for collaboration
|
||||
|
||||
In this section, we will explain how a team can use the 🤗 hub to collaborate on a project.
|
||||
The 🤗 hub allows each team to create a repository with integrated git version control that
|
||||
should be used for their project.
|
||||
The advantages of using a repository on the 🤗 hub are:
|
||||
|
||||
- easy collaboration - each team member has write access to the model repository
|
||||
- integrated git version control - code scripts as well as large model files are tracked using git version control
|
||||
- easy sharing - the hub allows each team to easily share their work during and after the event
|
||||
- integrated tensorboard functionality - uploaded tensorboard traces are automatically displayed on an integrated tensorboard tab
|
||||
|
||||
We highly recommend each team to make use of the 🤗 hub during the event.
|
||||
To better understand how the repository and the hub in general functions, please take a look at the documentation and the videos [here](https://huggingface.co/docs/hub).
|
||||
|
||||
Now let's explain in more detail how a project can be created on the hub. Having an officially defined project on [this](https://docs.google.com/spreadsheets/d/1GpHebL7qrwJOc9olTpIPgjf8vOS0jNb6zR_B8x_Jtik/edit?usp=sharing) Google Sheet you should be part of [the Flax Community organization on the hub](https://huggingface.co/flax-community). All repositories should be created under this organization so that write access can be shared and everybody can easily access other participants'
|
||||
work 🤗. Note that we are giving each team member access to all repositories created under [flax-community](https://huggingface.co/flax-community), but we encourage participants to only clone and edit repositories corresponding to one's teams. If you want to help other teams, please ask them before changing files in their repository! The integrated git version control keeps track of
|
||||
all changes, so in case a file was deleted by mistake, it is trivial to re-create it.
|
||||
|
||||
Awesome! Now, let's first go over a simple example where most of the required we'll pre-train a RoBERTa model on a low-resource language. To begin with, we create a repository
|
||||
under [the Flax Community organization on the hub](https://huggingface.co/flax-community) by logging in to the hub and going to [*"Add model"*](https://huggingface.co/new). By default
|
||||
the username should be displayed under "*Owner*", which we want to change to *flax-community*. Next, we give our repository a fitting name for the project - here we'll just call it
|
||||
*roberta-base-als* because we'll be pretraining a RoBERTa model on the super low-resource language *Alemannic* (`als`). We make sure that the model is a public repository and create it!
|
||||
It should then be displayed on [the Flax Community organization on the hub](https://huggingface.co/flax-community).
|
||||
|
||||
Great, now we have a project directory with integrated git version control and a public model page, which we can access under [flax-community/roberta-base-als](https://huggingface.co/flax-community/roberta-base-als). Let's create a short README so that other participants know what this model is about. You can create the README.md directly on the model page as a markdown file.
|
||||
Let's now make use of the repository for training.
|
||||
|
||||
We assume that the 🤗 Transformers library and [git-lfs](https://git-lfs.github.com/) are correctly installed on our machine or the TPU attributed to us.
|
||||
If this is not the case, please refer to the [Installation guide](#how-to-install-relevant-libraries) and the official [git-lfs](https://git-lfs.github.com/) website.
|
||||
|
||||
At first we should log in:
|
||||
|
||||
```bash
|
||||
$ huggingface-cli login
|
||||
```
|
||||
|
||||
Next we can clone the repo:
|
||||
|
||||
```bash
|
||||
$ git clone https://huggingface.co/flax-community/roberta-base-als
|
||||
```
|
||||
|
||||
We have now cloned the model's repository and it should be under `roberta-base-als`. As you can see,
|
||||
we have all the usual git functionalities in this repo - when adding a file, we can do `git add .`, `git commit -m "add file"` and `git push`
|
||||
as usual. Let's try it out by adding the model's config.
|
||||
|
||||
We go into the folder:
|
||||
|
||||
```bash
|
||||
$ cd ./roberta-base-als
|
||||
```
|
||||
|
||||
and run the following commands in a Python shell to save a config.
|
||||
|
||||
```python
|
||||
from transformers import RobertaConfig
|
||||
|
||||
config = RobertaConfig.from_pretrained("roberta-base")
|
||||
config.save_pretrained("./")
|
||||
```
|
||||
|
||||
Now we've added a `config.json` file and can upload it by running
|
||||
|
||||
```bash
|
||||
$ git add . && git commit -m "add config" && git push
|
||||
```
|
||||
|
||||
Cool! The file is now displayed on the model page under the [files tab](https://huggingface.co/flax-community/roberta-base-als/tree/main).
|
||||
We encourage you to upload all files except maybe the actual data files to the repository. This includes training scripts, model weights,
|
||||
model configurations, training logs, etc...
|
||||
|
||||
Next, let's create a tokenizer and save it to the model dir by following the instructions of the [official Flax MLM README](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling#train-tokenizer). We can again use a simple Python shell.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("oscar", "unshuffled_deduplicated_als", split="train")
|
||||
|
||||
# Instantiate tokenizer
|
||||
tokenizer = ByteLevelBPETokenizer()
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i: i + batch_size]["text"]
|
||||
|
||||
# Customized training
|
||||
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
||||
"<s>",
|
||||
"<pad>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
])
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save("./tokenizer.json")
|
||||
```
|
||||
|
||||
This creates and saves our tokenizer directly in the cloned repository.
|
||||
Finally, we can start training. For now, we'll simply use the official [`run_mlm_flax`](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_mlm_flax.py)
|
||||
script, but we might make some changes later. So let's copy the script into our model repository.
|
||||
|
||||
```bash
|
||||
$ cp ~/transformers/examples/flax/language-modeling/run_mlm_flax.py ./
|
||||
```
|
||||
|
||||
This way we are certain to have all the code used to train the model tracked in our repository.
|
||||
Let's start training by running:
|
||||
|
||||
```bash
|
||||
./run_mlm_flax.py \
|
||||
--output_dir="./" \
|
||||
--model_type="roberta" \
|
||||
--config_name="./" \
|
||||
--tokenizer_name="./" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_als" \
|
||||
--max_seq_length="128" \
|
||||
--per_device_train_batch_size="4" \
|
||||
--per_device_eval_batch_size="4" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="8" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Since the dataset is tiny this command should actually run in less than 5 minutes. Note that we attach
|
||||
the flag ``--push_to_hub`` so that both model weights and tensorboard traces are automatically uploaded to the hub.
|
||||
You can see the tensorboard directly on the model page, under the [Training metrics tab](https://huggingface.co/flax-community/roberta-base-als/tensorboard).
|
||||
|
||||
As you can see, it is pretty simple to upload model weights and training logs to the model hub. Since the repository
|
||||
has git version control, you & your team probably already have the necessary skills to collaborate. Thanks
|
||||
to `git-lfs` being integrated into the hub, model weights and other larger file can just as easily be uploaded
|
||||
and changed. Finally, at Hugging Face, we believe that the model hub is a great platform to share your project
|
||||
while you are still working on it:
|
||||
|
||||
- Bugs in training scripts can be found and corrected by anybody participating in the event
|
||||
- Loss curves can be analyzed directly on the model page
|
||||
- Model weights can be accessed and analyzed by everybody from the model repository
|
||||
|
||||
If you are not using a transformers model, don't worry - you should still be able to make use of the hub's functionalities!
|
||||
The [huggingface_hub](https://github.com/huggingface/huggingface_hub) allows you to upload essentially any JAX/Flax model to the hub with
|
||||
just a couple of lines of code. *E.g.* assuming you want to call your model simply `flax-model-dummy`, you can upload it to the hub with
|
||||
just three lines of code:
|
||||
|
||||
|
||||
```python
|
||||
from flax import serialization
|
||||
from jax import random
|
||||
from flax import linen as nn
|
||||
|
||||
model = nn.Dense(features=5)
|
||||
|
||||
key1, key2 = random.split(random.PRNGKey(0))
|
||||
x = random.normal(key1, (10,))
|
||||
params = model.init(key2, x)
|
||||
|
||||
bytes_output = serialization.to_bytes(params)
|
||||
|
||||
repo = Repository("flax-model", clone_from="flax-community/flax-model-dummy", use_auth_token=True)
|
||||
with repo.commit("My cool Flax model :)"):
|
||||
with open("flax_model.msgpack", "wb") as f:
|
||||
f.write(bytes_output)
|
||||
|
||||
# Repo is created and available here: https://huggingface.co/flax-community/flax-model-dummy
|
||||
```
|
||||
|
||||
**Note**: Make sure to have `huggingface_hub >= 0.0.13` to make this command work.
|
||||
|
||||
For more information, check out [this PR](https://github.com/huggingface/huggingface_hub/pull/143) on how to upload any framework to the hub.
|
||||
|
||||
## How to setup TPU VM
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user