fix docstrings

This commit is contained in:
Kashif Rasul 2024-12-05 10:04:14 +01:00 committed by Jinan Zhou
parent b437e87b88
commit 9f0f086a34

View File

@ -185,7 +185,7 @@ class TimesFMPositionalEmbedding(nn.Module):
class TimesFMAttention(nn.Module):
"""Implements the attention used in TimesFM."""
"""Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query."""
def __init__(self, config: TimesFMConfig):
super().__init__()
@ -655,7 +655,8 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
freq: torch.Tensor,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> TimesFMDecoderOutput:
return_dict: bool = True,
) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]:
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
@ -674,13 +675,22 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
else:
all_hidden_states = None
return TimesFMDecoderOutput(
last_hidden_state=transformer_output.last_hidden_state,
hidden_states=all_hidden_states,
attentions=transformer_output.attentions if output_attentions else None,
loc=stats[0],
scale=stats[1],
)
if return_dict:
return TimesFMDecoderOutput(
last_hidden_state=transformer_output.last_hidden_state,
hidden_states=all_hidden_states,
attentions=transformer_output.attentions if output_attentions else None,
loc=stats[0],
scale=stats[1],
)
else:
return (
transformer_output.last_hidden_state,
all_hidden_states,
transformer_output.attentions,
stats[0],
stats[1],
)
class TimesFMModelForPrediction(TimesFMPreTrainedModel):
@ -778,7 +788,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
return_forecast_on_context: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
) -> tuple[torch.Tensor, ...]:
"""Auto-regressive decoding without caching.
Args:
@ -799,6 +809,9 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
Raises:
ValueError: If the paddings do not match the input + horizon_len.
"""
final_out = input_ts
context_len = final_out.shape[1]
@ -871,7 +884,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> TimesFMOutputForPrediction:
) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]:
"""Forecasts on a list of time series.
Args:
@ -887,15 +900,15 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
when available, i.e. after the first input patch.
truncate_negative: truncate to only non-negative values if all the contexts
have non-negative values.
output_attentions: Whether to return the attentions.
output_hidden_states: Whether to return the hidden states.
return_dict: Whether to return a TimesFMOutputForPrediction object.
Returns:
A tuple for Tensors:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
A TimesFMOutputForPrediction object containing:
- the mean forecast of size (# inputs, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# inputs, # forecast horizon, 1 + # quantiles).
Raises:
ValueError: If the checkpoint is not properly loaded.
"""
if return_dict is None:
return_dict = self.config.use_return_dict