mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
parent
2f517200c1
commit
dc8227827d
@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -44,8 +43,6 @@ from .configuration_timesfm import TimesFmConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/timesfm-2.0-500m-pytorch"
|
||||
_CONFIG_FOR_DOC = "TimesFmConfig"
|
||||
|
||||
|
||||
@ -734,11 +731,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TimesFmOutputForPrediction,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
past_values: Sequence[torch.Tensor],
|
||||
@ -752,28 +744,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TimesFmOutputForPrediction:
|
||||
r"""
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
|
||||
Returns:
|
||||
A TimesFmOutputForPrediction object or a tuple containing:
|
||||
- the mean forecast of size (# past_values, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import TimesFmModelForPrediction
|
||||
|
||||
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
|
||||
|
||||
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
|
||||
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
|
||||
>>> # Generate
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
|
||||
>>> point_forecast_conv = outputs.mean_predictions
|
||||
>>> quantile_forecast_conv = outputs.full_predictions
|
||||
```
|
||||
"""
|
||||
if forecast_context_len is None:
|
||||
fcontext_len = self.context_len
|
||||
|
@ -27,7 +27,6 @@ from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
@ -690,11 +689,6 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
@can_return_tuple
|
||||
@add_start_docstrings_to_model_forward(TIMESFM_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TimesFmOutputForPrediction, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TimesFmOutputForPrediction,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
past_values: Sequence[torch.Tensor],
|
||||
@ -708,28 +702,40 @@ class TimesFmModelForPrediction(TimesFmPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TimesFmOutputForPrediction:
|
||||
r"""
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
window_size (`int`, *optional*):
|
||||
Window size of trend + residual decomposition. If None then we do not do decomposition.
|
||||
future_values (`torch.Tensor`, *optional*):
|
||||
Optional future time series values to be used for loss computation.
|
||||
forecast_context_len (`int`, *optional*):
|
||||
Optional max context length.
|
||||
return_forecast_on_context (`bool`, *optional*):
|
||||
True to return the forecast on the context when available, i.e. after the first input patch.
|
||||
truncate_negative (`bool`, *optional*):
|
||||
Truncate to only non-negative values if any of the contexts have non-negative values,
|
||||
otherwise do nothing.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether to output the attentions.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether to output the hidden states.
|
||||
|
||||
Returns:
|
||||
A TimesFmOutputForPrediction object or a tuple containing:
|
||||
- the mean forecast of size (# past_values, # forecast horizon),
|
||||
- the full forecast (mean + quantiles) of size
|
||||
(# past_values, # forecast horizon, 1 + # quantiles).
|
||||
- loss: the mean squared error loss + quantile loss if `future_values` is provided.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import TimesFmModelForPrediction
|
||||
|
||||
>>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")
|
||||
|
||||
>>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
|
||||
>>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)
|
||||
|
||||
>>> # Generate
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
|
||||
>>> point_forecast_conv = outputs.mean_predictions
|
||||
>>> quantile_forecast_conv = outputs.full_predictions
|
||||
```
|
||||
"""
|
||||
if forecast_context_len is None:
|
||||
fcontext_len = self.context_len
|
||||
|
Loading…
Reference in New Issue
Block a user