some fixes, loss_kwargs should never had been

This commit is contained in:
Arthur 2025-07-01 15:19:32 +02:00
parent 0b119ffb1f
commit f7a1f0da3d
18 changed files with 52 additions and 42 deletions

View File

@ -1170,7 +1170,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
**kwargs,
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1203,7 +1203,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
lm_loss = None
if labels is not None:
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **kwargs)
if not return_dict:
output = (prediction_scores,) + outputs[2:]

View File

@ -1672,7 +1672,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
**kwargs,
) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -1777,7 +1777,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
denoising_meta_values=denoising_meta_values,
predicted_corners=predicted_corners,
initial_reference_points=initial_reference_points,
**loss_kwargs,
**kwargs,
)
if not return_dict:

View File

@ -549,7 +549,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -594,7 +594,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = outputs.last_hidden_state
@ -608,7 +608,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -498,7 +498,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Example:
@ -538,7 +538,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = outputs.last_hidden_state
@ -552,7 +552,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -654,7 +654,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -699,7 +699,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = outputs.last_hidden_state
@ -713,7 +713,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -1826,7 +1826,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1871,7 +1871,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = outputs.last_hidden_state
@ -1885,7 +1885,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -1386,7 +1386,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:

View File

@ -912,7 +912,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -961,7 +961,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,

View File

@ -1043,7 +1043,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> Union[tuple, MoeCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1099,7 +1099,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1272,7 +1272,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1331,7 +1331,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1118,7 +1118,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1172,7 +1172,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1860,7 +1860,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
**kwargs,
) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -1968,7 +1968,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
denoising_meta_values=denoising_meta_values,
predicted_corners=predicted_corners,
initial_reference_points=initial_reference_points,
**loss_kwargs,
**kwargs,
)
if not return_dict:

View File

@ -1853,7 +1853,7 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
**kwargs,
) -> Union[tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -1961,7 +1961,7 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
denoising_meta_values=denoising_meta_values,
predicted_corners=predicted_corners,
initial_reference_points=initial_reference_points,
**loss_kwargs,
**kwargs,
)
if not return_dict:

View File

@ -402,7 +402,6 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
past_key_value=None,
**kwargs,
)
@ -1062,7 +1061,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -1100,7 +1099,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = decoder_outputs.last_hidden_state
@ -1116,7 +1115,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
# Input has right-shifted so we directly perform masked lm loss
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return Seq2SeqLMOutput(
loss=loss,
@ -1179,6 +1178,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutput:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -1216,6 +1216,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=False,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.decoder_hidden_states
@ -1226,6 +1227,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.hidden_states
@ -1318,6 +1320,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -1355,6 +1358,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=False,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.decoder_hidden_states
@ -1365,6 +1369,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.hidden_states

View File

@ -374,7 +374,6 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
past_key_value=None,
**kwargs,
)
@ -924,7 +923,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -962,7 +961,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**loss_kwargs,
**kwargs,
)
hidden_states = decoder_outputs.last_hidden_state
@ -978,7 +977,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
# Input has right-shifted so we directly perform masked lm loss
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
return Seq2SeqLMOutput(
loss=loss,
@ -1041,6 +1040,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> SequenceClassifierOutput:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -1078,6 +1078,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=False,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.decoder_hidden_states
@ -1088,6 +1089,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.hidden_states
@ -1180,6 +1182,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
r"""
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
@ -1217,6 +1220,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=False,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.decoder_hidden_states
@ -1227,6 +1231,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
last_hidden_state = outputs.last_hidden_state
hidden_states = outputs.hidden_states

View File

@ -1102,7 +1102,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1155,7 +1155,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1527,7 +1527,7 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -1580,7 +1580,7 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -3832,7 +3832,7 @@ class Trainer:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
inputs = {**inputs, **kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.