mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Deprecate old past arguments (#5671)
This commit is contained in:
parent
cdf4cd7068
commit
df983b7483
@ -690,7 +690,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -111,6 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||
See diagram 1 in the paper for more info on the default strategy
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding.
|
||||
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
|
||||
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
||||
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
|
||||
``decoder_past_key_values``).
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
@ -482,7 +491,7 @@ class BartDecoder(nn.Module):
|
||||
encoder_padding_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
decoder_cached_states=None,
|
||||
decoder_past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
@ -499,7 +508,7 @@ class BartDecoder(nn.Module):
|
||||
encoder_hidden_states: output from the encoder, used for
|
||||
encoder-side attention
|
||||
encoder_padding_mask: for ignoring pad tokens
|
||||
decoder_cached_states (dict or None): dictionary used for storing state during generation
|
||||
decoder_past_key_values (dict or None): dictionary used for storing state during generation
|
||||
|
||||
Returns:
|
||||
BaseModelOutputWithPast or tuple:
|
||||
@ -508,6 +517,13 @@ class BartDecoder(nn.Module):
|
||||
- hidden states
|
||||
- attentions
|
||||
"""
|
||||
if "decoder_cached_states" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||
|
||||
# check attention mask and invert
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||
@ -541,7 +557,7 @@ class BartDecoder(nn.Module):
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
|
||||
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
|
||||
|
||||
x, layer_self_attn, layer_past = decoder_layer(
|
||||
x,
|
||||
@ -854,11 +870,12 @@ class BartModel(PretrainedBartModel):
|
||||
decoder_input_ids=None,
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
decoder_past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if decoder_input_ids is None:
|
||||
@ -908,7 +925,7 @@ class BartModel(PretrainedBartModel):
|
||||
attention_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -977,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
decoder_past_key_values=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@ -1015,9 +1032,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
if "lm_labels" in unused:
|
||||
warnings.warn(
|
||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = unused.pop("lm_labels")
|
||||
if "decoder_cached_states" in unused:
|
||||
warnings.warn(
|
||||
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
|
||||
if labels is not None:
|
||||
@ -1029,7 +1052,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@ -1061,11 +1084,11 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
encoder_outputs, decoder_past_key_values = past
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_past_key_values": decoder_past_key_values,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
@ -1092,9 +1115,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
((enc_out, enc_mask), decoder_cached_states) = past
|
||||
((enc_out, enc_mask), decoder_past_key_values) = past
|
||||
reordered_past = []
|
||||
for layer_past in decoder_cached_states:
|
||||
for layer_past in decoder_past_key_values:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
|
@ -879,7 +879,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
@ -1076,7 +1076,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -246,20 +247,22 @@ CTRL_START_DOCSTRING = r"""
|
||||
CTRL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
|
||||
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids.
|
||||
If ``past_key_values`` is used, only input_ids that do not have their past calculated should be passed as
|
||||
``input_ids``.
|
||||
|
||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
The input_ids which have their past given to this model should not be passed as input ids as they have already been computed.
|
||||
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
|
||||
The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@ -284,10 +287,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
|
||||
If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and
|
||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
If `use_cache` is True, ``past_key_values`` key value states are returned and
|
||||
can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
@ -343,7 +346,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@ -353,7 +356,16 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
if "past" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
output_hidden_states = (
|
||||
@ -373,11 +385,11 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past is None:
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
past_key_values = [None] * len(self.h)
|
||||
else:
|
||||
past_length = past[0][0].size(-2)
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
@ -431,7 +443,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
presents = () if use_cache else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = [] if output_attentions else None
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
outputs = h(
|
||||
@ -492,7 +504,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@ -504,7 +516,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@ -515,6 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@ -524,11 +537,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
if "past" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past=past,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -531,7 +531,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -622,7 +622,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -347,10 +347,12 @@ GPT2_START_DOCSTRING = r"""
|
||||
GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
|
||||
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
|
||||
If ``past_key_values`` is used, only ``input_ids`` that do not have their past calculated should be passed
|
||||
as ``input_ids``.
|
||||
|
||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
@ -358,10 +360,10 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `past` output below). Can be used to speed up sequential decoding.
|
||||
The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
|
||||
The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been computed.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
@ -386,9 +388,9 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
|
||||
If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
|
||||
use_cache (:obj:`bool`):
|
||||
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
|
||||
If `use_cache` is True, ``past_key_values`` key value states are returned and can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
@ -437,7 +439,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@ -447,7 +449,16 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
if "past" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -472,11 +483,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past is None:
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
past_key_values = [None] * len(self.h)
|
||||
else:
|
||||
past_length = past[0][0].size(-2)
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
@ -522,7 +533,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
@ -581,7 +592,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
@ -593,7 +604,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@ -604,6 +615,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@ -613,11 +625,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||
computed for labels in ``[0, ..., config.vocab_size]``
|
||||
"""
|
||||
if "past" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past=past,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
@ -680,7 +699,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
@ -693,7 +712,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
|
||||
@ -741,15 +760,21 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
if "lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("lm_labels")
|
||||
if "past" in kwargs:
|
||||
warnings.warn(
|
||||
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past=past,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1094,7 +1094,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -665,7 +665,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
if "lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -223,7 +223,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
if "masked_lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
@ -836,27 +836,27 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Used in the cross-attention of the decoder.
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
|
||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||
decoder_input_ids takes the value of input_ids.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding.
|
||||
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
|
||||
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
|
||||
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
|
||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||
decoder_inputs_embeds takes the value of inputs_embeds.
|
||||
@ -923,7 +923,7 @@ class T5Model(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
decoder_past_key_values=None,
|
||||
use_cache=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -931,6 +931,7 @@ class T5Model(T5PreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
@ -947,6 +948,14 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
"""
|
||||
if "decoder_past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
|
||||
@ -978,7 +987,7 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None:
|
||||
if decoder_past_key_values is not None:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
@ -989,7 +998,7 @@ class T5Model(T5PreTrainedModel):
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_value_states,
|
||||
past_key_value_states=decoder_past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@ -1062,7 +1071,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
decoder_past_key_values=None,
|
||||
use_cache=None,
|
||||
labels=None,
|
||||
inputs_embeds=None,
|
||||
@ -1071,7 +1080,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@ -1103,9 +1112,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
if "lm_labels" in kwargs:
|
||||
warnings.warn(
|
||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||
DeprecationWarning,
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("lm_labels")
|
||||
if "decoder_past_key_value_states" in kwargs:
|
||||
warnings.warn(
|
||||
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@ -1138,7 +1153,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None:
|
||||
if decoder_past_key_values is not None:
|
||||
assert labels is None, "Decoder should not use cached key value states when training."
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
@ -1150,7 +1165,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_value_states,
|
||||
past_key_value_states=decoder_past_key_values,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@ -1193,11 +1208,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
encoder_outputs, decoder_past_key_value_states = past
|
||||
encoder_outputs, decoder_past_key_values = past
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
||||
"decoder_past_key_values": decoder_past_key_values,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"use_cache": use_cache,
|
||||
|
Loading…
Reference in New Issue
Block a user