mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix past_key_values
type hint in model output types (#37953)
* F: Fix type hint. * F: Use Cache type. * F: Sort import. * U: Format. * U: Address reviews.
This commit is contained in:
parent
07feaad8fb
commit
67b3d45eb6
@ -464,7 +464,7 @@ class DynamicCache(Cache):
|
||||
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
|
||||
return None
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
|
||||
backward compatibility."""
|
||||
legacy_cache = ()
|
||||
@ -473,7 +473,9 @@ class DynamicCache(Cache):
|
||||
return legacy_cache
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None
|
||||
) -> "DynamicCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls()
|
||||
@ -1505,8 +1507,8 @@ class EncoderDecoderCache(Cache):
|
||||
"""
|
||||
return len(self.self_attention_cache)
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||
legacy_cache = ()
|
||||
if len(self.cross_attention_cache) > 0:
|
||||
for self_attn, cross_attn in zip(
|
||||
|
@ -18,6 +18,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .cache_utils import Cache, EncoderDecoderCache
|
||||
from .utils import ModelOutput
|
||||
|
||||
|
||||
@ -131,11 +132,8 @@ class BaseModelOutputWithPast(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -154,7 +152,7 @@ class BaseModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -222,11 +220,8 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
|
||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -236,7 +231,7 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
pooler_output: Optional[torch.FloatTensor] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -252,11 +247,8 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -281,7 +273,7 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -298,9 +290,8 @@ class MoECausalLMOutputWithPast(ModelOutput):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -328,7 +319,7 @@ class MoECausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
z_loss: Optional[torch.FloatTensor] = None
|
||||
@ -376,11 +367,8 @@ class MoeModelOutputWithPast(ModelOutput):
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -404,7 +392,7 @@ class MoeModelOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -431,9 +419,8 @@ class MoeCausalLMOutputWithPast(ModelOutput):
|
||||
Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
|
||||
loss for Mixture of Experts models.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -453,7 +440,7 @@ class MoeCausalLMOutputWithPast(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
aux_loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -471,11 +458,8 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||
encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
|
||||
@ -505,7 +489,7 @@ class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -524,10 +508,8 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -564,7 +546,7 @@ class Seq2SeqModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -585,10 +567,8 @@ class Seq2SeqMoEModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -634,7 +614,7 @@ class Seq2SeqMoEModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -684,9 +664,8 @@ class CausalLMOutputWithPast(ModelOutput):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -705,7 +684,7 @@ class CausalLMOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -737,10 +716,8 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
|
||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
cross-attention heads.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key,
|
||||
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
|
||||
setting. Only relevant if `config.is_decoder = True`.
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -748,7 +725,7 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -764,9 +741,8 @@ class SequenceClassifierOutputWithPast(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
@ -785,7 +761,7 @@ class SequenceClassifierOutputWithPast(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[Cache] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
@ -829,10 +805,8 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
Language modeling loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -870,7 +844,7 @@ class Seq2SeqLMOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -889,10 +863,8 @@ class Seq2SeqMoEOutput(ModelOutput):
|
||||
Language modeling loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -943,7 +915,7 @@ class Seq2SeqMoEOutput(ModelOutput):
|
||||
decoder_z_loss: Optional[torch.FloatTensor] = None
|
||||
encoder_aux_loss: Optional[torch.FloatTensor] = None
|
||||
decoder_aux_loss: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@ -1023,10 +995,8 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1064,7 +1034,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1177,10 +1147,8 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
Span-start scores (before SoftMax).
|
||||
end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||||
Span-end scores (before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1219,7 +1187,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
start_logits: Optional[torch.FloatTensor] = None
|
||||
end_logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1508,10 +1476,8 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
|
||||
Spectrogram generation loss.
|
||||
spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
|
||||
The predicted spectrogram.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1549,7 +1515,7 @@ class Seq2SeqSpectrogramOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
spectrogram: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1570,10 +1536,8 @@ class Seq2SeqTSModelOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
||||
hidden_size)` is output.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1618,7 +1582,7 @@ class Seq2SeqTSModelOutput(ModelOutput):
|
||||
"""
|
||||
|
||||
last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
@ -1641,10 +1605,8 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
|
||||
Distributional loss.
|
||||
params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
|
||||
Parameters of the chosen distribution.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
@ -1690,7 +1652,7 @@ class Seq2SeqTSPredictionOutput(ModelOutput):
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
params: Optional[Tuple[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
past_key_values: Optional[EncoderDecoderCache] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
Loading…
Reference in New Issue
Block a user