mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Flax] Adapt Flax models to new structure (#9484)
* Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Fix code quality * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016 * Remove redundant ElectraPooler * save intermediate * adapt * correct bert flax design * adapt roberta as well * finish roberta flax * finish * apply suggestions * apply suggestions Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
This commit is contained in:
parent
5c0bf39782
commit
0b98ca368f
@ -97,6 +97,7 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
@ -106,7 +107,10 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
@ -119,18 +123,17 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
||||
mul = mul * jnp.asarray(self.gamma)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
||||
y = y + jnp.asarray(self.beta)
|
||||
return y
|
||||
|
||||
|
||||
@ -142,278 +145,232 @@ class FlaxBertEmbedding(nn.Module):
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
kernel_init_scale: float = 0.2
|
||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
||||
initializer_range: float
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
def setup(self):
|
||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||
|
||||
def __call__(self, input_ids):
|
||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||
|
||||
|
||||
class FlaxBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dropout_rate: float = 0.0
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxBertEmbedding(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
def setup(self):
|
||||
self.word_embeddings = FlaxBertEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
p_emb = FlaxBertEmbedding(
|
||||
self.max_length,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
)
|
||||
self.position_embeddings = FlaxBertEmbedding(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
t_emb = FlaxBertEmbedding(
|
||||
self.type_vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
)
|
||||
self.token_type_embeddings = FlaxBertEmbedding(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
# Embed
|
||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
||||
return embeddings
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
def setup(self):
|
||||
self.self_attention = nn.attention.SelfAttention(
|
||||
num_heads=self.config.num_attention_heads,
|
||||
qkv_features=self.config.hidden_size,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads,
|
||||
qkv_features=self.head_size,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=deterministic,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask)
|
||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||
return layer_norm
|
||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOutput(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_states = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxBertAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxBertIntermediate(
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
output = FlaxBertOutput(
|
||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||
)(intermediate, attention, deterministic=deterministic)
|
||||
def setup(self):
|
||||
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
|
||||
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
|
||||
|
||||
return output
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N BertLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxBertLayer(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
||||
return input_i
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxBertLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxBertPooler(nn.Module):
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
cls_token = hidden_states[:, 0]
|
||||
out = nn.Dense(
|
||||
hidden_states.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(cls_token)
|
||||
return nn.tanh(out)
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
cls_hidden_state = hidden_states[:, 0]
|
||||
cls_hidden_state = self.dense(cls_hidden_state)
|
||||
return nn.tanh(cls_hidden_state)
|
||||
|
||||
|
||||
class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
hidden_act: str = "gelu"
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return self.layer_norm(hidden_states)
|
||||
|
||||
|
||||
class FlaxBertLMPredictionHead(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_act: str = "gelu"
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
|
||||
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
# TODO: The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
# Need a link between the two variables so that the bias is correctly
|
||||
# resized with `resize_token_embeddings`
|
||||
|
||||
hidden_states = FlaxBertPredictionHeadTransform(
|
||||
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
|
||||
)(hidden_states)
|
||||
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertOnlyMLMHead(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_act: str = "gelu"
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = FlaxBertLMPredictionHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
|
||||
)(hidden_states)
|
||||
hidden_states = self.mlm_head(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -543,20 +500,7 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
module = FlaxBertModule(config=config, dtype=dtype, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@ -592,71 +536,34 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||
|
||||
|
||||
class FlaxBertModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxBertEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxBertEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
return hidden_states
|
||||
|
||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
pooled = self.pooler(hidden_states)
|
||||
return hidden_states, pooled
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxBertForMaskedLMModule(
|
||||
vocab_size=config.vocab_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
head_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
max_length=config.max_position_embeddings,
|
||||
hidden_act=config.hidden_act,
|
||||
**kwargs,
|
||||
)
|
||||
module = FlaxBertForMaskedLMModule(config, **kwargs)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@ -691,43 +598,32 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||
|
||||
|
||||
class FlaxBertForMaskedLMModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
hidden_act: str
|
||||
dropout_rate: float = 0.0
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.encoder = FlaxBertModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.mlm_head = FlaxBertOnlyMLMHead(
|
||||
config=self.config,
|
||||
name="cls",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder = FlaxBertModule(
|
||||
vocab_size=self.vocab_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
head_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
max_length=self.max_length,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
hidden_states = self.encoder(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertOnlyMLMHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
||||
)(encoder)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.mlm_head(hidden_states)
|
||||
|
||||
return (logits,)
|
||||
|
@ -114,6 +114,7 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
@ -123,7 +124,10 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
@ -136,18 +140,17 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)))
|
||||
mul = mul * jnp.asarray(self.gamma)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)))
|
||||
y = y + jnp.asarray(self.beta)
|
||||
return y
|
||||
|
||||
|
||||
@ -160,243 +163,202 @@ class FlaxRobertaEmbedding(nn.Module):
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
kernel_init_scale: float = 0.2
|
||||
emb_init: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=kernel_init_scale)
|
||||
initializer_range: float
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
def setup(self):
|
||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
||||
|
||||
def __call__(self, input_ids):
|
||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||
class FlaxRobertaEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
kernel_init_scale: float = 0.2
|
||||
dropout_rate: float = 0.0
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxRobertaEmbedding(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
def setup(self):
|
||||
self.word_embeddings = FlaxRobertaEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
p_emb = FlaxRobertaEmbedding(
|
||||
self.max_length,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
)
|
||||
self.position_embeddings = FlaxRobertaEmbedding(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
t_emb = FlaxRobertaEmbedding(
|
||||
self.type_vocab_size,
|
||||
self.hidden_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
)
|
||||
self.token_type_embeddings = FlaxRobertaEmbedding(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
# Embed
|
||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
|
||||
embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
|
||||
return embeddings
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||
class FlaxRobertaAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
def setup(self):
|
||||
self.self_attention = nn.attention.SelfAttention(
|
||||
num_heads=self.config.num_attention_heads,
|
||||
qkv_features=self.config.hidden_size,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads,
|
||||
qkv_features=self.head_size,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=deterministic,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask)
|
||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||
return layer_norm
|
||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
class FlaxRobertaIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = nn.Dense(
|
||||
features=self.output_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states)
|
||||
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||
class FlaxRobertaOutput(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||
hidden_states = nn.Dense(
|
||||
attention_output.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(intermediate_output)
|
||||
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
|
||||
class FlaxRobertaLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention = FlaxRobertaAttention(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="attention",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
intermediate = FlaxRobertaIntermediate(
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
name="intermediate",
|
||||
dtype=self.dtype,
|
||||
)(attention)
|
||||
output = FlaxRobertaOutput(
|
||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||
)(intermediate, attention, deterministic=deterministic)
|
||||
def setup(self):
|
||||
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
|
||||
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
|
||||
|
||||
return output
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.intermediate(attention_output)
|
||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||
class FlaxRobertaLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N RobertaLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask, deterministic: bool = True):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxRobertaLayer(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name=f"{i}",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
input_i = layer(input_i, attention_mask, deterministic=deterministic)
|
||||
return input_i
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||
class FlaxRobertaEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
layer = FlaxRobertaLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="layer",
|
||||
dtype=self.dtype,
|
||||
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return layer
|
||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||
class FlaxRobertaPooler(nn.Module):
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_states):
|
||||
cls_token = hidden_states[:, 0]
|
||||
out = nn.Dense(
|
||||
hidden_states.shape[-1],
|
||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)(cls_token)
|
||||
return nn.tanh(out)
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
cls_hidden_state = hidden_states[:, 0]
|
||||
cls_hidden_state = self.dense(cls_hidden_state)
|
||||
return nn.tanh(cls_hidden_state)
|
||||
|
||||
|
||||
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
@ -520,21 +482,7 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
module = FlaxRobertaModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
hidden_act=config.hidden_act,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
module = FlaxRobertaModule(config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@ -570,50 +518,24 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
dropout_rate: float = 0.0
|
||||
kernel_init_scale: float = 0.2
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
@nn.compact
|
||||
def setup(self):
|
||||
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
self.type_vocab_size,
|
||||
self.max_length,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
name="embeddings",
|
||||
dtype=self.dtype,
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
kernel_init_scale=self.kernel_init_scale,
|
||||
dropout_rate=self.dropout_rate,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
dtype=self.dtype,
|
||||
)(embeddings, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||
)
|
||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
return hidden_states
|
||||
|
||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||
return encoder, pooled
|
||||
pooled = self.pooler(hidden_states)
|
||||
return hidden_states, pooled
|
||||
|
@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@require_flax
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
|
Loading…
Reference in New Issue
Block a user