From 8d4bb020565e404d9eb814150280147e4963a2ee Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 10 Dec 2020 15:57:39 -0500 Subject: [PATCH] Refactor FLAX tests (#9034) --- src/transformers/tokenization_utils_base.py | 1 + tests/test_modeling_flax_bert.py | 138 ++++++++++++-------- tests/test_modeling_flax_common.py | 127 ++++++++++++++++++ tests/test_modeling_flax_roberta.py | 138 ++++++++++++-------- 4 files changed, 294 insertions(+), 110 deletions(-) create mode 100644 tests/test_modeling_flax_common.py diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 77b0362c546..14d0d0bd4e1 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -50,6 +50,7 @@ if is_tf_available(): if is_torch_available(): import torch + if is_flax_available(): import jax.numpy as jnp diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index cdb2367c1e2..2276ef3b55c 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -14,70 +14,98 @@ import unittest -from numpy import ndarray +from transformers import BertConfig, is_flax_available +from transformers.testing_utils import require_flax -from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available -from transformers.testing_utils import require_flax, require_torch +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask if is_flax_available(): - import os - - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 - - import jax from transformers.models.bert.modeling_flax_bert import FlaxBertModel -if is_torch_available(): - import torch - from transformers.models.bert.modeling_bert import BertModel +class FlaxBertModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_attention_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = BertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + return config, inputs_dict @require_flax -@require_torch -class FlaxBertModelTest(unittest.TestCase): - def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): - diff = (a - b).sum() - self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})") +class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): - def test_from_pytorch(self): - with torch.no_grad(): - with self.subTest("bert-base-cased"): - tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") - fx_model = FlaxBertModel.from_pretrained("bert-base-cased") - pt_model = BertModel.from_pretrained("bert-base-cased") + all_model_classes = (FlaxBertModel,) if is_flax_available() else () - # Check for simple input - pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH) - fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX) - pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_outputs = fx_model(**fx_inputs) - - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) - - def test_multiple_sequences(self): - tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") - model = FlaxBertModel.from_pretrained("bert-base-cased") - - sequences = ["this is an example sentence", "this is another", "and a third one"] - encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True) - - @jax.jit - def model_jitted(input_ids, attention_mask=None, token_type_ids=None): - return model(input_ids, attention_mask, token_type_ids) - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - tokens, pooled = model_jitted(**encodings) - self.assertEqual(tokens.shape, (3, 7, 768)) - self.assertEqual(pooled.shape, (3, 768)) - - with self.subTest("JIT Enabled"): - jitted_tokens, jitted_pooled = model_jitted(**encodings) - - self.assertEqual(jitted_tokens.shape, (3, 7, 768)) - self.assertEqual(jitted_pooled.shape, (3, 768)) + def setUp(self): + self.model_tester = FlaxBertModelTester(self) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py new file mode 100644 index 00000000000..c79407527ed --- /dev/null +++ b/tests/test_modeling_flax_common.py @@ -0,0 +1,127 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np + +import transformers +from transformers import is_flax_available, is_torch_available +from transformers.testing_utils import require_flax, require_torch + + +if is_flax_available(): + import os + + import jax + import jax.numpy as jnp + from flax.traverse_util import unflatten_dict + + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 + +if is_torch_available(): + import torch + + +def ids_tensor(shape, vocab_size, rng=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + output = np.array(values, dtype=jnp.int32).reshape(shape) + + return output + + +def random_attention_mask(shape, rng=None): + attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) + # make sure that at least one token is attended to for each batch + attn_mask[:, -1] = 1 + return attn_mask + + +def convert_pt_model_to_flax(pt_model, config, flax_model_cls): + state = pt_model.state_dict() + state = {k: v.numpy() for k, v in state.items()} + state = flax_model_cls.convert_from_pytorch(state, config) + state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()}) + return flax_model_cls(config, state, dtype=jnp.float32) + + +@require_flax +class FlaxModelTesterMixin: + model_tester = None + all_model_classes = () + + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).sum() + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + + @require_torch + def test_equivalence_flax_pytorch(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + pt_model = pt_model_class(config).eval() + + fx_model = convert_pt_model_to_flax(pt_model, config, model_class) + + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + fx_outputs = fx_model(**inputs_dict) + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) + + @require_torch + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # TODO later: have some way to initialize easily a Flax model from config, for now I go through PT + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + pt_model = pt_model_class(config).eval() + + model = convert_pt_model_to_flax(pt_model, config, model_class) + + @jax.jit + def model_jitted(input_ids, attention_mask=None, token_type_ids=None): + return model(input_ids, attention_mask, token_type_ids) + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**inputs_dict) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**inputs_dict) + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py index 3c60b17ab85..fa20d9fa3ca 100644 --- a/tests/test_modeling_flax_roberta.py +++ b/tests/test_modeling_flax_roberta.py @@ -14,70 +14,98 @@ import unittest -from numpy import ndarray +from transformers import RobertaConfig, is_flax_available +from transformers.testing_utils import require_flax -from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available -from transformers.testing_utils import require_flax, require_torch +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask if is_flax_available(): - import os - - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 - - import jax from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel -if is_torch_available(): - import torch - from transformers.models.roberta.modeling_roberta import RobertaModel +class FlaxRobertaModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_attention_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = RobertaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + return config, inputs_dict @require_flax -@require_torch -class FlaxRobertaModelTest(unittest.TestCase): - def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): - diff = (a - b).sum() - self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol})") +class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase): - def test_from_pytorch(self): - with torch.no_grad(): - with self.subTest("roberta-base"): - tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") - fx_model = FlaxRobertaModel.from_pretrained("roberta-base") - pt_model = RobertaModel.from_pretrained("roberta-base") + all_model_classes = (FlaxRobertaModel,) if is_flax_available() else () - # Check for simple input - pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH) - fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX) - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**fx_inputs) - - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - - for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()): - self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) - - def test_multiple_sequences(self): - tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") - model = FlaxRobertaModel.from_pretrained("roberta-base") - - sequences = ["this is an example sentence", "this is another", "and a third one"] - encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True) - - @jax.jit - def model_jitted(input_ids, attention_mask=None, token_type_ids=None): - return model(input_ids, attention_mask, token_type_ids) - - with self.subTest("JIT Disabled"): - with jax.disable_jit(): - tokens, pooled = model_jitted(**encodings) - self.assertEqual(tokens.shape, (3, 7, 768)) - self.assertEqual(pooled.shape, (3, 768)) - - with self.subTest("JIT Enabled"): - jitted_tokens, jitted_pooled = model_jitted(**encodings) - - self.assertEqual(jitted_tokens.shape, (3, 7, 768)) - self.assertEqual(jitted_pooled.shape, (3, 768)) + def setUp(self): + self.model_tester = FlaxRobertaModelTester(self)