mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
This reverts commit 0b2da0e592
.
This commit is contained in:
parent
0b2da0e592
commit
615be03f9d
@ -47,7 +47,7 @@ class PretrainedConfig(object):
|
|||||||
Whether or not the model should return all hidden-states.
|
Whether or not the model should return all hidden-states.
|
||||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not the model should returns all attentions.
|
Whether or not the model should returns all attentions.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
|
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
|
||||||
|
@ -110,8 +110,6 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
Used in the SQuAD evaluation script for XLM and XLNet.
|
Used in the SQuAD evaluation script for XLM and XLNet.
|
||||||
end_n_top (:obj:`int`, optional, defaults to 5):
|
end_n_top (:obj:`int`, optional, defaults to 5):
|
||||||
Used in the SQuAD evaluation script for XLM and XLNet.
|
Used in the SQuAD evaluation script for XLM and XLNet.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
||||||
Differs slightly from other models as it is always turned on at training time.
|
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
|
@ -575,7 +575,7 @@ class XLNetModelOutput(ModelOutput):
|
|||||||
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
|
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
|
||||||
``num_predict`` corresponds to ``sequence_length``.
|
``num_predict`` corresponds to ``sequence_length``.
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -611,7 +611,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
|
|||||||
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
|
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
|
||||||
``num_predict`` corresponds to ``sequence_length``.
|
``num_predict`` corresponds to ``sequence_length``.
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -645,7 +645,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
|
|||||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -679,7 +679,7 @@ class XLNetForTokenClassificationOutput(ModelOutput):
|
|||||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
Classification scores (before SoftMax).
|
Classification scores (before SoftMax).
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -715,7 +715,7 @@ class XLNetForMultipleChoiceOutput(ModelOutput):
|
|||||||
|
|
||||||
Classification scores (before SoftMax).
|
Classification scores (before SoftMax).
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -751,7 +751,7 @@ class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
|
|||||||
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -794,7 +794,7 @@ class XLNetForQuestionAnsweringOutput(ModelOutput):
|
|||||||
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
|
||||||
Log probabilities for the ``is_impossible`` label of the answers.
|
Log probabilities for the ``is_impossible`` label of the answers.
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states.
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
||||||
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
||||||
should not be passed as input ids as they have already been computed.
|
should not be passed as input ids as they have already been computed.
|
||||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
@ -850,7 +850,7 @@ XLNET_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states as computed by the model
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
|
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
|
||||||
given to this model should not be passed as input ids as they have already been computed.
|
given to this model should not be passed as input ids as they have already been computed.
|
||||||
`use_cache` has to be set to `True` to make use of `mems`.
|
`use_cache` has to be set to `True` to make use of `mems`.
|
||||||
@ -964,19 +964,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
if self.reuse_len is not None and self.reuse_len > 0:
|
if self.reuse_len is not None and self.reuse_len > 0:
|
||||||
curr_out = curr_out[: self.reuse_len]
|
curr_out = curr_out[: self.reuse_len]
|
||||||
|
|
||||||
if self.mem_len is None or self.mem_len == 0:
|
|
||||||
# If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
|
|
||||||
# and returns all of the past and current hidden states.
|
|
||||||
cutoff = 0
|
|
||||||
else:
|
|
||||||
# If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
|
|
||||||
# states. This is the preferred setting for training and long-form generation.
|
|
||||||
cutoff = -self.mem_len
|
|
||||||
if prev_mem is None:
|
if prev_mem is None:
|
||||||
# if `use_cache` is active and `mem_len` is defined, the model
|
new_mem = curr_out[-self.mem_len :]
|
||||||
new_mem = curr_out[cutoff:]
|
|
||||||
else:
|
else:
|
||||||
new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
|
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len :]
|
||||||
|
|
||||||
return new_mem.detach()
|
return new_mem.detach()
|
||||||
|
|
||||||
@ -1048,7 +1039,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
input_mask=None,
|
input_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1058,7 +1049,6 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
|
||||||
# but we want a unified interface in the library with the batch size on the first dimension
|
# but we want a unified interface in the library with the batch size on the first dimension
|
||||||
@ -1189,7 +1179,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
attentions = [] if output_attentions else None
|
attentions = [] if output_attentions else None
|
||||||
hidden_states = [] if output_hidden_states else None
|
hidden_states = [] if output_hidden_states else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if use_cache:
|
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
||||||
# cache new mems
|
# cache new mems
|
||||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -1221,7 +1211,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
output = output.permute(1, 0, 2).contiguous()
|
output = output.permute(1, 0, 2).contiguous()
|
||||||
|
|
||||||
# TODO Teven: fix this test to only use use_cache.
|
# TODO Teven: fix this test to only use use_cache.
|
||||||
if not use_cache:
|
if not (self.mem_len is not None and self.mem_len > 0 and use_cache is True):
|
||||||
new_mems = None
|
new_mems = None
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@ -1322,7 +1312,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1370,7 +1360,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1444,7 +1433,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1457,7 +1446,6 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1536,7 +1524,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1548,7 +1536,6 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
|
|||||||
of the input tensors. (see `input_ids` above)
|
of the input tensors. (see `input_ids` above)
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
outputs = self.transformer(
|
outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1631,7 +1618,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1643,7 +1630,6 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
of the input tensors. (see `input_ids` above)
|
of the input tensors. (see `input_ids` above)
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
@ -1731,7 +1717,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
start_positions=None,
|
start_positions=None,
|
||||||
end_positions=None,
|
end_positions=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1747,7 +1733,6 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
Position outside of the sequence are not taken into account for computing the loss.
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
outputs = self.transformer(
|
outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -1839,7 +1824,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
is_impossible=None,
|
is_impossible=None,
|
||||||
cls_index=None,
|
cls_index=None,
|
||||||
p_mask=None,
|
p_mask=None,
|
||||||
use_cache=None,
|
use_cache=True,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
@ -1879,7 +1864,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
>>> loss = outputs[0]
|
>>> loss = outputs[0]
|
||||||
"""
|
"""
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -191,8 +191,8 @@ class XLNetModelTester:
|
|||||||
model = XLNetModel(config)
|
model = XLNetModel(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
base_model_output = model(input_ids_1)
|
no_mems_outputs = model(input_ids_1)
|
||||||
self.parent.assertEqual(len(base_model_output), 2)
|
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||||
|
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
@ -202,72 +202,6 @@ class XLNetModelTester:
|
|||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_xlnet_model_use_cache(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
input_ids_1,
|
|
||||||
input_ids_2,
|
|
||||||
input_ids_q,
|
|
||||||
perm_mask,
|
|
||||||
input_mask,
|
|
||||||
target_mapping,
|
|
||||||
segment_ids,
|
|
||||||
lm_labels,
|
|
||||||
sequence_labels,
|
|
||||||
is_impossible_labels,
|
|
||||||
token_labels,
|
|
||||||
):
|
|
||||||
model = XLNetModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# first forward pass
|
|
||||||
causal_mask = torch.ones(
|
|
||||||
input_ids_1.shape[0],
|
|
||||||
input_ids_1.shape[1],
|
|
||||||
input_ids_1.shape[1],
|
|
||||||
dtype=torch.float,
|
|
||||||
device=input_ids_1.device,
|
|
||||||
)
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
|
||||||
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
|
|
||||||
outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask)
|
|
||||||
outputs_conf = model(input_ids_1)
|
|
||||||
|
|
||||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
|
|
||||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
|
|
||||||
|
|
||||||
output, mems = outputs_cache
|
|
||||||
|
|
||||||
# create hypothetical next token and extent to next_input_ids
|
|
||||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
|
||||||
|
|
||||||
# append to next input_ids and token_type_ids
|
|
||||||
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
|
|
||||||
|
|
||||||
# causal mask
|
|
||||||
causal_mask = torch.ones(
|
|
||||||
input_ids_1.shape[0],
|
|
||||||
input_ids_1.shape[1] + 1,
|
|
||||||
input_ids_1.shape[1] + 1,
|
|
||||||
dtype=torch.float,
|
|
||||||
device=input_ids_1.device,
|
|
||||||
)
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=0)
|
|
||||||
single_mask = torch.ones(input_ids_1.shape[0], 1, 1)
|
|
||||||
|
|
||||||
# second forward pass
|
|
||||||
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask)
|
|
||||||
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask)
|
|
||||||
|
|
||||||
# select random slice
|
|
||||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
||||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
|
||||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
|
||||||
|
|
||||||
# test that outputs are equal for slice
|
|
||||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
||||||
|
|
||||||
def create_and_check_xlnet_base_model_with_att_output(
|
def create_and_check_xlnet_base_model_with_att_output(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@ -517,6 +451,7 @@ class XLNetModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
XLNetModel,
|
XLNetModel,
|
||||||
@ -547,12 +482,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||||
|
|
||||||
def test_xlnet_base_model_use_cache(self):
|
|
||||||
# checking that in auto-regressive mode, `use_cache` gives the same results
|
|
||||||
self.model_tester.set_seed()
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_xlnet_model_use_cache(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_xlnet_base_model_with_att_output(self):
|
def test_xlnet_base_model_with_att_output(self):
|
||||||
self.model_tester.set_seed()
|
self.model_tester.set_seed()
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
@ -945,33 +874,33 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
9,
|
9,
|
||||||
69,
|
69,
|
||||||
27,
|
27,
|
||||||
442,
|
50,
|
||||||
|
551,
|
||||||
22,
|
22,
|
||||||
2771,
|
2771,
|
||||||
24,
|
4901,
|
||||||
11335,
|
19,
|
||||||
20,
|
21,
|
||||||
|
45,
|
||||||
|
668,
|
||||||
|
21,
|
||||||
18,
|
18,
|
||||||
9225,
|
416,
|
||||||
2198,
|
41,
|
||||||
9,
|
1499,
|
||||||
69,
|
|
||||||
27,
|
|
||||||
442,
|
|
||||||
22,
|
22,
|
||||||
2771,
|
755,
|
||||||
24,
|
|
||||||
11335,
|
|
||||||
20,
|
|
||||||
18,
|
18,
|
||||||
9225,
|
14285,
|
||||||
2198,
|
|
||||||
9,
|
9,
|
||||||
69,
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
27,
|
27,
|
||||||
442,
|
1499,
|
||||||
|
22,
|
||||||
|
642,
|
||||||
22,
|
22,
|
||||||
2771,
|
|
||||||
]
|
]
|
||||||
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
||||||
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
||||||
@ -981,8 +910,9 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
||||||
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
||||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
||||||
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
|
# <sep><cls>, Rasputin is asked to perform magic.
|
||||||
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
|
# He is not able to perform magic, and his father and
|
||||||
|
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||||
|
|
||||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
Loading…
Reference in New Issue
Block a user