diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 09ec877022f..52dfa5e3922 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1467,7 +1467,8 @@ class BartForConditionalGeneration(BartPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1946,5 +1947,7 @@ class BartForCausalLM(BartPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index de6ed382e47..29846b8051f 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1294,7 +1294,9 @@ class BertLMHeadModel(BertPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 56d32e2910c..f245ac155e7 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -1002,5 +1002,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 45b51a21524..867aca67e99 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2638,7 +2638,8 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 0e0dda8b695..a32f3ecde76 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2651,7 +2651,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -3121,5 +3122,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index fa2cb1994de..7534ed17fe8 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -752,7 +752,9 @@ class BioGptForCausalLM(BioGptPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 20779fe66a9..bdb8c52a552 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1412,7 +1412,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1634,5 +1635,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index dc2f1512ffe..a1e888aec90 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1379,7 +1379,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1601,5 +1602,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 59e5539340c..2ae3ac053be 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -934,5 +934,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index d12882dfec3..4635c061980 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1551,7 +1551,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index f949f3cce59..7cbaee69256 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -1018,7 +1018,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 8469b86eb9d..67142229aad 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -881,7 +881,9 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index ad9e17cf6ed..c06d306c1a2 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1677,5 +1677,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index d0a3cc9a81b..7ee6f438129 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1236,7 +1236,9 @@ class ErnieForCausalLM(ErniePreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index eafb72f3523..00707e42dd0 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1568,5 +1568,7 @@ class GitForCausalLM(GitPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 6710892dc5c..9956aae6b30 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -839,7 +839,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e66c51c1b4f..746ae968109 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -723,6 +723,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index e405098bf09..f0c22ed9502 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2509,7 +2509,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index b6c31518390..88e543b54b5 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1385,5 +1385,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 6c287151eea..b4e3aac5be0 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1532,7 +1532,8 @@ class MarianMTModel(MarianPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1752,5 +1753,7 @@ class MarianForCausalLM(MarianPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 678367ff94d..ca6bea40337 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -961,7 +961,9 @@ class MarkupLMModel(MarkupLMPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 6faa2d7bc7e..b035ef7b10d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1431,7 +1431,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1904,5 +1905,7 @@ class MBartForCausalLM(MBartPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/mega/modeling_mega.py b/src/transformers/models/mega/modeling_mega.py index 50950e88725..45ce5242428 100644 --- a/src/transformers/models/mega/modeling_mega.py +++ b/src/transformers/models/mega/modeling_mega.py @@ -1341,7 +1341,7 @@ class MegaPreTrainedModel(PreTrainedModel): config_class = MegaConfig base_model_prefix = "mega" supports_gradient_checkpointing = False - _no_split_modules = [] + _no_split_modules = ["MegaMovingAverageGatedAttention"] def _init_weights(self, module): """Initialize the weights""" @@ -1802,7 +1802,9 @@ class MegaForCausalLM(MegaPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index f56745dace0..1c1eeff667d 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1260,7 +1260,9 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index c42ef51c531..21a82f95c33 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1595,7 +1595,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -2066,5 +2067,7 @@ class MvpForCausalLM(MvpPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 53f01328c9e..f37f64627df 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1826,5 +1826,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index a6706bfeea0..d24211f0393 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -1003,7 +1003,9 @@ class OPTForCausalLM(OPTPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index b64833a8f6a..67934520fbb 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1489,7 +1489,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1731,5 +1732,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 9311c45e651..def82bdbaa7 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1691,7 +1691,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 62f62dbb953..93532f4b0d8 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1402,7 +1402,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -1751,5 +1752,7 @@ class PLBartForCausalLM(PLBartPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 1b771705ab7..241a9efea36 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2068,7 +2068,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -2312,7 +2313,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 6cd7511e00c..47546930ebd 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -1160,7 +1160,9 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 15f4fca475d..7048168a064 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1213,7 +1213,9 @@ class RagTokenForGeneration(RagPreTrainedModel): reordered_past = () for layer_past in past_key_values: # get the correct batch idx from decoder layer's batch dim for cross and self-attn - reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 14f735c7e6a..275a1e1dc73 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2300,12 +2300,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): for layer_past in past_key_values: # buckets if layer_past[0] is not None: - reord_buckets = layer_past[0].index_select(0, beam_idx) + reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)) else: reord_buckets = None # hidden states - reord_hidden_states = layer_past[1].index_select(0, beam_idx) + reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)) reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) return reord_past_buckets_states diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index f25e8efbe32..745be26ebfc 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1157,7 +1157,8 @@ class RemBertForCausalLM(RemBertPreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 38c2d74acc3..67e0fee422c 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1016,7 +1016,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 9dd4b23c43c..ddd87fa9ce0 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -1023,7 +1023,9 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index f6b5df11c6a..35d4be9f20e 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1580,7 +1580,9 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 9f040174d39..2c3feeda127 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1188,7 +1188,8 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 60889972a5b..31c9b6cfe93 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1418,5 +1418,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 1bd51c4e0d4..bfd801b2427 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -976,5 +976,7 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 3c2d355f1d7..48334deb377 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2525,7 +2525,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 5676689209e..50829592a02 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -1016,5 +1016,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 16d92c70e09..8323054144f 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1331,7 +1331,9 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 926101156d1..cc380a9dcc9 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1794,7 +1794,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02): diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 2279537edb9..9ef15977a3d 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -881,5 +881,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index c84e3fac5ae..cde05cfe8a8 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -2095,7 +2095,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], ) return reordered_past @@ -2340,7 +2341,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 44b7fd4cca9..761e96a11b7 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1020,7 +1020,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 8aac5a8bb1b..025bab3887c 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -979,7 +979,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 0940b1cc6c6..61002bd2772 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -1128,7 +1128,9 @@ class XmodForCausalLM(XmodPreTrainedModel): def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index f0960a38412..02fcb7d2f51 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],) return reordered_past class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): @@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past @@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past {% endif -%} diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7373ed6cb80..f73e3f60a55 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -15,13 +15,14 @@ import inspect +import tempfile import unittest import warnings import numpy as np from transformers import is_torch_available, pipeline -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device from ..test_modeling_common import floats_tensor, ids_tensor from .test_framework_agnostic import GenerationIntegrationTestsMixin @@ -1017,6 +1018,27 @@ class GenerationTesterMixin: output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams ) + @require_accelerate + @require_torch_multi_gpu + def test_model_parallel_beam_search(self): + for model_class in self.all_generative_model_classes: + if model_class._no_split_modules is None: + continue + + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + model = model_class(config).eval() + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + new_model = model_class.from_pretrained(tmp_dir, device_map="auto") + + new_model.generate( + input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_beams=2, + ) + def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eed704d3bca..0f7feaa9c56 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2482,34 +2482,6 @@ class ModelTesterMixin: for value_, parallel_value_ in zip(value, parallel_value): self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7)) - @require_torch_multi_gpu - def test_model_parallel_beam_search(self): - if not self.test_model_parallel: - return - - all_generative_and_parallelizable_model_classes = tuple( - set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes) - ) - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in all_generative_and_parallelizable_model_classes: - inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) - - def cast_to_device(dictionary, device): - output = {} - for k, v in dictionary.items(): - if isinstance(v, torch.Tensor): - output[k] = v.to(device) - else: - output[k] = v - - return output - - model.parallelize() - model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) - def check_device_map_is_respected(self, model, device_map): for param_name, param in model.named_parameters(): # Find device in device_map