mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix repo consistency (#36063)
* fix 1 * fix 2 * fix modular * simplify at the same time --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
parent
ed98ad35e6
commit
37faa97d9b
@ -927,11 +927,7 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
||||
hidden_states = outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
|
@ -625,29 +625,24 @@ class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
||||
hidden_states = outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
batch_size = logits.shape[0]
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
if input_ids is not None:
|
||||
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||
sequence_lengths = sequence_lengths.to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user