fix docstrings

This commit is contained in:
Kashif Rasul 2024-12-05 10:04:14 +01:00 committed by Jinan Zhou
parent b437e87b88
commit 9f0f086a34

View File

@ -185,7 +185,7 @@ class TimesFMPositionalEmbedding(nn.Module):
class TimesFMAttention(nn.Module): class TimesFMAttention(nn.Module):
"""Implements the attention used in TimesFM.""" """Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query."""
def __init__(self, config: TimesFMConfig): def __init__(self, config: TimesFMConfig):
super().__init__() super().__init__()
@ -655,7 +655,8 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
freq: torch.Tensor, freq: torch.Tensor,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
) -> TimesFMDecoderOutput: return_dict: bool = True,
) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]:
model_input, patched_padding, stats, _ = self._preprocess_input( model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts, input_ts=input_ts,
input_padding=input_padding, input_padding=input_padding,
@ -674,13 +675,22 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
else: else:
all_hidden_states = None all_hidden_states = None
return TimesFMDecoderOutput( if return_dict:
last_hidden_state=transformer_output.last_hidden_state, return TimesFMDecoderOutput(
hidden_states=all_hidden_states, last_hidden_state=transformer_output.last_hidden_state,
attentions=transformer_output.attentions if output_attentions else None, hidden_states=all_hidden_states,
loc=stats[0], attentions=transformer_output.attentions if output_attentions else None,
scale=stats[1], loc=stats[0],
) scale=stats[1],
)
else:
return (
transformer_output.last_hidden_state,
all_hidden_states,
transformer_output.attentions,
stats[0],
stats[1],
)
class TimesFMModelForPrediction(TimesFMPreTrainedModel): class TimesFMModelForPrediction(TimesFMPreTrainedModel):
@ -778,7 +788,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
return_forecast_on_context: bool = False, return_forecast_on_context: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
): ) -> tuple[torch.Tensor, ...]:
"""Auto-regressive decoding without caching. """Auto-regressive decoding without caching.
Args: Args:
@ -799,6 +809,9 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
B x H' x (1 + # quantiles). B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len. the forecastable context length, i.e. context_len - (first) patch_len.
Raises:
ValueError: If the paddings do not match the input + horizon_len.
""" """
final_out = input_ts final_out = input_ts
context_len = final_out.shape[1] context_len = final_out.shape[1]
@ -871,7 +884,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
output_attentions: bool | None = None, output_attentions: bool | None = None,
output_hidden_states: bool | None = None, output_hidden_states: bool | None = None,
return_dict: bool | None = None, return_dict: bool | None = None,
) -> TimesFMOutputForPrediction: ) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]:
"""Forecasts on a list of time series. """Forecasts on a list of time series.
Args: Args:
@ -887,15 +900,15 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
when available, i.e. after the first input patch. when available, i.e. after the first input patch.
truncate_negative: truncate to only non-negative values if all the contexts truncate_negative: truncate to only non-negative values if all the contexts
have non-negative values. have non-negative values.
output_attentions: Whether to return the attentions.
output_hidden_states: Whether to return the hidden states.
return_dict: Whether to return a TimesFMOutputForPrediction object.
Returns: Returns:
A tuple for Tensors: A TimesFMOutputForPrediction object containing:
- the mean forecast of size (# inputs, # forecast horizon), - the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size - the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles). (# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
""" """
if return_dict is None: if return_dict is None:
return_dict = self.config.use_return_dict return_dict = self.config.use_return_dict