[Flax] Adapt Flax models to new structure (#9484)

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Fix code quality

* Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016

* Remove redundant ElectraPooler

* save intermediate

* adapt

* correct bert flax design

* adapt roberta as well

* finish roberta flax

* finish

* apply suggestions

* apply suggestions

Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
This commit is contained in:
Patrick von Platen 2021-03-18 09:44:17 +03:00 committed by GitHub
parent 5c0bf39782
commit 0b98ca368f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 316 additions and 500 deletions

View File

@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module):
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module):
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module):
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
mul = mul * jnp.asarray(self.gamma)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
y = y + jnp.asarray(self.beta)
return y
@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module):
vocab_size: int
hidden_size: int
kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
initializer_range: float
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
def setup(self):
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
class FlaxBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
w_emb = FlaxBertEmbedding(
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
def setup(self):
self.word_embeddings = FlaxBertEmbedding(
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4")))
p_emb = FlaxBertEmbedding(
self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
)
self.position_embeddings = FlaxBertEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxBertEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
)
self.token_type_embeddings = FlaxBertEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxBertAttention(nn.Module):
num_heads: int
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_states, attention_mask)
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
return layer_norm
hidden_states = self.layer_norm(self_attn_output + hidden_states)
return hidden_states
class FlaxBertIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states):
hidden_states = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class FlaxBertOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_states = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
return hidden_states
class FlaxBertLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention = FlaxBertAttention(
self.num_heads,
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxBertIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxBertOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
def setup(self):
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
return output
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
class FlaxBertLayerCollection(nn.Module):
"""
Stores N BertLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
def setup(self):
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxBertLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
class FlaxBertEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def setup(self):
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxBertLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
class FlaxBertPooler(nn.Module):
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
class FlaxBertPredictionHeadTransform(nn.Module):
hidden_act: str = "gelu"
config: BertConfig
dtype: jnp.dtype = jnp.float32
@nn.compact
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return self.layer_norm(hidden_states)
class FlaxBertLMPredictionHead(nn.Module):
vocab_size: int
hidden_act: str = "gelu"
config: BertConfig
dtype: jnp.dtype = jnp.float32
@nn.compact
def setup(self):
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
def __call__(self, hidden_states):
# TODO: The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
hidden_states = FlaxBertPredictionHeadTransform(
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
)(hidden_states)
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class FlaxBertOnlyMLMHead(nn.Module):
vocab_size: int
hidden_act: str = "gelu"
config: BertConfig
dtype: jnp.dtype = jnp.float32
@nn.compact
def setup(self):
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = FlaxBertLMPredictionHead(
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
)(hidden_states)
hidden_states = self.mlm_head(hidden_states)
return hidden_states
@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
dtype=dtype,
**kwargs,
)
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
class FlaxBertModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
@nn.compact
def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxBertEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxBertEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
if not self.add_pooling_layer:
return encoder
return hidden_states
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
pooled = self.pooler(hidden_states)
return hidden_states, pooled
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
def __init__(
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = FlaxBertForMaskedLMModule(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_encoder_layers=config.num_hidden_layers,
max_length=config.max_position_embeddings,
hidden_act=config.hidden_act,
**kwargs,
)
module = FlaxBertForMaskedLMModule(config, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
class FlaxBertForMaskedLMModule(nn.Module):
vocab_size: int
hidden_size: int
intermediate_size: int
head_size: int
num_heads: int
num_encoder_layers: int
type_vocab_size: int
max_length: int
hidden_act: str
dropout_rate: float = 0.0
config: BertConfig
dtype: jnp.dtype = jnp.float32
@nn.compact
def setup(self):
self.encoder = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
name="bert",
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.mlm_head = FlaxBertOnlyMLMHead(
config=self.config,
name="cls",
dtype=self.dtype,
)
def __call__(
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
encoder = FlaxBertModule(
vocab_size=self.vocab_size,
type_vocab_size=self.type_vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
head_size=self.hidden_size,
num_heads=self.num_heads,
num_encoder_layers=self.num_encoder_layers,
max_length=self.max_length,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
dtype=self.dtype,
add_pooling_layer=False,
name="bert",
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
hidden_states = self.encoder(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
# Compute the prediction scores
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
logits = FlaxBertOnlyMLMHead(
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
)(encoder)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.mlm_head(hidden_states)
return (logits,)

View File

@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module):
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
"""
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module):
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
@nn.compact
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module):
Returns:
Normalized inputs (the same shape as inputs).
"""
features = x.shape[-1]
mean = jnp.mean(x, axis=-1, keepdims=True)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
var = mean2 - jax.lax.square(mean)
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
mul = mul * jnp.asarray(self.gamma)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
y = y + jnp.asarray(self.beta)
return y
@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module):
vocab_size: int
hidden_size: int
kernel_init_scale: float = 0.2
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
initializer_range: float
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs):
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
return jnp.take(embedding, inputs, axis=0)
def setup(self):
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
def __call__(self, input_ids):
return jnp.take(self.embeddings, input_ids, axis=0)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
class FlaxRobertaEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
kernel_init_scale: float = 0.2
dropout_rate: float = 0.0
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
w_emb = FlaxRobertaEmbedding(
self.vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
def setup(self):
self.word_embeddings = FlaxRobertaEmbedding(
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(input_ids.astype("i4")))
p_emb = FlaxRobertaEmbedding(
self.max_length,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
)
self.position_embeddings = FlaxRobertaEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(position_ids.astype("i4")))
t_emb = FlaxRobertaEmbedding(
self.type_vocab_size,
self.hidden_size,
kernel_init_scale=self.kernel_init_scale,
)
self.token_type_embeddings = FlaxRobertaEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)(jnp.atleast_2d(token_type_ids.astype("i4")))
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
# Embed
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
# Sum all embeddings
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
return embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
num_heads: int
head_size: int
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(
num_heads=self.num_heads,
qkv_features=self.head_size,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)(hidden_states, attention_mask)
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
return layer_norm
hidden_states = self.layer_norm(self_attn_output + hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
class FlaxRobertaIntermediate(nn.Module):
output_size: int
hidden_act: str = "gelu"
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states):
hidden_states = nn.Dense(
features=self.output_size,
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(hidden_states)
hidden_states = ACT2FN[self.hidden_act](hidden_states)
)
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
class FlaxRobertaOutput(nn.Module):
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
hidden_states = nn.Dense(
attention_output.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(intermediate_output)
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
class FlaxRobertaLayer(nn.Module):
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention = FlaxRobertaAttention(
self.num_heads,
self.head_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="attention",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
intermediate = FlaxRobertaIntermediate(
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
hidden_act=self.hidden_act,
name="intermediate",
dtype=self.dtype,
)(attention)
output = FlaxRobertaOutput(
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
)(intermediate, attention, deterministic=deterministic)
def setup(self):
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
return output
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
class FlaxRobertaLayerCollection(nn.Module):
"""
Stores N RobertaLayer(s)
"""
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, inputs, attention_mask, deterministic: bool = True):
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
def setup(self):
self.layers = [
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
# Initialize input / output
input_i = inputs
# Forward over all encoders
for i in range(self.num_layers):
layer = FlaxRobertaLayer(
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name=f"{i}",
dtype=self.dtype,
)
input_i = layer(input_i, attention_mask, deterministic=deterministic)
return input_i
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
class FlaxRobertaEncoder(nn.Module):
num_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def setup(self):
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
layer = FlaxRobertaLayerCollection(
self.num_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
hidden_act=self.hidden_act,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="layer",
dtype=self.dtype,
)(hidden_states, attention_mask, deterministic=deterministic)
return layer
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
class FlaxRobertaPooler(nn.Module):
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
@nn.compact
def __call__(self, hidden_states):
cls_token = hidden_states[:, 0]
out = nn.Dense(
hidden_states.shape[-1],
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)(cls_token)
return nn.tanh(out)
)
def __call__(self, hidden_states):
cls_hidden_state = hidden_states[:, 0]
cls_hidden_state = self.dense(cls_hidden_state)
return nn.tanh(cls_hidden_state)
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = FlaxRobertaModule(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
type_vocab_size=config.type_vocab_size,
max_length=config.max_position_embeddings,
num_encoder_layers=config.num_hidden_layers,
num_heads=config.num_attention_heads,
head_size=config.hidden_size,
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
dtype=dtype,
**kwargs,
)
module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
class FlaxRobertaModule(nn.Module):
vocab_size: int
hidden_size: int
type_vocab_size: int
max_length: int
num_encoder_layers: int
num_heads: int
head_size: int
intermediate_size: int
hidden_act: str = "gelu"
dropout_rate: float = 0.0
kernel_init_scale: float = 0.2
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
@nn.compact
def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
# Embedding
embeddings = FlaxRobertaEmbeddings(
self.vocab_size,
self.hidden_size,
self.type_vocab_size,
self.max_length,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
name="embeddings",
dtype=self.dtype,
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
# N stacked encoding layers
encoder = FlaxRobertaEncoder(
self.num_encoder_layers,
self.num_heads,
self.head_size,
self.intermediate_size,
kernel_init_scale=self.kernel_init_scale,
dropout_rate=self.dropout_rate,
hidden_act=self.hidden_act,
name="encoder",
dtype=self.dtype,
)(embeddings, attention_mask, deterministic=deterministic)
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
if not self.add_pooling_layer:
return encoder
return hidden_states
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
return encoder, pooled
pooled = self.pooler(hidden_states)
return hidden_states, pooled

View File

@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
return attn_mask
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__