mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Enable pegasus fp16 by clamping large activations (#7243)
* Clean clamp * boom boom * Take some other changes * boom boom * boom boom * boom boom * one chg * fix test * Use finfo * style
This commit is contained in:
parent
be51c1039d
commit
9e80f972fb
@ -269,6 +269,9 @@ class EncoderLayer(nn.Module):
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
if torch.isinf(x).any() or torch.isnan(x).any():
|
||||
clamp_value = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
|
@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
# Demonstrate fp16 issue, Contributions welcome!
|
||||
self.model.half()
|
||||
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
||||
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"]
|
||||
self.assertListEqual(decoded, bad_fp16_result)
|
||||
decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
||||
assert decoded_fp16 == [
|
||||
"California's largest electricity provider has begun",
|
||||
"N-Dubz have revealed they were",
|
||||
]
|
||||
|
||||
|
||||
class PegasusConfigTests(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user