mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] Fix BERT initialization & token_type_ids default (#11695)
* fix some stuff * fix roberta & electra as well * del run bug Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
daf0d6a97b
commit
57b6a80de8
@ -558,7 +558,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
||||
"params"
|
||||
]
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
@ -587,7 +589,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
@ -502,14 +502,16 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
||||
"params"
|
||||
]
|
||||
|
||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
|
@ -546,7 +546,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
||||
"params"
|
||||
]
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def __call__(
|
||||
@ -575,7 +577,7 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
# init input tensors if not passed
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||
|
Loading…
Reference in New Issue
Block a user