Fix how we compute the final non-padding token for ForSequenceClassification models (#35911)

* Fix how we compute the final non-padding token for Gemma (and probably other models)

* .size() -> .shape[]

* Propagating changes to other models

* Propagating changes to other models

* Change it for all ForSequenceClassification models

* Fix batch dim

* More TF fixes

* Copy the TF fix around as well

* Correct layer name for TFCTRL

* Cleaner .to()

* Clean up the nested if-else

* Use argmax() instead of .max().values
This commit is contained in:
Matt 2025-02-05 16:23:33 +00:00 committed by GitHub
parent 531d1511f5
commit 694aaa7fbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 449 additions and 438 deletions

View File

@ -1127,21 +1127,20 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -789,23 +789,21 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -868,42 +868,33 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
return_dict=return_dict,
training=training,
)
hidden_states = transformer_outputs[0]
logits = self.classifier(hidden_states)
in_logits = None
logits_shape = shape_list(logits)
batch_size = logits_shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
else:
if input_ids is not None:
sequence_lengths = (
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
token_indices = tf.range(shape_list(input_ids)[-1])
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
else:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
if labels is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(input_ids)[:2]
else:
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
if self.config.pad_token_id is None and batch_size != 1:
if self.config.pad_token_id is None and logits_shape[0] != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]

View File

@ -1216,17 +1216,20 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1359,21 +1359,20 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -948,17 +948,20 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1083,17 +1083,20 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -958,17 +958,20 @@ class GlmForSequenceClassification(GlmPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1393,25 +1393,23 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1177,39 +1177,33 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
return_dict=return_dict,
training=training,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
batch_size = logits_shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
else:
if input_ids is not None:
sequence_lengths = (
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
token_indices = tf.range(shape_list(input_ids)[-1])
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
else:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
if labels is not None:
assert (
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None and logits_shape[0] != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]

View File

@ -1280,25 +1280,23 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1101,21 +1101,20 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -935,21 +935,20 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1243,21 +1243,20 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -941,35 +941,30 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
batch_size = logits_shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
else:
if input_ids is not None:
sequence_lengths = (
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(
sequence_lengths >= 0,
sequence_lengths,
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
token_indices = tf.range(shape_list(input_ids)[-1])
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
else:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
if labels is not None:
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths]
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
if labels is not None:
if self.config.pad_token_id is None and logits_shape[0] != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]

View File

@ -945,17 +945,20 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1685,17 +1685,20 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1457,17 +1457,20 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -947,17 +947,20 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1036,17 +1036,20 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -996,39 +996,30 @@ class TFMistralForSequenceClassification(TFMistralPreTrainedModel, TFSequenceCla
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
batch_size = logits_shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
else:
if input_ids is not None:
sequence_lengths = (
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(
sequence_lengths >= 0,
sequence_lengths,
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
token_indices = tf.range(shape_list(input_ids)[-1])
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
else:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
if labels is not None:
if self.config.pad_token_id is None and logits_shape[0] != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]

View File

@ -1189,17 +1189,20 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -681,21 +681,20 @@ class MptForSequenceClassification(MptPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1194,17 +1194,20 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -805,23 +805,21 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
# Ensure the batch size is > 1 if there is no padding.
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -876,43 +876,33 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
return_dict=return_dict,
training=training,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
in_logits = None
logits_shape = shape_list(logits)
batch_size = logits_shape[0]
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
else:
if input_ids is not None:
sequence_lengths = (
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
token_indices = tf.range(shape_list(input_ids)[-1])
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
else:
sequence_lengths = -1
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
if labels is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(input_ids)[:2]
else:
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None and logits_shape[0] != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]

View File

@ -1295,22 +1295,23 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1009,17 +1009,20 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -921,17 +921,20 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1057,17 +1057,20 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1597,17 +1597,20 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -932,17 +932,20 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1438,17 +1438,20 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1265,17 +1265,20 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -944,17 +944,20 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1438,17 +1438,20 @@ class ZambaForSequenceClassification(ZambaPreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:

View File

@ -1875,17 +1875,20 @@ class Zamba2ForSequenceClassification(Zamba2PreTrainedModel):
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None: