mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Padded Batch Error 12282 (#12487)
This fixes the padded batch [issue](https://github.com/huggingface/transformers/issues/12282). The error was generated due to the maximum sequence length of the attention mask not matching the padded sequence length of the hidden_states. `np.allclose` now passes with a 1e-2 absolute tolerance. This change fixes
This commit is contained in:
parent
7fae535052
commit
6f8e367ae9
@ -1213,7 +1213,10 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
|
||||
if inputs["attention_mask"] is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
|
||||
attention_mask = tf.sequence_mask(output_lengths, dtype=hidden_states.dtype)
|
||||
|
||||
attention_mask = tf.sequence_mask(
|
||||
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user