mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1668 from tlkh/fix-tf-xlm
Fixed training for TF XLM
This commit is contained in:
commit
22838f19fd
@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
attn_mask = mask
|
||||
|
||||
# sanity check
|
||||
assert shape_list(mask) == [bs, slen]
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
|
||||
mask = tf.cast(mask, dtype=dtype)
|
||||
@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# check inputs
|
||||
bs, slen = shape_list(input_ids)
|
||||
assert shape_list(lengths)[0] == bs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
||||
# assert lengths.max().item() <= slen
|
||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||
# assert (src_enc is None) == (src_len is None)
|
||||
@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
if position_ids is None:
|
||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if langs is not None:
|
||||
assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
|
||||
# langs = langs.transpose(0, 1)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
Loading…
Reference in New Issue
Block a user