mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate can return cross-attention weights too (#10493)
This commit is contained in:
parent
b013842244
commit
1750e62900
@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
||||
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
|
||||
@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||
@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
|
||||
generated_length, sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
|
||||
@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
|
||||
@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@ -1177,6 +1193,7 @@ class GenerationMixin:
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
@ -1212,6 +1229,8 @@ class GenerationMixin:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
@ -1260,6 +1279,7 @@ class GenerationMixin:
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
@ -1384,6 +1404,7 @@ class GenerationMixin:
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
@ -1424,6 +1445,8 @@ class GenerationMixin:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
@ -1468,6 +1491,7 @@ class GenerationMixin:
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
@ -1604,6 +1628,7 @@ class GenerationMixin:
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
@ -1656,6 +1681,8 @@ class GenerationMixin:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
@ -1716,6 +1743,7 @@ class GenerationMixin:
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
@ -1865,6 +1893,7 @@ class GenerationMixin:
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
@ -1913,6 +1942,8 @@ class GenerationMixin:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
@ -1968,17 +1999,18 @@ class GenerationMixin:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
return BeamSampleEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return BeamSearchDecoderOnlyOutput(
|
||||
return BeamSampleDecoderOnlyOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
@ -2115,6 +2147,7 @@ class GenerationMixin:
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
@ -2238,6 +2271,8 @@ class GenerationMixin:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
@ -2263,7 +2298,7 @@ class GenerationMixin:
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"]
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
@ -2272,6 +2307,7 @@ class GenerationMixin:
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
|
@ -39,6 +39,8 @@ if is_torch_available():
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation_utils import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
@ -900,11 +902,11 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
||||
self.assertTrue(
|
||||
|
Loading…
Reference in New Issue
Block a user