mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
some fixes, loss_kwargs should never had been
This commit is contained in:
parent
0b119ffb1f
commit
f7a1f0da3d
@ -1170,7 +1170,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1203,7 +1203,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
|
||||
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
@ -1672,7 +1672,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
|
||||
r"""
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1777,7 +1777,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
denoising_meta_values=denoising_meta_values,
|
||||
predicted_corners=predicted_corners,
|
||||
initial_reference_points=initial_reference_points,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -549,7 +549,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -594,7 +594,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
@ -608,7 +608,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -498,7 +498,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Example:
|
||||
@ -538,7 +538,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
@ -552,7 +552,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -654,7 +654,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -699,7 +699,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
@ -713,7 +713,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -1826,7 +1826,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1871,7 +1871,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
@ -1885,7 +1885,7 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -1386,7 +1386,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
|
@ -912,7 +912,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -961,7 +961,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
|
@ -1043,7 +1043,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple, MoeCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1099,7 +1099,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
|
@ -1272,7 +1272,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1331,7 +1331,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
|
@ -1118,7 +1118,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1172,7 +1172,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
|
@ -1860,7 +1860,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
|
||||
r"""
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1968,7 +1968,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
|
||||
denoising_meta_values=denoising_meta_values,
|
||||
predicted_corners=predicted_corners,
|
||||
initial_reference_points=initial_reference_points,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -1853,7 +1853,7 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.FloatTensor], RTDetrV2ObjectDetectionOutput]:
|
||||
r"""
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1961,7 +1961,7 @@ class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel):
|
||||
denoising_meta_values=denoising_meta_values,
|
||||
predicted_corners=predicted_corners,
|
||||
initial_reference_points=initial_reference_points,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
@ -402,7 +402,6 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1062,7 +1061,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1100,7 +1099,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = decoder_outputs.last_hidden_state
|
||||
@ -1116,7 +1115,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Input has right-shifted so we directly perform masked lm loss
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
@ -1179,6 +1178,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1216,6 +1216,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1226,6 +1227,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
@ -1318,6 +1320,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1355,6 +1358,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1365,6 +1369,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
|
@ -374,7 +374,6 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer):
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
@ -924,7 +923,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -962,7 +961,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = decoder_outputs.last_hidden_state
|
||||
@ -978,7 +977,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Input has right-shifted so we directly perform masked lm loss
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
@ -1041,6 +1040,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1078,6 +1078,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1088,6 +1089,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
@ -1180,6 +1182,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
@ -1217,6 +1220,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=False,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
@ -1227,6 +1231,7 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
hidden_states = outputs.hidden_states
|
||||
|
@ -1102,7 +1102,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1155,7 +1155,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -1527,7 +1527,7 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs,
|
||||
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1580,7 +1580,7 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
@ -3832,7 +3832,7 @@ class Trainer:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
inputs = {**inputs, **kwargs}
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
|
Loading…
Reference in New Issue
Block a user