From 842f3bf049d4a728cb4bff543e8bbd74020af230 Mon Sep 17 00:00:00 2001 From: Timothy Liu Date: Wed, 30 Oct 2019 01:32:15 +0000 Subject: [PATCH] Fixed training for TF XLM --- transformers/modeling_tf_xlm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/transformers/modeling_tf_xlm.py b/transformers/modeling_tf_xlm.py index 84de1517ee3..9ac5d28e1f1 100644 --- a/transformers/modeling_tf_xlm.py +++ b/transformers/modeling_tf_xlm.py @@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): attn_mask = mask # sanity check - assert shape_list(mask) == [bs, slen] + # assert shape_list(mask) == [bs, slen] + tf.debugging.assert_equal(shape_list(mask), [bs, slen]) assert causal is False or shape_list(attn_mask) == [bs, slen, slen] mask = tf.cast(mask, dtype=dtype) @@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): # check inputs bs, slen = shape_list(input_ids) - assert shape_list(lengths)[0] == bs + # assert shape_list(lengths)[0] == bs + tf.debugging.assert_equal(shape_list(lengths)[0], bs) # assert lengths.max().item() <= slen # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # assert (src_enc is None) == (src_len is None) @@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer): if position_ids is None: position_ids = tf.expand_dims(tf.range(slen), axis=0) else: - assert shape_list(position_ids) == [bs, slen] # (slen, bs) + # assert shape_list(position_ids) == [bs, slen] # (slen, bs) + tf.debugging.assert_equal(shape_list(position_ids), [bs, slen]) # position_ids = position_ids.transpose(0, 1) # langs if langs is not None: - assert shape_list(langs) == [bs, slen] # (slen, bs) + # assert shape_list(langs) == [bs, slen] # (slen, bs) + tf.debugging.assert_equal(shape_list(langs), [bs, slen]) # langs = langs.transpose(0, 1) # Prepare head mask if needed