mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] Refactor gpt2 & bert example docs (#13024)
* fix_torch_device_generate_test * remove @ * improve docs for clm * speed-ups * correct t5 example as well * push final touches * Update examples/flax/language-modeling/README.md * correct docs for mlm * Update examples/flax/language-modeling/README.md Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
3ff2cde5ca
commit
13a9c9a354
@ -49,21 +49,15 @@ Next we clone the model repository to add the tokenizer and model files.
|
||||
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
To setup all relevant files for training, let's go into the cloned model directory.
|
||||
|
||||
```
|
||||
```bash
|
||||
cd norwegian-roberta-base
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_mlm_flax.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./norwegian-roberta-base"
|
||||
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
|
||||
```
|
||||
|
||||
@ -71,15 +65,13 @@ ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_fla
|
||||
|
||||
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
|
||||
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||
and consequently saved in `${MODEL_DIR}`
|
||||
and consequently saved in the cloned model directory.
|
||||
This can take up to 10 minutes depending on your hardware ☕.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
|
||||
|
||||
@ -100,7 +92,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
|
||||
])
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save(f"{model_dir}/tokenizer.json")
|
||||
tokenizer.save("./")
|
||||
```
|
||||
|
||||
### Create configuration
|
||||
@ -112,22 +104,23 @@ in the local model folder:
|
||||
```python
|
||||
from transformers import RobertaConfig
|
||||
|
||||
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
|
||||
|
||||
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=tokenizer.get_vocab_size())
|
||||
config.save_pretrained(model_dir)
|
||||
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
|
||||
config.save_pretrained("./")
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model:
|
||||
|
||||
```bash
|
||||
./run_mlm_flax.py \
|
||||
--output_dir="${MODEL_DIR}" \
|
||||
--output_dir="./" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--config_name="./" \
|
||||
--tokenizer_name="./" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--max_seq_length="128" \
|
||||
@ -180,25 +173,51 @@ Next we clone the model repository to add the tokenizer and model files.
|
||||
git clone https://huggingface.co/<your-username>/norwegian-gpt2
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd norwegian-gpt2
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_clm_flax.py`.
|
||||
To setup all relevant files for training, let's go into the cloned model directory.
|
||||
|
||||
```bash
|
||||
cd norwegian-gpt2
|
||||
```
|
||||
|
||||
Next, let's add a symbolic link to the training script `run_clm_flax.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./norwegian-gpt2"
|
||||
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
|
||||
```
|
||||
|
||||
Next, we'll follow the same steps as above in [Train tokenizer](#train-tokenizer) to train the tokenizer.
|
||||
### Train tokenizer
|
||||
|
||||
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
|
||||
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||
and consequently saved in the cloned model directory.
|
||||
This can take up to 10 minutes depending on your hardware ☕.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
||||
|
||||
# load dataset
|
||||
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
|
||||
|
||||
# Instantiate tokenizer
|
||||
tokenizer = ByteLevelBPETokenizer()
|
||||
|
||||
def batch_iterator(batch_size=1000):
|
||||
for i in range(0, len(dataset), batch_size):
|
||||
yield dataset[i: i + batch_size]["text"]
|
||||
|
||||
# Customized training
|
||||
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=2, special_tokens=[
|
||||
"<s>",
|
||||
"<pad>",
|
||||
"</s>",
|
||||
"<unk>",
|
||||
"<mask>",
|
||||
])
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save("./tokenizer.json")
|
||||
```
|
||||
|
||||
### Create configuration
|
||||
|
||||
@ -209,22 +228,23 @@ in the local model folder:
|
||||
```python
|
||||
from transformers import GPT2Config
|
||||
|
||||
model_dir = "./norwegian-gpt2" # ${MODEL_DIR}
|
||||
|
||||
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=tokenizer.get_vocab_size())
|
||||
config.save_pretrained(model_dir)
|
||||
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
|
||||
config.save_pretrained("./")
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will now automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model:
|
||||
Finally, we can run the example script to pretrain the model:
|
||||
|
||||
```bash
|
||||
./run_clm_flax.py \
|
||||
--output_dir="${MODEL_DIR}" \
|
||||
--output_dir="./l" \
|
||||
--model_type="gpt2" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--config_name="./" \
|
||||
--tokenizer_name="./" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--do_train --do_eval \
|
||||
@ -246,6 +266,9 @@ of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8.
|
||||
This should take less than ~21 hours.
|
||||
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
|
||||
|
||||
For a step-by-step walkthrough of how to do causal language modeling in Flax, please have a
|
||||
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/causal_language_modeling_flax.ipynb) google colab.
|
||||
|
||||
## T5-like span-masked language modeling
|
||||
|
||||
In the following, we demonstrate how to train a T5 model using the span-masked language model
|
||||
@ -272,21 +295,15 @@ Next we clone the model repository to add the tokenizer and model files.
|
||||
git clone https://huggingface.co/<your-username>/norwegian-t5-base
|
||||
```
|
||||
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
To setup all relevant files for trairing, let's go into the cloned model directory.
|
||||
|
||||
```
|
||||
```bash
|
||||
cd norwegian-t5-base
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./norwegian-t5-base"
|
||||
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
|
||||
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
|
||||
```
|
||||
@ -299,7 +316,7 @@ a sentencepiece unigram tokenizer as shown in [t5_tokenizer_model.py](https://gi
|
||||
which is heavily inspired from [yandex-research/DeDLOC's tokenizer model](https://github.com/yandex-research/DeDLOC/blob/5c994bc64e573702a9a79add3ecd68b38f14b548/sahajbert/tokenizer/tokenizer_model.py) .
|
||||
|
||||
The tokenizer is trained on the complete Norwegian dataset of OSCAR
|
||||
and consequently saved in `${MODEL_DIR}`
|
||||
and consequently saved in the cloned model directory.
|
||||
This can take up to 120 minutes depending on your hardware ☕☕☕ .
|
||||
|
||||
```python
|
||||
@ -310,7 +327,6 @@ from t5_tokenizer_model import SentencePieceUnigramTokenizer
|
||||
|
||||
vocab_size = 32_000
|
||||
input_sentence_size = None
|
||||
model_dir = "./norwegian-t5-base" # ${MODEL_DIR}
|
||||
|
||||
# Initialize a dataset
|
||||
dataset = datasets.load_dataset("oscar", name="unshuffled_deduplicated_no", split="train")
|
||||
@ -335,7 +351,7 @@ tokenizer.train_from_iterator(
|
||||
)
|
||||
|
||||
# Save files to disk
|
||||
tokenizer.save(f"{model_dir}/tokenizer.json")
|
||||
tokenizer.save("./tokenizer.json")
|
||||
```
|
||||
|
||||
### Create configuration
|
||||
@ -347,12 +363,13 @@ in the local model folder:
|
||||
```python
|
||||
from transformers import T5Config
|
||||
|
||||
model_dir = "./norwegian-t5-base" # ${MODEL_DIR}
|
||||
|
||||
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
|
||||
config.save_pretrained(model_dir)
|
||||
config.save_pretrained("./")
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
### Train model
|
||||
|
||||
Next we can run the example script to pretrain the model:
|
||||
|
@ -31,6 +31,7 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -51,6 +52,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import CaptureLogger
|
||||
|
||||
@ -182,18 +184,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.random.permutation(len(dataset))
|
||||
else:
|
||||
batch_idx = jnp.arange(len(dataset))
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
||||
|
||||
batch = shard(batch)
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
@ -269,6 +269,9 @@ def main():
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
@ -577,7 +580,7 @@ def main():
|
||||
|
||||
train_time = 0
|
||||
train_metrics = []
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
@ -591,6 +594,7 @@ def main():
|
||||
# train
|
||||
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||
batch = next(train_loader)
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
@ -617,6 +621,7 @@ def main():
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = next(eval_loader)
|
||||
batch = shard(batch)
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
|
@ -214,7 +214,7 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
|
||||
def mask_tokens(
|
||||
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user