transformers/tests/test_modeling_flax_roberta.py
Funtowicz Morgan 75627148ee
Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more 🤗

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Initial working (forward + backward) for Flax MLM training example.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Simply code

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing comments, using module and moving to LM task.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore parameter name "module" wrongly renamed model.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Restore correct output ordering...

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Actually commit the example 😅

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Add FlaxBertModelForMaskedLM after rebasing.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make it possible to initialize the training from scratch

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Reuse flax linen example of cross entropy loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added specific data collator for flax

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove todo for data collator

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added evaluation step

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added ability to provide dtype to support bfloat16 on TPU

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable flax tensorboard output

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable jax.pmap support.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Ensure batches are correctly sized to be dispatched with jax.pmap

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable bfloat16 with --fp16 cmdline args

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correctly export metrics to tensorboard

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added dropout and ability to use it.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Effectively enable & disable during training and evaluation steps.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Oops.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Enable specifying kernel initializer scale

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added warmup step to the learning rate scheduler.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix typo.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Print training loss

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix linter issue (flake8)

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix model matching

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix dummies

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix non default dtype on Flax models

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use the same create_position_ids_from_input_ids for FlaxRoberta

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta attention as Bert

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix copy

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Wording.

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
2020-12-09 17:13:56 +01:00

84 lines
3.4 KiB
Python

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from numpy import ndarray
from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
if is_flax_available():
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
import jax
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
if is_torch_available():
import torch
from transformers.models.roberta.modeling_roberta import RobertaModel
@require_flax
@require_torch
class FlaxRobertaModelTest(unittest.TestCase):
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})")
def test_from_pytorch(self):
with torch.no_grad():
with self.subTest("roberta-base"):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
pt_model = RobertaModel.from_pretrained("roberta-base")
# Check for simple input
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**fx_inputs)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_multiple_sequences(self):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
model = FlaxRobertaModel.from_pretrained("roberta-base")
sequences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
self.assertEqual(tokens.shape, (3, 7, 768))
self.assertEqual(pooled.shape, (3, 768))
with self.subTest("JIT Enabled"):
jitted_tokens, jitted_pooled = model_jitted(**encodings)
self.assertEqual(jitted_tokens.shape, (3, 7, 768))
self.assertEqual(jitted_pooled.shape, (3, 768))