set redorder past sort dimension to its default

This commit is contained in:
patrickvonplaten 2020-03-04 00:41:05 +01:00 committed by Patrick von Platen
parent 61fef6e957
commit 2529b2d37e

View File

@ -941,9 +941,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)