tests: Fix flaky test for NLLB-MoE (#22880)

* add test update and docs edits

* docs edit suggestion
This commit is contained in:
Connor Henderson 2023-04-21 12:09:40 -04:00 committed by GitHub
parent d00997e66c
commit b950c38565
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 10 deletions

View File

@ -43,9 +43,10 @@ This model was contributed by [Arthur Zucker](https://huggingface.co/ArtZucker).
The original code can be found [here](https://github.com/facebookresearch/fairseq).
## Implementation differences with SwitchTransformers
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that blah blah blah blah.
In SwitchTransformers, once the masks are computed for each experts, we just index the current hidden_states with the routing mask, and feed the
correct tokens to the expert. However here, the implementation varies a lot as the fairseq repository used a different approach.
The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
## Generating with NLLB-MoE
The avalable checkpoints requires around 350GB of storage. Make sure to use `accelerate` if you do not have enough RAM on your machine.

View File

@ -52,7 +52,7 @@ class NllbMoeConfig(PretrainedConfig):
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.

View File

@ -460,7 +460,7 @@ class NllbMoeSparseMLP(nn.Module):
Attention mask. Can be in the causal form or not.
Returns:
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_lenght, hidden_dim)`):
hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_dim)`):
Updated hidden states
router_logits (`torch.Tensor` of shape `(batch_size, sequence_length, num_experts)`):
Needed for computing the loss

View File

@ -21,7 +21,6 @@ import unittest
from transformers import NllbMoeConfig, is_torch_available, set_seed
from transformers.testing_utils import (
is_flaky,
require_sentencepiece,
require_tokenizers,
require_torch,
@ -210,7 +209,7 @@ class NllbMoeModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = NllbMoeModel(config=config).to(torch_device).eval()
@ -290,10 +289,10 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
@is_flaky()
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.decoder_sparse_step = 0
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, inputs_dict)
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()