mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix docstrings
This commit is contained in:
parent
b437e87b88
commit
9f0f086a34
@ -185,7 +185,7 @@ class TimesFMPositionalEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TimesFMAttention(nn.Module):
|
class TimesFMAttention(nn.Module):
|
||||||
"""Implements the attention used in TimesFM."""
|
"""Implements the attention used in TimesFM. One key diffrence is that there is _per_dim_scaling of the query."""
|
||||||
|
|
||||||
def __init__(self, config: TimesFMConfig):
|
def __init__(self, config: TimesFMConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -655,7 +655,8 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
|
|||||||
freq: torch.Tensor,
|
freq: torch.Tensor,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
) -> TimesFMDecoderOutput:
|
return_dict: bool = True,
|
||||||
|
) -> TimesFMDecoderOutput | tuple[torch.Tensor, ...]:
|
||||||
model_input, patched_padding, stats, _ = self._preprocess_input(
|
model_input, patched_padding, stats, _ = self._preprocess_input(
|
||||||
input_ts=input_ts,
|
input_ts=input_ts,
|
||||||
input_padding=input_padding,
|
input_padding=input_padding,
|
||||||
@ -674,13 +675,22 @@ class TimesFMDecoder(TimesFMPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
|
|
||||||
return TimesFMDecoderOutput(
|
if return_dict:
|
||||||
last_hidden_state=transformer_output.last_hidden_state,
|
return TimesFMDecoderOutput(
|
||||||
hidden_states=all_hidden_states,
|
last_hidden_state=transformer_output.last_hidden_state,
|
||||||
attentions=transformer_output.attentions if output_attentions else None,
|
hidden_states=all_hidden_states,
|
||||||
loc=stats[0],
|
attentions=transformer_output.attentions if output_attentions else None,
|
||||||
scale=stats[1],
|
loc=stats[0],
|
||||||
)
|
scale=stats[1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
transformer_output.last_hidden_state,
|
||||||
|
all_hidden_states,
|
||||||
|
transformer_output.attentions,
|
||||||
|
stats[0],
|
||||||
|
stats[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
||||||
@ -778,7 +788,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
|||||||
return_forecast_on_context: bool = False,
|
return_forecast_on_context: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
):
|
) -> tuple[torch.Tensor, ...]:
|
||||||
"""Auto-regressive decoding without caching.
|
"""Auto-regressive decoding without caching.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -799,6 +809,9 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
|||||||
B x H' x (1 + # quantiles).
|
B x H' x (1 + # quantiles).
|
||||||
In particular, if return_forecast_on_context is True, H' is H plus
|
In particular, if return_forecast_on_context is True, H' is H plus
|
||||||
the forecastable context length, i.e. context_len - (first) patch_len.
|
the forecastable context length, i.e. context_len - (first) patch_len.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the paddings do not match the input + horizon_len.
|
||||||
"""
|
"""
|
||||||
final_out = input_ts
|
final_out = input_ts
|
||||||
context_len = final_out.shape[1]
|
context_len = final_out.shape[1]
|
||||||
@ -871,7 +884,7 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
|||||||
output_attentions: bool | None = None,
|
output_attentions: bool | None = None,
|
||||||
output_hidden_states: bool | None = None,
|
output_hidden_states: bool | None = None,
|
||||||
return_dict: bool | None = None,
|
return_dict: bool | None = None,
|
||||||
) -> TimesFMOutputForPrediction:
|
) -> TimesFMOutputForPrediction | tuple[torch.Tensor, ...]:
|
||||||
"""Forecasts on a list of time series.
|
"""Forecasts on a list of time series.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -887,15 +900,15 @@ class TimesFMModelForPrediction(TimesFMPreTrainedModel):
|
|||||||
when available, i.e. after the first input patch.
|
when available, i.e. after the first input patch.
|
||||||
truncate_negative: truncate to only non-negative values if all the contexts
|
truncate_negative: truncate to only non-negative values if all the contexts
|
||||||
have non-negative values.
|
have non-negative values.
|
||||||
|
output_attentions: Whether to return the attentions.
|
||||||
|
output_hidden_states: Whether to return the hidden states.
|
||||||
|
return_dict: Whether to return a TimesFMOutputForPrediction object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple for Tensors:
|
A TimesFMOutputForPrediction object containing:
|
||||||
- the mean forecast of size (# inputs, # forecast horizon),
|
- the mean forecast of size (# inputs, # forecast horizon),
|
||||||
- the full forecast (mean + quantiles) of size
|
- the full forecast (mean + quantiles) of size
|
||||||
(# inputs, # forecast horizon, 1 + # quantiles).
|
(# inputs, # forecast horizon, 1 + # quantiles).
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the checkpoint is not properly loaded.
|
|
||||||
"""
|
"""
|
||||||
if return_dict is None:
|
if return_dict is None:
|
||||||
return_dict = self.config.use_return_dict
|
return_dict = self.config.use_return_dict
|
||||||
|
Loading…
Reference in New Issue
Block a user