mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix beam search when using model parallel (#24969)
* Fix GPTNeoX beam search when using parallelize * Fix beam search idx device when using model parallel * remove onnx related stuff Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix: move test_beam_search_on_multi_gpu to GenerationTesterMixin * fix: add right item to _no_split_modules of MegaPreTrainedModel * fix: add num_beams within parallelized beam_search test Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
0dd06c3f78
commit
8881f38a4f
@ -1467,7 +1467,8 @@ class BartForConditionalGeneration(BartPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1946,5 +1947,7 @@ class BartForCausalLM(BartPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1294,7 +1294,9 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1002,5 +1002,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -2638,7 +2638,8 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -2651,7 +2651,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -3121,5 +3122,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -752,7 +752,9 @@ class BioGptForCausalLM(BioGptPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1412,7 +1412,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1634,5 +1635,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1379,7 +1379,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1601,5 +1602,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -934,5 +934,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1551,7 +1551,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1018,7 +1018,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -881,7 +881,9 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1677,5 +1677,7 @@ class ElectraForCausalLM(ElectraPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1236,7 +1236,9 @@ class ErnieForCausalLM(ErniePreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1568,5 +1568,7 @@ class GitForCausalLM(GitPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -839,7 +839,8 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -723,6 +723,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -2509,7 +2509,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -1385,5 +1385,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1532,7 +1532,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1752,5 +1753,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -961,7 +961,9 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1431,7 +1431,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1904,5 +1905,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1341,7 +1341,7 @@ class MegaPreTrainedModel(PreTrainedModel):
|
||||
config_class = MegaConfig
|
||||
base_model_prefix = "mega"
|
||||
supports_gradient_checkpointing = False
|
||||
_no_split_modules = []
|
||||
_no_split_modules = ["MegaMovingAverageGatedAttention"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -1802,7 +1802,9 @@ class MegaForCausalLM(MegaPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1260,7 +1260,9 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1595,7 +1595,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -2066,5 +2067,7 @@ class MvpForCausalLM(MvpPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1826,5 +1826,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1003,7 +1003,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1489,7 +1489,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1731,5 +1732,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1691,7 +1691,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -1402,7 +1402,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -1751,5 +1752,7 @@ class PLBartForCausalLM(PLBartPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -2068,7 +2068,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -2312,7 +2313,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1160,7 +1160,9 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1213,7 +1213,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
|
||||
return reordered_past
|
||||
|
||||
|
@ -2300,12 +2300,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# buckets
|
||||
if layer_past[0] is not None:
|
||||
reord_buckets = layer_past[0].index_select(0, beam_idx)
|
||||
reord_buckets = layer_past[0].index_select(0, beam_idx.to(layer_past[0].device))
|
||||
else:
|
||||
reord_buckets = None
|
||||
|
||||
# hidden states
|
||||
reord_hidden_states = layer_past[1].index_select(0, beam_idx)
|
||||
reord_hidden_states = layer_past[1].index_select(0, beam_idx.to(layer_past[1].device))
|
||||
reord_past_buckets_states.append((reord_buckets, reord_hidden_states))
|
||||
return reord_past_buckets_states
|
||||
|
||||
|
@ -1157,7 +1157,8 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -1016,7 +1016,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1023,7 +1023,9 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1580,7 +1580,9 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1188,7 +1188,8 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
@ -1418,5 +1418,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -976,5 +976,7 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -2525,7 +2525,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1016,5 +1016,7 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -1331,7 +1331,9 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1794,7 +1794,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
|
||||
|
@ -881,5 +881,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
@ -2095,7 +2095,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
|
||||
for layer_past in past_key_values:
|
||||
# cached cross_attention states don't have to be reordered -> they are always the same
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
@ -2340,7 +2341,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1020,7 +1020,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -979,7 +979,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1128,7 +1128,9 @@ class XmodForCausalLM(XmodPreTrainedModel):
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
|
@ -1180,7 +1180,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + layer_past[2:],)
|
||||
return reordered_past
|
||||
|
||||
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||
@ -2898,7 +2898,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@ -3335,6 +3335,6 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
{% endif -%}
|
||||
|
@ -15,13 +15,14 @@
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_accelerate, require_torch, require_torch_multi_gpu, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
@ -1017,6 +1018,27 @@ class GenerationTesterMixin:
|
||||
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
|
||||
)
|
||||
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_beam_search(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._no_split_modules is None:
|
||||
continue
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).eval()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto")
|
||||
|
||||
new_model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
num_beams=2,
|
||||
)
|
||||
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
@ -2482,34 +2482,6 @@ class ModelTesterMixin:
|
||||
for value_, parallel_value_ in zip(value, parallel_value):
|
||||
self.assertTrue(torch.allclose(value_, parallel_value_.to("cpu"), atol=1e-7))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_model_parallel_beam_search(self):
|
||||
if not self.test_model_parallel:
|
||||
return
|
||||
|
||||
all_generative_and_parallelizable_model_classes = tuple(
|
||||
set(self.all_generative_model_classes).intersection(self.all_parallelizable_model_classes)
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in all_generative_and_parallelizable_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
def cast_to_device(dictionary, device):
|
||||
output = {}
|
||||
for k, v in dictionary.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
output[k] = v.to(device)
|
||||
else:
|
||||
output[k] = v
|
||||
|
||||
return output
|
||||
|
||||
model.parallelize()
|
||||
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
|
Loading…
Reference in New Issue
Block a user