diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index 75fd64eb085..5346904d84c 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -129,7 +129,7 @@ look at [this](https://colab.research.google.com/github/huggingface/notebooks/bl In the following, we demonstrate how to train an auto-regressive causal transformer model in JAX/Flax. -More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8. +More specifically, we pretrain a randomly initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8. to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2) in Norwegian on a single TPUv3-8 pod. diff --git a/examples/research_projects/jax-projects/README.md b/examples/research_projects/jax-projects/README.md index 0b3f0dc5d24..e1f37d12fb4 100644 --- a/examples/research_projects/jax-projects/README.md +++ b/examples/research_projects/jax-projects/README.md @@ -710,7 +710,7 @@ class FlaxMLPModel(FlaxMLPPreTrainedModel): module_class = FlaxMLPModule ``` -Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomely initialized weights to the model instance. +Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomly initialized weights to the model instance. So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_flax_bert.py#L536) diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index e6bbdbee8cb..e4bec5e2886 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -562,7 +562,7 @@ if __name__ == "__main__": samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length) except StopIteration: # Once the end of the dataset stream is reached, the training iterator - # is reinitialized and reshuffled and a new eval dataset is randomely chosen. + # is reinitialized and reshuffled and a new eval dataset is randomly chosen. shuffle_seed += 1 tokenized_datasets.set_epoch(shuffle_seed) diff --git a/tests/generation/test_beam_search.py b/tests/generation/test_beam_search.py index 426f9d872ce..72202ae2dad 100644 --- a/tests/generation/test_beam_search.py +++ b/tests/generation/test_beam_search.py @@ -59,7 +59,7 @@ class BeamSearchTester: self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - # cannot be randomely generated + # cannot be randomly generated self.eos_token_id = vocab_size + 1 def prepare_beam_scorer(self, **kwargs): @@ -283,7 +283,7 @@ class ConstrainedBeamSearchTester: constraints = [PhrasalConstraint(force_tokens), DisjunctiveConstraint(disjunctive_tokens)] self.constraints = constraints - # cannot be randomely generated + # cannot be randomly generated self.eos_token_id = vocab_size + 1 def prepare_constrained_beam_scorer(self, **kwargs):