mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Enable pegasus fp16 by clamping large activations (#7243)
* Clean clamp * boom boom * Take some other changes * boom boom * boom boom * boom boom * one chg * fix test * Use finfo * style
This commit is contained in:
parent
be51c1039d
commit
9e80f972fb
@ -269,6 +269,9 @@ class EncoderLayer(nn.Module):
|
|||||||
x = residual + x
|
x = residual + x
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
x = self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
|
if torch.isinf(x).any() or torch.isnan(x).any():
|
||||||
|
clamp_value = torch.finfo(x.dtype).max - 1000
|
||||||
|
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
||||||
return x, attn_weights
|
return x, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,9 +47,11 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
# Demonstrate fp16 issue, Contributions welcome!
|
# Demonstrate fp16 issue, Contributions welcome!
|
||||||
self.model.half()
|
self.model.half()
|
||||||
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
|
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
|
||||||
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
decoded_fp16 = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
||||||
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"]
|
assert decoded_fp16 == [
|
||||||
self.assertListEqual(decoded, bad_fp16_result)
|
"California's largest electricity provider has begun",
|
||||||
|
"N-Dubz have revealed they were",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PegasusConfigTests(unittest.TestCase):
|
class PegasusConfigTests(unittest.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user