[FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340)

* [FLAX] Add dtype to embedding for bert/bart/opt/t5

* Fix all copies

* Add a test case
This commit is contained in:
Lianmin Zheng 2022-11-28 07:21:42 -08:00 committed by GitHub
parent 667ccea722
commit ac2f6674a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 41 additions and 0 deletions

View File

@ -715,6 +715,7 @@ class FlaxBartEncoder(nn.Module):
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
@ -779,6 +780,7 @@ class FlaxBartDecoder(nn.Module):
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
@ -842,6 +844,7 @@ class FlaxBartModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
@ -1888,6 +1891,7 @@ class FlaxBartDecoderWrapper(nn.Module):
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)

View File

@ -187,16 +187,19 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

View File

@ -205,16 +205,19 @@ class FlaxBigBirdEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

View File

@ -817,6 +817,7 @@ class FlaxBlenderbotModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

View File

@ -815,6 +815,7 @@ class FlaxBlenderbotSmallModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

View File

@ -368,6 +368,7 @@ class FlaxLongT5Attention(nn.Module):
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
@staticmethod
@ -2032,6 +2033,7 @@ class FlaxLongT5Module(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)
@ -2160,6 +2162,7 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)

View File

@ -881,6 +881,7 @@ class FlaxMBartModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

View File

@ -436,12 +436,14 @@ class FlaxOPTDecoder(nn.Module):
self.config.vocab_size,
self.config.word_embed_proj_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
self.config.max_position_embeddings,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
if self.config.word_embed_proj_dim != self.config.hidden_size:

View File

@ -831,6 +831,7 @@ class FlaxPegasusModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
dtype=self.dtype,
)
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

View File

@ -147,16 +147,19 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.position_embeddings = nn.Embed(
self.config.max_position_embeddings,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.token_type_embeddings = nn.Embed(
self.config.type_vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

View File

@ -228,6 +228,7 @@ class FlaxT5Attention(nn.Module):
self.relative_attention_num_buckets,
self.n_heads,
embedding_init=jax.nn.initializers.normal(kv_init_std),
dtype=self.dtype,
)
@staticmethod
@ -1292,6 +1293,7 @@ class FlaxT5Module(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)
@ -1417,6 +1419,7 @@ class FlaxT5EncoderModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)
@ -1512,6 +1515,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
dtype=self.dtype,
)
encoder_config = copy.deepcopy(self.config)

View File

@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_small_generation_bfloat16(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
sequences = model.generate(input_ids).sequences
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_summarization(self):
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")