diff --git a/examples/summarization/bertabs/modeling_bertabs.py b/examples/summarization/bertabs/modeling_bertabs.py index 0691403186c..e314ff122bb 100644 --- a/examples/summarization/bertabs/modeling_bertabs.py +++ b/examples/summarization/bertabs/modeling_bertabs.py @@ -15,11 +15,14 @@ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + + # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import copy + +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, import math import numpy as np diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 21c51f971e0..d689756f691 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -640,9 +640,10 @@ class SelfAttention(nn.Module): reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) - attn_weights = attn_weights_float.type_as(attn_weights) + attn_weights_float = F.softmax(attn_weights, dim=-1) attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,) + attn_weights = attn_weights_float.type_as(attn_weights) + assert v is not None attn_output = torch.bmm(attn_probs, v) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) @@ -696,8 +697,12 @@ class SelfAttention(nn.Module): elif prev_key_padding_mask is not None: filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) if prev_key_padding_mask.is_cuda: - filler = filler.cuda() + filler = filler.to(prev_key_padding_mask.device) new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + print(new_key_padding_mask.device, new_key_padding_mask.dtype) + import ipdb + + ipdb.set_trace() elif key_padding_mask is not None: filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) if key_padding_mask.is_cuda: diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 559046f66bd..ccb1946080b 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, ) - lm_model = BartForMaskedLM(config) - context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long() - summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long() + lm_model = BartForMaskedLM(config).to(torch_device) + context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]) + summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]) logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary) expected_shape = (*summary.shape, config.vocab_size) self.assertEqual(logits.shape, expected_shape) def test_generate_beam_search(self): - input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() + input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]]) config = BartConfig( vocab_size=self.vocab_size, d_model=24, @@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase): max_position_embeddings=48, output_past=True, ) - lm_model = BartForMaskedLM(config) + lm_model = BartForMaskedLM(config).to(torch_device) lm_model.eval() new_input_ids = lm_model.generate( @@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase): bart_toks = tokenizer.encode(ex, return_tensors="pt") _assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_generate_fp16(self): + config, input_ids, batch_size = self._get_config_and_data(output_past=True) + attention_mask = input_ids.ne(1) + lm_model = BartForMaskedLM(config).eval().to(torch_device).half() + lm_model.generate(input_ids, attention_mask) + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""