mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #3137 from tomhosking/bart-refactor
Refactor BartModel so that input checks are handled within enc/dec
This commit is contained in:
commit
6ffe03a0a1
@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
|
||||
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
||||
During training might not be of length n_layers because of layer dropout.
|
||||
"""
|
||||
# check attention mask and invert
|
||||
if attention_mask is not None:
|
||||
assert attention_mask.dim() == 2
|
||||
|
||||
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
||||
assert attention_mask.max() <= 0
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
x = inputs_embeds + embed_pos
|
||||
@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
|
||||
- hidden states
|
||||
- attentions
|
||||
"""
|
||||
# check attention mask and invert
|
||||
if encoder_padding_mask is not None:
|
||||
assert encoder_padding_mask.dim() == 2
|
||||
|
||||
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
|
||||
assert encoder_padding_mask.max() <= 0
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
||||
|
||||
@ -808,11 +821,6 @@ class BartModel(PretrainedBartModel):
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
assert attention_mask.dim() == 2
|
||||
|
||||
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
||||
assert attention_mask.max() <= 0
|
||||
|
||||
# make masks if user doesn't supply
|
||||
if not self.decoder.generation_mode:
|
||||
|
Loading…
Reference in New Issue
Block a user