From 055afdb6bb1f2e4bd5ab4870ba5b3733bb88da4e Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Thu, 20 Mar 2025 15:22:48 +0100 Subject: [PATCH] Fix ONNX export for sequence classification head (#36332) * set dtype to int32 * fix style --- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/ctrl/modeling_ctrl.py | 2 +- src/transformers/models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- src/transformers/models/gpt_neox/modular_gpt_neox.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/helium/modeling_helium.py | 2 +- src/transformers/models/jamba/modeling_jamba.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/mpt/modeling_mpt.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/openai/modeling_openai.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/phimoe/modeling_phimoe.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- src/transformers/models/zamba/modeling_zamba.py | 2 +- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- 33 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index c5783a3b13b..f91cd85779a 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -1151,7 +1151,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index f530b702870..f97adac913a 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -796,7 +796,7 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d69dcf3b59d..ab38de3a392 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1234,7 +1234,7 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1910e32f29b..61c27f5f296 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1363,7 +1363,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 807abe2b9f4..bf835a98055 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -967,7 +967,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ba16f52c354..e5f73756d89 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1060,7 +1060,7 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index d155beff771..fbefb0b1b03 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -976,7 +976,7 @@ class GlmForSequenceClassification(GlmPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 931bba9ba96..0c045094e7d 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1400,7 +1400,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d4691d01431..d388f70a777 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1288,7 +1288,7 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 13ff3cd740d..ff963fd665d 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -1121,7 +1121,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 97590faf705..fd7dac06539 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -950,7 +950,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 3a7cc49542e..8e64fd120ba 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -634,7 +634,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index effaa3b7251..57feb2c5360 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1263,7 +1263,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 40aac3b9330..7a46de8dd55 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -963,7 +963,7 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index abd0d44fe79..482ed8177a4 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1687,7 +1687,7 @@ class JambaForSequenceClassification(JambaPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 9518ef0be5d..872c3628a05 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1478,7 +1478,7 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e52d4087be1..eb37b605291 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -965,7 +965,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 07a3e7ed417..b4ff18cb117 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1042,7 +1042,7 @@ class MistralForSequenceClassification(MistralPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8fd3f7540ec..172b2d7b3e2 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1195,7 +1195,7 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index f9bfdf7c1dd..4de99990396 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -685,7 +685,7 @@ class MptForSequenceClassification(MptPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index d6d77887a2a..c23f35caa09 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1211,7 +1211,7 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index a86b86bdcac..f777d95bf89 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -812,7 +812,7 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 28134e141c9..c13f43f6967 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -1322,7 +1322,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a71169dbdda..7d6965620b8 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -1022,7 +1022,7 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c273f906283..04500e834c8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -939,7 +939,7 @@ class PhiForSequenceClassification(PhiPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 8b49ad08767..c656d56265d 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1068,7 +1068,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 66452fd943a..c419a7e534e 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1602,7 +1602,7 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index d5d922e208e..2a2b45538b9 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -967,7 +967,7 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c99c74c1fa3..a0ee19a6a3a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1444,7 +1444,7 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2f59cda241f..a1806dfe24a 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1278,7 +1278,7 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index a23dfa9d46b..a1eb407d608 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -950,7 +950,7 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 6c8b4c520b7..2a7a13780e1 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1440,7 +1440,7 @@ class ZambaForSequenceClassification(ZambaPreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index dc83935ce37..772500b42c3 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1877,7 +1877,7 @@ class Zamba2ForSequenceClassification(Zamba2PreTrainedModel): elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1