Merge pull request #1668 from tlkh/fix-tf-xlm

Fixed training for TF XLM
This commit is contained in:
Thomas Wolf 2019-10-30 17:08:00 +01:00 committed by GitHub
commit 22838f19fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
attn_mask = mask
# sanity check
assert shape_list(mask) == [bs, slen]
# assert shape_list(mask) == [bs, slen]
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
mask = tf.cast(mask, dtype=dtype)
@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
bs, slen = shape_list(input_ids)
assert shape_list(lengths)[0] == bs
# assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0)
else:
assert shape_list(position_ids) == [bs, slen] # (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
assert shape_list(langs) == [bs, slen] # (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
# langs = langs.transpose(0, 1)
# Prepare head mask if needed