fix tests

This commit is contained in:
Kashif Rasul 2025-02-14 20:43:45 +01:00 committed by Jinan Zhou
parent ef59621e80
commit d8c2e0d74f
2 changed files with 6 additions and 5 deletions

View File

@ -314,7 +314,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [SigLIP2](https://huggingface.co/docs/transformers/model_doc/siglip2)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmModel)
* [TimesFM](https://huggingface.co/docs/transformers/model_doc/timesfm#transformers.TimesFmDecoder)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel)

View File

@ -1047,8 +1047,8 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
if return_dict:
return TimesFmOutputForPrediction(
last_hidden_state=decoder_output.last_hidden_state,
attentions=decoder_output.all_attentions if output_attentions else None,
hidden_states=decoder_output.all_hidden_states if output_hidden_states else None,
attentions=decoder_output.attentions if output_attentions else None,
hidden_states=decoder_output.hidden_states if output_hidden_states else None,
mean_predictions=mean_outputs,
full_predictions=full_outputs,
loss=loss,
@ -1056,9 +1056,9 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
else:
return_tuple = [decoder_output.last_hidden_state]
if output_hidden_states:
return_tuple.append(decoder_output.all_hidden_states)
return_tuple.append(decoder_output.hidden_states)
if output_attentions:
return_tuple.append(decoder_output.all_attentions)
return_tuple.append(decoder_output.attentions)
return_tuple += [mean_outputs, full_outputs, loss]
return tuple(return_tuple)
@ -1066,4 +1066,5 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
__all__ = [
"TimesFmModelForPrediction",
"TimesFmPreTrainedModel",
"TimesFmDecoder",
]