From a2ef9c5446b65c9d20c06ab6940f89fcbee89382 Mon Sep 17 00:00:00 2001 From: Tommy Chiang Date: Fri, 24 Sep 2021 16:31:23 +0800 Subject: [PATCH] Use torch.unique_consecutive to check same element (#13637) We use `torch.unique` here only to check whether every elements have the same value. Therefore, we can use `torch.unique_consecutive` here. This function eliminates all but the first element from every consecutive group of equivalent elements. Like, if we apply this function to `[1, 2, 2, 1]`, it will result in `[1, 2, 1]`. As you could see, this is enough for checking whether every elements have the same value. Since `torch.unique_consecutive` do less thing, it is much more faster. On my computer, it is 25x faster on GPU and 15x faster on CPU. --- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 2 +- src/transformers/models/led/modeling_led.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 2 +- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 134669cee4b..08315557d20 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel): eos_mask = input_ids.eq(self.config.eos_token_id) - if len(torch.unique(eos_mask.sum(1))) > 1: + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ :, -1, : diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 536cd784daa..684e14af8ea 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): eos_mask = input_ids.eq(self.config.eos_token_id) - if len(torch.unique(eos_mask.sum(1))) > 1: + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ :, -1, : diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 926da161a97..863e96c053a 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): eos_mask = input_ids.eq(self.config.eos_token_id) - if len(torch.unique(eos_mask.sum(1))) > 1: + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ :, -1, : diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0ebb5a1a8f3..a45e4eb808a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): eos_mask = input_ids.eq(self.config.eos_token_id) - if len(torch.unique(eos_mask.sum(1))) > 1: + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ :, -1, : diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index b0482f70621..94c107fb205 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt eos_mask = input_ids.eq(self.config.eos_token_id) - if len(torch.unique(eos_mask.sum(1))) > 1: + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ :, -1, :