mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix how we compute the final non-padding token for ForSequenceClassification models (#35911)
* Fix how we compute the final non-padding token for Gemma (and probably other models) * .size() -> .shape[] * Propagating changes to other models * Propagating changes to other models * Change it for all ForSequenceClassification models * Fix batch dim * More TF fixes * Copy the TF fix around as well * Correct layer name for TFCTRL * Cleaner .to() * Clean up the nested if-else * Use argmax() instead of .max().values
This commit is contained in:
parent
531d1511f5
commit
694aaa7fbc
@ -1127,21 +1127,20 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -789,23 +789,21 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -868,42 +868,33 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.classifier(hidden_states)
|
||||
in_logits = None
|
||||
logits_shape = shape_list(logits)
|
||||
batch_size = logits_shape[0]
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
token_indices = tf.range(shape_list(input_ids)[-1])
|
||||
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
||||
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
||||
|
||||
if labels is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(input_ids)[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
|
@ -1216,17 +1216,20 @@ class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1359,21 +1359,20 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -948,17 +948,20 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1083,17 +1083,20 @@ class Gemma2ForSequenceClassification(Gemma2PreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -958,17 +958,20 @@ class GlmForSequenceClassification(GlmPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1393,25 +1393,23 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1177,39 +1177,33 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
batch_size = logits_shape[0]
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
token_indices = tf.range(shape_list(input_ids)[-1])
|
||||
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
||||
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
||||
|
||||
if labels is not None:
|
||||
assert (
|
||||
self.config.pad_token_id is not None or logits_shape[0] == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
|
@ -1280,25 +1280,23 @@ class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1101,21 +1101,20 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -935,21 +935,20 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1243,21 +1243,20 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -941,35 +941,30 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
batch_size = logits_shape[0]
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(
|
||||
sequence_lengths >= 0,
|
||||
sequence_lengths,
|
||||
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
|
||||
)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
token_indices = tf.range(shape_list(input_ids)[-1])
|
||||
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
||||
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
if labels is not None:
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
if labels is not None:
|
||||
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
|
@ -945,17 +945,20 @@ class HeliumForSequenceClassification(HeliumPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1685,17 +1685,20 @@ class JambaForSequenceClassification(JambaPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1457,17 +1457,20 @@ class JetMoeForSequenceClassification(JetMoePreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -947,17 +947,20 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1036,17 +1036,20 @@ class MistralForSequenceClassification(MistralPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -996,39 +996,30 @@ class TFMistralForSequenceClassification(TFMistralPreTrainedModel, TFSequenceCla
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
batch_size = logits_shape[0]
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(
|
||||
sequence_lengths >= 0,
|
||||
sequence_lengths,
|
||||
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
|
||||
)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
token_indices = tf.range(shape_list(input_ids)[-1])
|
||||
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
||||
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
||||
|
||||
if labels is not None:
|
||||
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
|
@ -1189,17 +1189,20 @@ class MixtralForSequenceClassification(MixtralPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -681,21 +681,20 @@ class MptForSequenceClassification(MptPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1194,17 +1194,20 @@ class NemotronForSequenceClassification(NemotronPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -805,23 +805,21 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
|
||||
# Ensure the batch size is > 1 if there is no padding.
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -876,43 +876,33 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
in_logits = None
|
||||
logits_shape = shape_list(logits)
|
||||
batch_size = logits_shape[0]
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
token_indices = tf.range(shape_list(input_ids)[-1])
|
||||
non_pad_mask = tf.cast(input_ids != self.config.pad_token_id, token_indices.dtype)
|
||||
last_non_pad_token = tf.reduce_max(token_indices * non_pad_mask, axis=-1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = tf.fill((batch_size,), value=logits_shape[1] - 1)
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
pooled_logits = tf.gather(logits, last_non_pad_token, batch_dims=1, axis=1)
|
||||
|
||||
if labels is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(input_ids)[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs_embeds)[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None and logits_shape[0] != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
|
@ -1295,22 +1295,23 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1009,17 +1009,20 @@ class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -921,17 +921,20 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1057,17 +1057,20 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1597,17 +1597,20 @@ class PhimoeForSequenceClassification(PhimoePreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -932,17 +932,20 @@ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1438,17 +1438,20 @@ class Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1265,17 +1265,20 @@ class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -944,17 +944,20 @@ class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1438,17 +1438,20 @@ class ZambaForSequenceClassification(ZambaPreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1875,17 +1875,20 @@ class Zamba2ForSequenceClassification(Zamba2PreTrainedModel):
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user