Merge pull request #1736 from huggingface/fix-tf-xlnet

Fix TFXLNet
This commit is contained in:
Thomas Wolf 2019-12-21 12:42:05 +01:00 committed by GitHub
commit 8618bf15d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -552,7 +552,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask
input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None: