[Flax T5] Speed up t5 training (#13012)

* fix_torch_device_generate_test

* remove @

* update

* up

* fix

* remove f-stings

* correct readme

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-08-06 11:21:37 +02:00 committed by GitHub
parent 60e448c87e
commit 2e4082364e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -373,15 +373,15 @@ Next we can run the example script to pretrain the model:
--weight_decay="0.001" \
--warmup_steps="2000" \
--overwrite_output_dir \
--logging_steps="100" \
--save_steps="1000" \
--eval_steps="1000" \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--push_to_hub
```
Training should converge at a loss and accuracy
of 2.2 and 58.0 respectively after 2 epochs on a single TPUv3-8.
This should take around 24 hours.
of 2.36 and 57.0 respectively after 3 epochs on a single TPUv3-8.
This should take around 4.5 hours.
Training statistics can be accessed on directly on the 🤗 [hub](https://huggingface.co/patrickvonplaten/t5-base-norwegian/tensorboard)
## Runtime evaluation

View File

@ -353,7 +353,8 @@ class FlaxDataCollatorForT5MLM:
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
@ -720,7 +721,7 @@ if __name__ == "__main__":
state = jax_utils.replicate(state)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()