From ac2f6674a33e8eaffdf868e1fa6cbc8e722f469e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 28 Nov 2022 07:21:42 -0800 Subject: [PATCH] [FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340) * [FLAX] Add dtype to embedding for bert/bart/opt/t5 * Fix all copies * Add a test case --- .../models/bart/modeling_flax_bart.py | 4 ++++ .../models/bert/modeling_flax_bert.py | 3 +++ .../models/big_bird/modeling_flax_big_bird.py | 3 +++ .../models/blenderbot/modeling_flax_blenderbot.py | 1 + .../modeling_flax_blenderbot_small.py | 1 + .../models/longt5/modeling_flax_longt5.py | 3 +++ .../models/mbart/modeling_flax_mbart.py | 1 + src/transformers/models/opt/modeling_flax_opt.py | 2 ++ .../models/pegasus/modeling_flax_pegasus.py | 1 + .../models/roberta/modeling_flax_roberta.py | 3 +++ src/transformers/models/t5/modeling_flax_t5.py | 4 ++++ tests/models/t5/test_modeling_flax_t5.py | 15 +++++++++++++++ 12 files changed, 41 insertions(+) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 5704147872f..90ddfa57cbd 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -715,6 +715,7 @@ class FlaxBartEncoder(nn.Module): self.config.max_position_embeddings + self.offset, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype) self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) @@ -779,6 +780,7 @@ class FlaxBartDecoder(nn.Module): self.config.max_position_embeddings + self.offset, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) @@ -842,6 +844,7 @@ class FlaxBartModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) @@ -1888,6 +1891,7 @@ class FlaxBartDecoderWrapper(nn.Module): self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 0cdf622f33e..f7c78632e5e 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -187,16 +187,19 @@ class FlaxBertEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 140beb64239..b38492f61fb 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -205,16 +205,19 @@ class FlaxBigBirdEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index a75fe4d5b74..1b3b57b95b1 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -817,6 +817,7 @@ class FlaxBlenderbotModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index ddace51e7e2..e5a0352d24d 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -815,6 +815,7 @@ class FlaxBlenderbotSmallModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index e2cce0a3f4c..6e4558f3ff3 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -368,6 +368,7 @@ class FlaxLongT5Attention(nn.Module): self.relative_attention_num_buckets, self.n_heads, embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, ) @staticmethod @@ -2032,6 +2033,7 @@ class FlaxLongT5Module(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) @@ -2160,6 +2162,7 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 7cb52033b78..afc67be57ba 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -881,6 +881,7 @@ class FlaxMBartModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py index adb38f4138a..1237e3b25f7 100644 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ b/src/transformers/models/opt/modeling_flax_opt.py @@ -436,12 +436,14 @@ class FlaxOPTDecoder(nn.Module): self.config.vocab_size, self.config.word_embed_proj_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.embed_positions = FlaxOPTLearnedPositionalEmbedding( self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) if self.config.word_embed_proj_dim != self.config.hidden_size: diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index 303d0055716..c4ecd25b6eb 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -831,6 +831,7 @@ class FlaxPegasusModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, ) self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 5cc3da84cc3..b7494e19f4d 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -147,16 +147,19 @@ class FlaxRobertaEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 2732bf59169..1e93fb32357 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -228,6 +228,7 @@ class FlaxT5Attention(nn.Module): self.relative_attention_num_buckets, self.n_heads, embedding_init=jax.nn.initializers.normal(kv_init_std), + dtype=self.dtype, ) @staticmethod @@ -1292,6 +1293,7 @@ class FlaxT5Module(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) @@ -1417,6 +1419,7 @@ class FlaxT5EncoderModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0), + dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) @@ -1512,6 +1515,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): self.config.vocab_size, self.config.d_model, embedding_init=jax.nn.initializers.normal(self.config.initializer_factor), + dtype=self.dtype, ) encoder_config = copy.deepcopy(self.config) diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py index a1dfa095712..f4bd54e97af 100644 --- a/tests/models/t5/test_modeling_flax_t5.py +++ b/tests/models/t5/test_modeling_flax_t5.py @@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase): output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] self.assertTrue(output_str == "Hello there!") + @slow + def test_small_generation_bfloat16(self): + model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16) + model.config.max_length = 8 + model.config.num_beams = 1 + model.config.do_sample = False + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids + + sequences = model.generate(input_ids).sequences + + output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] + self.assertTrue(output_str == "Hello there!") + @slow def test_summarization(self): model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")