mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[FLAX] Add dtype to embedding for bert/bart/opt/t5 (#20340)
* [FLAX] Add dtype to embedding for bert/bart/opt/t5 * Fix all copies * Add a test case
This commit is contained in:
parent
667ccea722
commit
ac2f6674a3
@ -715,6 +715,7 @@ class FlaxBartEncoder(nn.Module):
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
||||
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||
@ -779,6 +780,7 @@ class FlaxBartDecoder(nn.Module):
|
||||
self.config.max_position_embeddings + self.offset,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
||||
@ -842,6 +844,7 @@ class FlaxBartModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
@ -1888,6 +1891,7 @@ class FlaxBartDecoderWrapper(nn.Module):
|
||||
self.config.vocab_size,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
|
||||
|
||||
|
@ -187,16 +187,19 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
@ -205,16 +205,19 @@ class FlaxBigBirdEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
@ -817,6 +817,7 @@ class FlaxBlenderbotModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.encoder = FlaxBlenderbotEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
|
@ -815,6 +815,7 @@ class FlaxBlenderbotSmallModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.encoder = FlaxBlenderbotSmallEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
|
@ -368,6 +368,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
self.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
embedding_init=jax.nn.initializers.normal(kv_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -2032,6 +2033,7 @@ class FlaxLongT5Module(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
@ -2160,6 +2162,7 @@ class FlaxLongT5ForConditionalGenerationModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
|
@ -881,6 +881,7 @@ class FlaxMBartModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
|
@ -436,12 +436,14 @@ class FlaxOPTDecoder(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.word_embed_proj_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
|
||||
self.config.max_position_embeddings,
|
||||
embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
if self.config.word_embed_proj_dim != self.config.hidden_size:
|
||||
|
@ -831,6 +831,7 @@ class FlaxPegasusModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.encoder = FlaxPegasusEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
|
||||
|
@ -147,16 +147,19 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = nn.Embed(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = nn.Embed(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
@ -228,6 +228,7 @@ class FlaxT5Attention(nn.Module):
|
||||
self.relative_attention_num_buckets,
|
||||
self.n_heads,
|
||||
embedding_init=jax.nn.initializers.normal(kv_init_std),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1292,6 +1293,7 @@ class FlaxT5Module(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
@ -1417,6 +1419,7 @@ class FlaxT5EncoderModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
@ -1512,6 +1515,7 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.d_model,
|
||||
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
encoder_config = copy.deepcopy(self.config)
|
||||
|
@ -865,6 +865,21 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_small_generation_bfloat16(self):
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype=jnp.bfloat16)
|
||||
model.config.max_length = 8
|
||||
model.config.num_beams = 1
|
||||
model.config.do_sample = False
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
|
||||
|
||||
sequences = model.generate(input_ids).sequences
|
||||
|
||||
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
self.assertTrue(output_str == "Hello there!")
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
|
||||
|
Loading…
Reference in New Issue
Block a user