Enable pegasus fp16 by clamping large activations (#7243)

* Clean clamp

* boom boom

* Take some other changes

* boom boom

* boom boom

* boom boom

* one chg

* fix test

* Use finfo

* style
This commit is contained in:
Sam Shleifer 2020-10-01 04:48:37 -04:00 committed by GitHub
parent be51c1039d
commit 9e80f972fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -269,6 +269,9 @@ class EncoderLayer(nn.Module):
x = residual + x x = residual + x
if not self.normalize_before: if not self.normalize_before:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
if torch.isinf(x).any() or torch.isnan(x).any():
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
return x, attn_weights return x, attn_weights

View File

@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
# Demonstrate fp16 issue, Contributions welcome! # Demonstrate fp16 issue, Contributions welcome!
self.model.half() self.model.half()
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10) translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True) decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"] assert decoded_fp16 == [
self.assertListEqual(decoded, bad_fp16_result) "California's largest electricity provider has begun",
"N-Dubz have revealed they were",
]
class PegasusConfigTests(unittest.TestCase): class PegasusConfigTests(unittest.TestCase):