Merge pull request #3137 from tomhosking/bart-refactor

Refactor BartModel so that input checks are handled within enc/dec
This commit is contained in:
Thomas Wolf 2020-03-06 13:06:34 +01:00 committed by GitHub
commit 6ffe03a0a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
- hidden states
- attentions
"""
# check attention mask and invert
if encoder_padding_mask is not None:
assert encoder_padding_mask.dim() == 2
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
assert encoder_padding_mask.max() <= 0
# embed positions
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
@ -808,11 +821,6 @@ class BartModel(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
):
if attention_mask is not None:
assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0
# make masks if user doesn't supply
if not self.decoder.generation_mode: