diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index f1dca793db2..d3bc3cd81e2 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -269,6 +269,9 @@ class EncoderLayer(nn.Module): x = residual + x if not self.normalize_before: x = self.final_layer_norm(x) + if torch.isinf(x).any() or torch.isnan(x).any(): + clamp_value = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp_value, max=clamp_value) return x, attn_weights diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 68cb5d6e046..61880e66871 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): # Demonstrate fp16 issue, Contributions welcome! self.model.half() translated_tokens_fp16 = self.model.generate(**inputs, max_length=10) - decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True) - bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"] - self.assertListEqual(decoded, bad_fp16_result) + decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True) + assert decoded_fp16 == [ + "California's largest electricity provider has begun", + "N-Dubz have revealed they were", + ] class PegasusConfigTests(unittest.TestCase):