Fix TimesFm doc issue (#37552)

* fix doc

* code block
This commit is contained in:
Cyril Vallez 2025-04-16 16:28:42 +02:00 committed by GitHub
parent 2f517200c1
commit dc8227827d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 54 deletions

View File

@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -44,8 +43,6 @@ from .configuration_timesfm import TimesFmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
_CONFIG_FOR_DOC = "TimesFmConfig"
@ -734,11 +731,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TimesFmOutputForPrediction,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: Sequence[torch.Tensor],
@ -752,28 +744,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
output_hidden_states: Optional[bool] = None,
) -> TimesFmOutputForPrediction:
r"""
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
Returns:
A TimesFmOutputForPrediction object or a tuple containing:
- the mean forecast of size (# past_values, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# past_values, # forecast horizon, 1 + # quantiles).
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
Example:
```python
>>> from transformers import TimesFmModelForPrediction
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
>>> # Generate
>>> with torch.no_grad():
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
>>> point_forecast_conv = outputs.mean_predictions
>>> quantile_forecast_conv = outputs.full_predictions
```
"""
if forecast_context_len is None:
fcontext_len = self.context_len

View File

@ -27,7 +27,6 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -690,11 +689,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
@can_return_tuple
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TimesFmOutputForPrediction,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
past_values: Sequence[torch.Tensor],
@ -708,28 +702,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
output_hidden_states: Optional[bool] = None,
) -> TimesFmOutputForPrediction:
r"""
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
window_size (`int`, *optional*):
Window size of trend + residual decomposition. If None then we do not do decomposition.
future_values (`torch.Tensor`, *optional*):
Optional future time series values to be used for loss computation.
forecast_context_len (`int`, *optional*):
Optional max context length.
return_forecast_on_context (`bool`, *optional*):
True to return the forecast on the context when available, i.e. after the first input patch.
truncate_negative (`bool`, *optional*):
Truncate to only non-negative values if any of the contexts have non-negative values,
otherwise do nothing.
output_attentions (`bool`, *optional*):
Whether to output the attentions.
output_hidden_states (`bool`, *optional*):
Whether to output the hidden states.
Returns:
A TimesFmOutputForPrediction object or a tuple containing:
- the mean forecast of size (# past_values, # forecast horizon),
- the full forecast (mean + quantiles) of size
(# past_values, # forecast horizon, 1 + # quantiles).
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
Example:
```python
>>> from transformers import TimesFmModelForPrediction
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
>>> # Generate
>>> with torch.no_grad():
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
>>> point_forecast_conv = outputs.mean_predictions
>>> quantile_forecast_conv = outputs.full_predictions
```
"""
if forecast_context_len is None:
fcontext_len = self.context_len