From 9f0f086a345901b9f126b0729b4f300b7435c0e4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 5 Dec 2024 10:04:14 +0100 Subject: [PATCH] fix docstrings --- .../models/timesfm/modeling_timesfm.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 7b664ffd052..74779f6a158 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -185,7 +185,7 @@ class TimesFMPositionalEmbedding(nn.Module): class TimesFMAttention(nn.Module): - """Implements the attention used in TimesFM.""" + """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query.""" def __init__(self, config: TimesFMConfig): super().__init__() @@ -655,7 +655,8 @@ class TimesFMDecoder(TimesFMPreTrainedModel): freq: torch.Tensor, output_attentions: bool = False, output_hidden_states: bool = False, - ) -> TimesFMDecoderOutput: + return_dict: bool = True, + ) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]: model_input, patched_padding, stats, _ = self._preprocess_input( input_ts=input_ts, input_padding=input_padding, @@ -674,13 +675,22 @@ class TimesFMDecoder(TimesFMPreTrainedModel): else: all_hidden_states = None - return TimesFMDecoderOutput( - last_hidden_state=transformer_output.last_hidden_state, - hidden_states=all_hidden_states, - attentions=transformer_output.attentions if output_attentions else None, - loc=stats[0], - scale=stats[1], - ) + if return_dict: + return TimesFMDecoderOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=all_hidden_states, + attentions=transformer_output.attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + else: + return ( + transformer_output.last_hidden_state, + all_hidden_states, + transformer_output.attentions, + stats[0], + stats[1], + ) class TimesFMModelForPrediction(TimesFMPreTrainedModel): @@ -778,7 +788,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel): return_forecast_on_context: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, - ): + ) -> tuple[torch.Tensor, ...]: """Auto-regressive decoding without caching. Args: @@ -799,6 +809,9 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel): B x H' x (1 + # quantiles). In particular, if return_forecast_on_context is True, H' is H plus the forecastable context length, i.e. context_len - (first) patch_len. + + Raises: + ValueError: If the paddings do not match the input + horizon_len. """ final_out = input_ts context_len = final_out.shape[1] @@ -871,7 +884,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel): output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, - ) -> TimesFMOutputForPrediction: + ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]: """Forecasts on a list of time series. Args: @@ -887,15 +900,15 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel): when available, i.e. after the first input patch. truncate_negative: truncate to only non-negative values if all the contexts have non-negative values. + output_attentions: Whether to return the attentions. + output_hidden_states: Whether to return the hidden states. + return_dict: Whether to return a TimesFMOutputForPrediction object. Returns: - A tuple for Tensors: - - the mean forecast of size (# inputs, # forecast horizon), - - the full forecast (mean + quantiles) of size + A TimesFMOutputForPrediction object containing: + - the mean forecast of size (# inputs, # forecast horizon), + - the full forecast (mean + quantiles) of size (# inputs, # forecast horizon, 1 + # quantiles). - - Raises: - ValueError: If the checkpoint is not properly loaded. """ if return_dict is None: return_dict = self.config.use_return_dict