From 2529b2d37ee234c7b100f3896ce1688fc60580bd Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 4 Mar 2020 00:41:05 +0100 Subject: [PATCH] set redorder past sort dimension to its default --- src/transformers/modeling_tf_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index bb1856308a2..e3083b6d203 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -941,9 +941,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): for layer_past in past: # get the correct batch idx from layer past batch dim # batch dim of `past` and `mems` is at 2nd position - reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx] + reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx] # TODO: check whether it is an error that TF past.shape != Torch past.shape - reordered_layer_past = tf.concat(reordered_layer_past, axis=0) + reordered_layer_past = tf.concat(reordered_layer_past, axis=1) # check that shape matches assert shape_list(reordered_layer_past) == shape_list(layer_past) reordered_past.append(reordered_layer_past)