mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
set redorder past sort dimension to its default
This commit is contained in:
parent
61fef6e957
commit
2529b2d37e
@ -941,9 +941,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
|
||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
|
||||
# TODO: check whether it is an error that TF past.shape != Torch past.shape
|
||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
|
||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
|
||||
# check that shape matches
|
||||
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
||||
reordered_past.append(reordered_layer_past)
|
||||
|
Loading…
Reference in New Issue
Block a user