diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 1ae9f05230c..34fb9480004 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -478,7 +478,14 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxAutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxAutoModelForMaskedLM.from_config( + config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) # Store some constant num_epochs = int(training_args.num_train_epochs) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 09d6753a48a..c8b6fa3da69 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -588,7 +588,12 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxT5ForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Data collator # This one will take care of randomly masking the tokens. diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index 48de9380df3..1e2c29947ee 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -427,7 +427,14 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxAutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxAutoModelForMaskedLM.from_config( + config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) # Store some constant num_epochs = int(training_args.num_train_epochs)