mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Flax T5] Speed up t5 training (#13012)
* fix_torch_device_generate_test * remove @ * update * up * fix * remove f-stings * correct readme * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
60e448c87e
commit
2e4082364e
@ -373,15 +373,15 @@ Next we can run the example script to pretrain the model:
|
||||
--weight_decay="0.001" \
|
||||
--warmup_steps="2000" \
|
||||
--overwrite_output_dir \
|
||||
--logging_steps="100" \
|
||||
--save_steps="1000" \
|
||||
--eval_steps="1000" \
|
||||
--logging_steps="500" \
|
||||
--save_steps="10000" \
|
||||
--eval_steps="2500" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Training should converge at a loss and accuracy
|
||||
of 2.2 and 58.0 respectively after 2 epochs on a single TPUv3-8.
|
||||
This should take around 24 hours.
|
||||
of 2.36 and 57.0 respectively after 3 epochs on a single TPUv3-8.
|
||||
This should take around 4.5 hours.
|
||||
Training statistics can be accessed on directly on the 🤗 [hub](https://huggingface.co/patrickvonplaten/t5-base-norwegian/tensorboard)
|
||||
|
||||
## Runtime evaluation
|
||||
|
@ -353,7 +353,8 @@ class FlaxDataCollatorForT5MLM:
|
||||
np.random.shuffle(mask_indices)
|
||||
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
||||
segment_id = np.cumsum(first_in_segment)
|
||||
segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
|
||||
# count length of sub segments assuming that list is sorted
|
||||
_, segment_length = np.unique(segment_id, return_counts=True)
|
||||
return segment_length
|
||||
|
||||
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
||||
@ -720,7 +721,7 @@ if __name__ == "__main__":
|
||||
state = jax_utils.replicate(state)
|
||||
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
Loading…
Reference in New Issue
Block a user