OPT - Fix Softmax NaN in half precision mode (#17437)

This commit is contained in:
Younes Belkada 2022-06-29 19:15:32 +02:00 committed by GitHub
parent 9fe2403bc5
commit d444edb3f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 4 deletions

View File

@ -109,7 +109,6 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -212,9 +211,15 @@ class OPTAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
dtype_attn_weights = attn_weights.dtype
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if dtype_attn_weights == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
@ -382,7 +387,7 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder.version"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.init_std

View File

@ -22,7 +22,7 @@ import unittest
import timeout_decorator # noqa
from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -428,3 +428,25 @@ class OPTGenerationTest(unittest.TestCase):
predicted_outputs += generated_string
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
@require_torch_gpu
def test_batched_nan_fp16(self):
# a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
# therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
# please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
model_name = "facebook/opt-1.3b"
tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
model = OPTForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
model = model.eval()
batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
input_ids = batch["input_ids"].cuda()
attention_mask = batch["attention_mask"].cuda()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
self.assertFalse(
torch.isnan(outputs.logits[0]).any().item()
) # the first logits could contain NaNs if it fails