mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix docstrings
This commit is contained in:
parent
b437e87b88
commit
9f0f086a34
@ -185,7 +185,7 @@ class TimesFMPositionalEmbedding(nn.Module):
|
||||
|
||||
|
||||
class TimesFMAttention(nn.Module):
|
||||
"""Implements the attention used in TimesFM."""
|
||||
"""Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query."""
|
||||
|
||||
def __init__(self, config: TimesFMConfig):
|
||||
super().__init__()
|
||||
@ -655,7 +655,8 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
|
||||
freq: torch.Tensor,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
) -> TimesFMDecoderOutput:
|
||||
return_dict: bool = True,
|
||||
) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]:
|
||||
model_input, patched_padding, stats, _ = self._preprocess_input(
|
||||
input_ts=input_ts,
|
||||
input_padding=input_padding,
|
||||
@ -674,13 +675,22 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
return TimesFMDecoderOutput(
|
||||
last_hidden_state=transformer_output.last_hidden_state,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=transformer_output.attentions if output_attentions else None,
|
||||
loc=stats[0],
|
||||
scale=stats[1],
|
||||
)
|
||||
if return_dict:
|
||||
return TimesFMDecoderOutput(
|
||||
last_hidden_state=transformer_output.last_hidden_state,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=transformer_output.attentions if output_attentions else None,
|
||||
loc=stats[0],
|
||||
scale=stats[1],
|
||||
)
|
||||
else:
|
||||
return (
|
||||
transformer_output.last_hidden_state,
|
||||
all_hidden_states,
|
||||
transformer_output.attentions,
|
||||
stats[0],
|
||||
stats[1],
|
||||
)
|
||||
|
||||
|
||||
class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
@ -778,7 +788,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
return_forecast_on_context: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Auto-regressive decoding without caching.
|
||||
|
||||
Args:
|
||||
@ -799,6 +809,9 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
B x H' x (1 + # quantiles).
|
||||
In particular, if return_forecast_on_context is True, H' is H plus
|
||||
the forecastable context length, i.e. context_len - (first) patch_len.
|
||||
|
||||
Raises:
|
||||
ValueError: If the paddings do not match the input + horizon_len.
|
||||
"""
|
||||
final_out = input_ts
|
||||
context_len = final_out.shape[1]
|
||||
@ -871,7 +884,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
) -> TimesFMOutputForPrediction:
|
||||
) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]:
|
||||
"""Forecasts on a list of time series.
|
||||
|
||||
Args:
|
||||
@ -887,15 +900,15 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||
when available, i.e. after the first input patch.
|
||||
truncate_negative: truncate to only non-negative values if all the contexts
|
||||
have non-negative values.
|
||||
output_attentions: Whether to return the attentions.
|
||||
output_hidden_states: Whether to return the hidden states.
|
||||
return_dict: Whether to return a TimesFMOutputForPrediction object.
|
||||
|
||||
Returns:
|
||||
A tuple for Tensors:
|
||||
- the mean forecast of size (# inputs, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
A TimesFMOutputForPrediction object containing:
|
||||
- the mean forecast of size (# inputs, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
(# inputs, # forecast horizon, 1 + # quantiles).
|
||||
|
||||
Raises:
|
||||
ValueError: If the checkpoint is not properly loaded.
|
||||
"""
|
||||
if return_dict is None:
|
||||
return_dict = self.config.use_return_dict
|
||||
|
Loading…
Reference in New Issue
Block a user