Fix TFRemBertEncoder all_hidden_states (#15510)

* fix

* fix test

* remove expected_num_hidden_layers

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-02-04 17:32:14 +01:00 committed by GitHub
parent 854a0d526c
commit bbe9c6981b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 4 deletions

View File

@ -477,7 +477,7 @@ class TFRemBertEncoder(tf.keras.layers.Layer):
training: bool = False,
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
all_hidden_states = (hidden_states,) if output_hidden_states else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

View File

@ -90,9 +90,6 @@ class TFRemBertModelTester:
self.num_choices = 4
self.scope = None
# RemBERT also returns the upprojected word embeddings as an hidden layers
self.expected_num_hidden_layers = self.num_hidden_layers + 2
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)