mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
set redorder past sort dimension to its default
This commit is contained in:
parent
61fef6e957
commit
2529b2d37e
@ -941,9 +941,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
for layer_past in past:
|
for layer_past in past:
|
||||||
# get the correct batch idx from layer past batch dim
|
# get the correct batch idx from layer past batch dim
|
||||||
# batch dim of `past` and `mems` is at 2nd position
|
# batch dim of `past` and `mems` is at 2nd position
|
||||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
|
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
|
||||||
# TODO: check whether it is an error that TF past.shape != Torch past.shape
|
# TODO: check whether it is an error that TF past.shape != Torch past.shape
|
||||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
|
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
|
||||||
# check that shape matches
|
# check that shape matches
|
||||||
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
||||||
reordered_past.append(reordered_layer_past)
|
reordered_past.append(reordered_layer_past)
|
||||||
|
Loading…
Reference in New Issue
Block a user