[Reformer] Adapt Reformer MaskedLM Attn mask (#5560)

* fix attention mask

* fix slow test

* refactor attn masks

* fix fp16 generate test
This commit is contained in:
Patrick von Platen 2020-07-07 10:48:06 +02:00 committed by GitHub
parent 3dcb748e31
commit 989ae326b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 50 deletions

View File

@ -373,7 +373,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# use cached buckets for backprop only
if buckets is None:
# hash query key vectors into buckets
buckets = self._hash_vectors(query_key_vectors, num_hashes)
buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask)
assert (
int(buckets.shape[-1]) == num_hashes * sequence_length
@ -460,7 +460,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets)
def _hash_vectors(self, vectors, num_hashes):
def _hash_vectors(self, vectors, num_hashes, attention_mask):
batch_size = vectors.shape[0]
# See https://arxiv.org/pdf/1509.02897.pdf
@ -514,6 +514,15 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
cur_product = cur_product * bucket_factor
if attention_mask is not None:
# add an extra bucket for padding tokens only
num_buckets = num_buckets + 1
# assign padding tokens extra bucket
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape)
buckets = torch.where(
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
)
# buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len).
# Next we add offsets so that bucket numbers from different hashing rounds don't overlap.
offsets = torch.arange(num_hashes, device=vectors.device)
@ -614,7 +623,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
self_mask_value = self.self_mask_value_float32
mask_value = self.mask_value_float32
mask = self._compute_attn_mask(query_bucket_idx, key_value_bucket_idx, attention_mask, sequence_length)
mask = self._compute_attn_mask(
query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, sequence_length
)
if mask is not None:
query_key_dots = torch.where(mask, query_key_dots, mask_value)
@ -669,45 +680,32 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return out_vectors, logits, attention_probs
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, sequence_length):
mask = None
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dot_shape, sequence_length):
# Causal mask
if self.is_decoder:
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
# Attention mask: chunk, look up correct mask value from key_value_bucket_idx
# IMPORTANT: official trax code does not use a mask for LSH Atttention. Not sure why.
# attention mask for LSH
if attention_mask is not None:
# if chunked attention, the attention mask has to correspond to LSH order
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
if sequence_length > self.chunk_length:
attention_mask = attention_mask.to(torch.uint8)[:, None, None, :]
# expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask[:, None, :]
attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,))
key_attn_mask = torch.gather(attention_mask, -1, key_indices)
query_attn_mask = torch.gather(attention_mask, -1, query_indices)
# expand to query_key_dots shape: duplicate along query axis since key sorting is the same for each query position in chunk
attn_mask = query_attn_mask.unsqueeze(-1) * key_attn_mask.unsqueeze(-2)
# extract attention mask from LSH sorted key_indices
attention_mask = torch.gather(attention_mask, -1, key_indices)
# free memory
del query_attn_mask, key_attn_mask
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dot_shape)
# Causal mask
if self.is_decoder is True:
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
# add attention mask if not None
if attention_mask is not None:
attention_mask = causal_mask * attention_mask
else:
# usual attention mask creation
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(-2)).expand(
query_indices.shape + attention_mask.shape[-1:]
)
attention_mask = causal_mask
# free memory
del attention_mask
# multiply by casaul mask if necessary
if mask is not None:
mask = mask * attn_mask
else:
mask = attn_mask
return mask
return attention_mask
def _len_and_dim_norm(self, vectors):
"""
@ -923,7 +921,6 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs)
def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length):
mask = None
# chunk attention mask and look before and after
if attention_mask is not None:
@ -931,24 +928,21 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
if self.chunk_length < sequence_length:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
attention_mask_key = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
else:
attention_mask_key = attention_mask
attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after)
# create attn_mask
attention_mask = attention_mask.unsqueeze(-2).expand(query_key_dots_shape)
# Causal mask
if self.is_decoder is True:
mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2)).to(query_indices.device)
# Attention mask
if attention_mask is not None:
# create attn_mask
attn_mask = (attention_mask.unsqueeze(-1) * attention_mask_key.unsqueeze(-2)).expand(query_key_dots_shape)
# multiply by casaul mask if necessary
if mask is not None:
mask = mask * attn_mask
# add attention mask if not None
if attention_mask is not None:
attention_mask = causal_mask * attention_mask
else:
mask = attn_mask
return mask
attention_mask = causal_mask
return attention_mask
class ReformerSelfOutput(nn.Module):

View File

@ -407,7 +407,8 @@ class ReformerModelTester:
model.to(torch_device)
model.half()
model.eval()
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
# only use last 10 inputs for generation
output = model.generate(input_ids[:, -10:], attention_mask=input_mask, do_sample=False)
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
@ -623,7 +624,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
@require_torch
class ReformerIntegrationTests(unittest.TestCase):
"""
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/04/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `branch_to_save_trax_integration_tests`.
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
"""
def _get_basic_config_and_input(self):
@ -940,7 +941,7 @@ class ReformerIntegrationTests(unittest.TestCase):
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[1, -1, :5]
expected_output_slice = torch.tensor(
[0.0324, -0.0121, 0.0615, 0.0031, -0.0297], dtype=torch.float, device=torch_device,
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], dtype=torch.float, device=torch_device,
)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))